• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 20917664041

12 Jan 2026 11:26AM UTC coverage: 65.133% (+0.06%) from 65.074%
20917664041

Pull #10414

github

web-flow
Merge 343104337 into 39a1421d1
Pull Request #10414: [g175] graph/db: merge g175 types-prep side branch

449 of 836 new or added lines in 17 files covered. (53.71%)

101 existing lines in 23 files now uncovered.

138153 of 212110 relevant lines covered (65.13%)

20715.84 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        color "image/color"
11
        "iter"
12
        "maps"
13
        "math"
14
        "net"
15
        "slices"
16
        "strconv"
17
        "sync"
18
        "time"
19

20
        "github.com/btcsuite/btcd/btcec/v2"
21
        "github.com/btcsuite/btcd/btcutil"
22
        "github.com/btcsuite/btcd/chaincfg/chainhash"
23
        "github.com/btcsuite/btcd/wire"
24
        "github.com/lightningnetwork/lnd/aliasmgr"
25
        "github.com/lightningnetwork/lnd/batch"
26
        "github.com/lightningnetwork/lnd/fn/v2"
27
        "github.com/lightningnetwork/lnd/graph/db/models"
28
        "github.com/lightningnetwork/lnd/lnwire"
29
        "github.com/lightningnetwork/lnd/routing/route"
30
        "github.com/lightningnetwork/lnd/sqldb"
31
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
32
        "github.com/lightningnetwork/lnd/tlv"
33
        "github.com/lightningnetwork/lnd/tor"
34
)
35

36
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
37
// execute queries against the SQL graph tables.
38
//
39
//nolint:ll,interfacebloat
40
type SQLQueries interface {
41
        /*
42
                Node queries.
43
        */
44
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
45
        UpsertSourceNode(ctx context.Context, arg sqlc.UpsertSourceNodeParams) (int64, error)
46
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.GraphNode, error)
47
        GetNodesByIDs(ctx context.Context, ids []int64) ([]sqlc.GraphNode, error)
48
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
49
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.GraphNode, error)
50
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.GraphNode, error)
51
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
52
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
53
        DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error)
54
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
55
        DeleteNode(ctx context.Context, id int64) error
56
        NodeExists(ctx context.Context, arg sqlc.NodeExistsParams) (bool, error)
57

58
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeExtraType, error)
59
        GetNodeExtraTypesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeExtraType, error)
60
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
61
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
62

63
        UpsertNodeAddress(ctx context.Context, arg sqlc.UpsertNodeAddressParams) error
64
        GetNodeAddresses(ctx context.Context, nodeID int64) ([]sqlc.GetNodeAddressesRow, error)
65
        GetNodeAddressesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress, error)
66
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
67

68
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
69
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeFeature, error)
70
        GetNodeFeaturesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature, error)
71
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
72
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
73

74
        /*
75
                Source node queries.
76
        */
77
        AddSourceNode(ctx context.Context, nodeID int64) error
78
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
79

80
        /*
81
                Channel queries.
82
        */
83
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
84
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
85
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.GraphChannel, error)
86
        GetChannelsBySCIDs(ctx context.Context, arg sqlc.GetChannelsBySCIDsParams) ([]sqlc.GraphChannel, error)
87
        GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]sqlc.GetChannelsByOutpointsRow, error)
88
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
89
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
90
        GetChannelsBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelsBySCIDWithPoliciesParams) ([]sqlc.GetChannelsBySCIDWithPoliciesRow, error)
91
        GetChannelsByIDs(ctx context.Context, ids []int64) ([]sqlc.GetChannelsByIDsRow, error)
92
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
93
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
94
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
95
        ListChannelsForNodeIDs(ctx context.Context, arg sqlc.ListChannelsForNodeIDsParams) ([]sqlc.ListChannelsForNodeIDsRow, error)
96
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
97
        ListChannelsWithPoliciesForCachePaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesForCachePaginatedParams) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow, error)
98
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
99
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
100
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
101
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.GraphChannel, error)
102
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
103
        DeleteChannels(ctx context.Context, ids []int64) error
104

105
        UpsertChannelExtraType(ctx context.Context, arg sqlc.UpsertChannelExtraTypeParams) error
106
        GetChannelExtrasBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelExtraType, error)
107
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
108
        GetChannelFeaturesBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelFeature, error)
109

110
        /*
111
                Channel Policy table queries.
112
        */
113
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
114
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.GraphChannelPolicy, error)
115
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
116

117
        UpsertChanPolicyExtraType(ctx context.Context, arg sqlc.UpsertChanPolicyExtraTypeParams) error
118
        GetChannelPolicyExtraTypesBatch(ctx context.Context, policyIds []int64) ([]sqlc.GetChannelPolicyExtraTypesBatchRow, error)
119
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
120

121
        /*
122
                Zombie index queries.
123
        */
124
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
125
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.GraphZombieChannel, error)
126
        GetZombieChannelsSCIDs(ctx context.Context, arg sqlc.GetZombieChannelsSCIDsParams) ([]sqlc.GraphZombieChannel, error)
127
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
128
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
129
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
130

131
        /*
132
                Prune log table queries.
133
        */
134
        GetPruneTip(ctx context.Context) (sqlc.GraphPruneLog, error)
135
        GetPruneHashByHeight(ctx context.Context, blockHeight int64) ([]byte, error)
136
        GetPruneEntriesForHeights(ctx context.Context, heights []int64) ([]sqlc.GraphPruneLog, error)
137
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
138
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
139

140
        /*
141
                Closed SCID table queries.
142
        */
143
        InsertClosedChannel(ctx context.Context, scid []byte) error
144
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
145
        GetClosedChannelsSCIDs(ctx context.Context, scids [][]byte) ([][]byte, error)
146

147
        /*
148
                Migration specific queries.
149

150
                NOTE: these should not be used in code other than migrations.
151
                Once sqldbv2 is in place, these can be removed from this struct
152
                as then migrations will have their own dedicated queries
153
                structs.
154
        */
155
        InsertNodeMig(ctx context.Context, arg sqlc.InsertNodeMigParams) (int64, error)
156
        InsertChannelMig(ctx context.Context, arg sqlc.InsertChannelMigParams) (int64, error)
157
        InsertEdgePolicyMig(ctx context.Context, arg sqlc.InsertEdgePolicyMigParams) (int64, error)
158
}
159

160
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
161
// database operations.
162
type BatchedSQLQueries interface {
163
        SQLQueries
164
        sqldb.BatchedTx[SQLQueries]
165
}
166

167
// SQLStore is an implementation of the Store interface that uses a SQL
168
// database as the backend.
169
type SQLStore struct {
170
        cfg *SQLStoreConfig
171
        db  BatchedSQLQueries
172

173
        // cacheMu guards all caches (rejectCache and chanCache). If
174
        // this mutex will be acquired at the same time as the DB mutex then
175
        // the cacheMu MUST be acquired first to prevent deadlock.
176
        cacheMu     sync.RWMutex
177
        rejectCache *rejectCache
178
        chanCache   *channelCache
179

180
        chanScheduler batch.Scheduler[SQLQueries]
181
        nodeScheduler batch.Scheduler[SQLQueries]
182

183
        srcNodes  map[lnwire.GossipVersion]*srcNodeInfo
184
        srcNodeMu sync.Mutex
185
}
186

187
// A compile-time assertion to ensure that SQLStore implements the Store
188
// interface.
189
var _ Store = (*SQLStore)(nil)
190

191
// SQLStoreConfig holds the configuration for the SQLStore.
192
type SQLStoreConfig struct {
193
        // ChainHash is the genesis hash for the chain that all the gossip
194
        // messages in this store are aimed at.
195
        ChainHash chainhash.Hash
196

197
        // QueryConfig holds configuration values for SQL queries.
198
        QueryCfg *sqldb.QueryConfig
199
}
200

201
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
202
// storage backend.
203
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
204
        options ...StoreOptionModifier) (*SQLStore, error) {
×
205

×
206
        opts := DefaultOptions()
×
207
        for _, o := range options {
×
208
                o(opts)
×
209
        }
×
210

211
        if opts.NoMigration {
×
212
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
213
                        "supported for SQL stores")
×
214
        }
×
215

216
        s := &SQLStore{
×
217
                cfg:         cfg,
×
218
                db:          db,
×
219
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
220
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
221
                srcNodes:    make(map[lnwire.GossipVersion]*srcNodeInfo),
×
222
        }
×
223

×
224
        s.chanScheduler = batch.NewTimeScheduler(
×
225
                db, &s.cacheMu, opts.BatchCommitInterval,
×
226
        )
×
227
        s.nodeScheduler = batch.NewTimeScheduler(
×
228
                db, nil, opts.BatchCommitInterval,
×
229
        )
×
230

×
231
        return s, nil
×
232
}
233

234
// AddNode adds a vertex/node to the graph database. If the node is not
235
// in the database from before, this will add a new, unconnected one to the
236
// graph. If it is present from before, this will update that node's
237
// information.
238
//
239
// NOTE: part of the Store interface.
240
func (s *SQLStore) AddNode(ctx context.Context,
241
        node *models.Node, opts ...batch.SchedulerOption) error {
×
242

×
243
        r := &batch.Request[SQLQueries]{
×
244
                Opts: batch.NewSchedulerOptions(opts...),
×
245
                Do: func(queries SQLQueries) error {
×
246
                        _, err := upsertNode(ctx, queries, node)
×
247

×
248
                        // It is possible that two of the same node
×
249
                        // announcements are both being processed in the same
×
250
                        // batch. This may case the UpsertNode conflict to
×
251
                        // be hit since we require at the db layer that the
×
252
                        // new last_update is greater than the existing
×
253
                        // last_update. We need to gracefully handle this here.
×
254
                        if errors.Is(err, sql.ErrNoRows) {
×
255
                                return nil
×
256
                        }
×
257

258
                        return err
×
259
                },
260
        }
261

262
        return s.nodeScheduler.Execute(ctx, r)
×
263
}
264

265
// FetchNode attempts to look up a target node by its identity public
266
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
267
// returned.
268
//
269
// NOTE: part of the Store interface.
270
func (s *SQLStore) FetchNode(ctx context.Context, v lnwire.GossipVersion,
271
        pubKey route.Vertex) (*models.Node, error) {
×
272

×
273
        var node *models.Node
×
274
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
275
                var err error
×
NEW
276
                _, node, err = getNodeByPubKey(
×
NEW
277
                        ctx, s.cfg.QueryCfg, db, v, pubKey,
×
NEW
278
                )
×
279

×
280
                return err
×
281
        }, sqldb.NoOpReset)
×
282
        if err != nil {
×
283
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
284
        }
×
285

286
        return node, nil
×
287
}
288

289
// HasV1Node determines if the graph has a vertex identified by the
290
// target node identity public key. If the node exists in the database, a
291
// timestamp of when the data for the node was lasted updated is returned along
292
// with a true boolean. Otherwise, an empty time.Time is returned with a false
293
// boolean.
294
//
295
// NOTE: part of the Store interface.
296
func (s *SQLStore) HasV1Node(ctx context.Context,
297
        pubKey [33]byte) (time.Time, bool, error) {
×
298

×
299
        var (
×
300
                exists     bool
×
301
                lastUpdate time.Time
×
302
        )
×
303
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
304
                dbNode, err := db.GetNodeByPubKey(
×
305
                        ctx, sqlc.GetNodeByPubKeyParams{
×
306
                                Version: int16(lnwire.GossipVersion1),
×
307
                                PubKey:  pubKey[:],
×
308
                        },
×
309
                )
×
310
                if errors.Is(err, sql.ErrNoRows) {
×
311
                        return nil
×
312
                } else if err != nil {
×
313
                        return fmt.Errorf("unable to fetch node: %w", err)
×
314
                }
×
315

316
                exists = true
×
317

×
318
                if dbNode.LastUpdate.Valid {
×
319
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
320
                }
×
321

322
                return nil
×
323
        }, sqldb.NoOpReset)
324
        if err != nil {
×
325
                return time.Time{}, false,
×
326
                        fmt.Errorf("unable to fetch node: %w", err)
×
327
        }
×
328

329
        return lastUpdate, exists, nil
×
330
}
331

332
// HasNode determines if the graph has a vertex identified by the
333
// target node identity public key.
334
//
335
// NOTE: part of the Store interface.
336
func (s *SQLStore) HasNode(ctx context.Context, v lnwire.GossipVersion,
NEW
337
        pubKey [33]byte) (bool, error) {
×
NEW
338

×
NEW
339
        var exists bool
×
NEW
340
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
341
                var err error
×
NEW
342
                exists, err = db.NodeExists(ctx, sqlc.NodeExistsParams{
×
NEW
343
                        Version: int16(v),
×
NEW
344
                        PubKey:  pubKey[:],
×
NEW
345
                })
×
NEW
346

×
NEW
347
                return err
×
NEW
348
        }, sqldb.NoOpReset)
×
NEW
349
        if err != nil {
×
NEW
350
                return false, fmt.Errorf("unable to check if node (%x) "+
×
NEW
351
                        "exists: %w", pubKey, err)
×
NEW
352
        }
×
353

NEW
354
        return exists, nil
×
355
}
356

357
// AddrsForNode returns all known addresses for the target node public key
358
// that the graph DB is aware of. The returned boolean indicates if the
359
// given node is unknown to the graph DB or not.
360
//
361
// NOTE: part of the Store interface.
362
func (s *SQLStore) AddrsForNode(ctx context.Context, v lnwire.GossipVersion,
363
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
364

×
365
        var (
×
366
                addresses []net.Addr
×
367
                known     bool
×
368
        )
×
369
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
370
                // First, check if the node exists and get its DB ID if it
×
371
                // does.
×
372
                dbID, err := db.GetNodeIDByPubKey(
×
373
                        ctx, sqlc.GetNodeIDByPubKeyParams{
×
NEW
374
                                Version: int16(v),
×
375
                                PubKey:  nodePub.SerializeCompressed(),
×
376
                        },
×
377
                )
×
378
                if errors.Is(err, sql.ErrNoRows) {
×
379
                        return nil
×
380
                }
×
381

382
                known = true
×
383

×
384
                addresses, err = getNodeAddresses(ctx, db, dbID)
×
385
                if err != nil {
×
386
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
387
                                err)
×
388
                }
×
389

390
                return nil
×
391
        }, sqldb.NoOpReset)
392
        if err != nil {
×
393
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
394
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
395
        }
×
396

397
        return known, addresses, nil
×
398
}
399

400
// DeleteNode starts a new database transaction to remove a vertex/node
401
// from the database according to the node's public key.
402
//
403
// NOTE: part of the Store interface.
404
func (s *SQLStore) DeleteNode(ctx context.Context, v lnwire.GossipVersion,
405
        pubKey route.Vertex) error {
×
406

×
407
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
408
                res, err := db.DeleteNodeByPubKey(
×
409
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
NEW
410
                                Version: int16(v),
×
411
                                PubKey:  pubKey[:],
×
412
                        },
×
413
                )
×
414
                if err != nil {
×
415
                        return err
×
416
                }
×
417

418
                rows, err := res.RowsAffected()
×
419
                if err != nil {
×
420
                        return err
×
421
                }
×
422

423
                if rows == 0 {
×
424
                        return ErrGraphNodeNotFound
×
425
                } else if rows > 1 {
×
426
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
427
                }
×
428

429
                return err
×
430
        }, sqldb.NoOpReset)
431
        if err != nil {
×
432
                return fmt.Errorf("unable to delete node: %w", err)
×
433
        }
×
434

435
        return nil
×
436
}
437

438
// FetchNodeFeatures returns the features of the given node. If no features are
439
// known for the node, an empty feature vector is returned.
440
//
441
// NOTE: this is part of the graphdb.NodeTraverser interface.
442
func (s *SQLStore) FetchNodeFeatures(v lnwire.GossipVersion,
NEW
443
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
444

×
445
        ctx := context.TODO()
×
446

×
NEW
447
        return fetchNodeFeatures(ctx, s.db, v, nodePub)
×
448
}
×
449

450
// DisabledChannelIDs returns the channel ids of disabled channels.
451
// A channel is disabled when two of the associated ChanelEdgePolicies
452
// have their disabled bit on.
453
//
454
// NOTE: part of the Store interface.
455
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
456
        var (
×
457
                ctx     = context.TODO()
×
458
                chanIDs []uint64
×
459
        )
×
460
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
461
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
462
                if err != nil {
×
463
                        return fmt.Errorf("unable to fetch disabled "+
×
464
                                "channels: %w", err)
×
465
                }
×
466

467
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
468

×
469
                return nil
×
470
        }, sqldb.NoOpReset)
471
        if err != nil {
×
472
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
473
                        err)
×
474
        }
×
475

476
        return chanIDs, nil
×
477
}
478

479
// LookupAlias attempts to return the alias as advertised by the target node.
480
//
481
// NOTE: part of the Store interface.
482
func (s *SQLStore) LookupAlias(ctx context.Context, v lnwire.GossipVersion,
483
        pub *btcec.PublicKey) (string, error) {
×
484

×
485
        var alias string
×
486
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
487
                dbNode, err := db.GetNodeByPubKey(
×
488
                        ctx, sqlc.GetNodeByPubKeyParams{
×
NEW
489
                                Version: int16(v),
×
490
                                PubKey:  pub.SerializeCompressed(),
×
491
                        },
×
492
                )
×
493
                if errors.Is(err, sql.ErrNoRows) {
×
494
                        return ErrNodeAliasNotFound
×
495
                } else if err != nil {
×
496
                        return fmt.Errorf("unable to fetch node: %w", err)
×
497
                }
×
498

499
                if !dbNode.Alias.Valid {
×
500
                        return ErrNodeAliasNotFound
×
501
                }
×
502

503
                alias = dbNode.Alias.String
×
504

×
505
                return nil
×
506
        }, sqldb.NoOpReset)
507
        if err != nil {
×
508
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
509
        }
×
510

511
        return alias, nil
×
512
}
513

514
// SourceNode returns the source node of the graph. The source node is treated
515
// as the center node within a star-graph. This method may be used to kick off
516
// a path finding algorithm in order to explore the reachability of another
517
// node based off the source node.
518
//
519
// NOTE: part of the Store interface.
520
func (s *SQLStore) SourceNode(ctx context.Context,
NEW
521
        v lnwire.GossipVersion) (*models.Node, error) {
×
522

×
523
        var node *models.Node
×
524
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
525
                _, nodePub, err := s.getSourceNode(ctx, db, v)
×
526
                if err != nil {
×
NEW
527
                        return fmt.Errorf("unable to fetch source node: %w",
×
528
                                err)
×
529
                }
×
530

NEW
531
                _, node, err = getNodeByPubKey(
×
NEW
532
                        ctx, s.cfg.QueryCfg, db, v, nodePub,
×
NEW
533
                )
×
534

×
535
                return err
×
536
        }, sqldb.NoOpReset)
537
        if err != nil {
×
538
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
539
        }
×
540

541
        return node, nil
×
542
}
543

544
// SetSourceNode sets the source node within the graph database. The source
545
// node is to be used as the center of a star-graph within path finding
546
// algorithms.
547
//
548
// NOTE: part of the Store interface.
549
func (s *SQLStore) SetSourceNode(ctx context.Context,
550
        node *models.Node) error {
×
551

×
552
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
553
                // For the source node, we use a less strict upsert that allows
×
554
                // updates even when the timestamp hasn't changed. This handles
×
555
                // the race condition where multiple goroutines (e.g.,
×
556
                // setSelfNode, createNewHiddenService, RPC updates) read the
×
557
                // same old timestamp, independently increment it, and try to
×
558
                // write concurrently. We want all parameter changes to persist,
×
559
                // even if timestamps collide.
×
560
                id, err := upsertSourceNode(ctx, db, node)
×
561
                if err != nil {
×
562
                        return fmt.Errorf("unable to upsert source node: %w",
×
563
                                err)
×
564
                }
×
565

566
                // Make sure that if a source node for this version is already
567
                // set, then the ID is the same as the one we are about to set.
568
                dbSourceNodeID, _, err := s.getSourceNode(
×
NEW
569
                        ctx, db, node.Version,
×
570
                )
×
571
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
572
                        return fmt.Errorf("unable to fetch source node: %w",
×
573
                                err)
×
574
                } else if err == nil {
×
575
                        if dbSourceNodeID != id {
×
576
                                return fmt.Errorf("v1 source node already "+
×
577
                                        "set to a different node: %d vs %d",
×
578
                                        dbSourceNodeID, id)
×
579
                        }
×
580

581
                        return nil
×
582
                }
583

584
                return db.AddSourceNode(ctx, id)
×
585
        }, sqldb.NoOpReset)
586
}
587

588
// NodeUpdatesInHorizon returns all the known lightning node which have an
589
// update timestamp within the passed range. This method can be used by two
590
// nodes to quickly determine if they have the same set of up to date node
591
// announcements.
592
//
593
// NOTE: This is part of the Store interface.
594
func (s *SQLStore) NodeUpdatesInHorizon(startTime, endTime time.Time,
595
        opts ...IteratorOption) iter.Seq2[*models.Node, error] {
×
596

×
597
        cfg := defaultIteratorConfig()
×
598
        for _, opt := range opts {
×
599
                opt(cfg)
×
600
        }
×
601

602
        return func(yield func(*models.Node, error) bool) {
×
603
                var (
×
604
                        ctx            = context.TODO()
×
605
                        lastUpdateTime sql.NullInt64
×
606
                        lastPubKey     = make([]byte, 33)
×
607
                        hasMore        = true
×
608
                )
×
609

×
610
                // Each iteration, we'll read a batch amount of nodes, yield
×
611
                // them, then decide is we have more or not.
×
612
                for hasMore {
×
613
                        var batch []*models.Node
×
614

×
615
                        //nolint:ll
×
616
                        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
617
                                //nolint:ll
×
618
                                params := sqlc.GetNodesByLastUpdateRangeParams{
×
619
                                        StartTime: sqldb.SQLInt64(
×
620
                                                startTime.Unix(),
×
621
                                        ),
×
622
                                        EndTime: sqldb.SQLInt64(
×
623
                                                endTime.Unix(),
×
624
                                        ),
×
625
                                        LastUpdate: lastUpdateTime,
×
626
                                        LastPubKey: lastPubKey,
×
627
                                        OnlyPublic: sql.NullBool{
×
628
                                                Bool:  cfg.iterPublicNodes,
×
629
                                                Valid: true,
×
630
                                        },
×
631
                                        MaxResults: sqldb.SQLInt32(
×
632
                                                cfg.nodeUpdateIterBatchSize,
×
633
                                        ),
×
634
                                }
×
635
                                rows, err := db.GetNodesByLastUpdateRange(
×
636
                                        ctx, params,
×
637
                                )
×
638
                                if err != nil {
×
639
                                        return err
×
640
                                }
×
641

642
                                hasMore = len(rows) == cfg.nodeUpdateIterBatchSize
×
643

×
644
                                err = forEachNodeInBatch(
×
645
                                        ctx, s.cfg.QueryCfg, db, rows,
×
646
                                        func(_ int64, node *models.Node) error {
×
647
                                                batch = append(batch, node)
×
648

×
649
                                                // Update pagination cursors
×
650
                                                // based on the last processed
×
651
                                                // node.
×
652
                                                lastUpdateTime = sql.NullInt64{
×
653
                                                        Int64: node.LastUpdate.
×
654
                                                                Unix(),
×
655
                                                        Valid: true,
×
656
                                                }
×
657
                                                lastPubKey = node.PubKeyBytes[:]
×
658

×
659
                                                return nil
×
660
                                        },
×
661
                                )
662
                                if err != nil {
×
663
                                        return fmt.Errorf("unable to build "+
×
664
                                                "nodes: %w", err)
×
665
                                }
×
666

667
                                return nil
×
668
                        }, func() {
×
669
                                batch = []*models.Node{}
×
670
                        })
×
671

672
                        if err != nil {
×
673
                                log.Errorf("NodeUpdatesInHorizon batch "+
×
674
                                        "error: %v", err)
×
675

×
676
                                yield(&models.Node{}, err)
×
677

×
678
                                return
×
679
                        }
×
680

681
                        for _, node := range batch {
×
682
                                if !yield(node, nil) {
×
683
                                        return
×
684
                                }
×
685
                        }
686

687
                        // If the batch didn't yield anything, then we're done.
688
                        if len(batch) == 0 {
×
689
                                break
×
690
                        }
691
                }
692
        }
693
}
694

695
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
696
// undirected edge from the two target nodes are created. The information stored
697
// denotes the static attributes of the channel, such as the channelID, the keys
698
// involved in creation of the channel, and the set of features that the channel
699
// supports. The chanPoint and chanID are used to uniquely identify the edge
700
// globally within the database.
701
//
702
// NOTE: part of the Store interface.
703
func (s *SQLStore) AddChannelEdge(ctx context.Context,
704
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
705

×
706
        var alreadyExists bool
×
707
        r := &batch.Request[SQLQueries]{
×
708
                Opts: batch.NewSchedulerOptions(opts...),
×
709
                Reset: func() {
×
710
                        alreadyExists = false
×
711
                },
×
712
                Do: func(tx SQLQueries) error {
×
713
                        chanIDB := channelIDToBytes(edge.ChannelID)
×
714

×
715
                        // Make sure that the channel doesn't already exist. We
×
716
                        // do this explicitly instead of relying on catching a
×
717
                        // unique constraint error because relying on SQL to
×
718
                        // throw that error would abort the entire batch of
×
719
                        // transactions.
×
720
                        _, err := tx.GetChannelBySCID(
×
721
                                ctx, sqlc.GetChannelBySCIDParams{
×
722
                                        Scid:    chanIDB,
×
723
                                        Version: int16(lnwire.GossipVersion1),
×
724
                                },
×
725
                        )
×
726
                        if err == nil {
×
727
                                alreadyExists = true
×
728
                                return nil
×
729
                        } else if !errors.Is(err, sql.ErrNoRows) {
×
730
                                return fmt.Errorf("unable to fetch channel: %w",
×
731
                                        err)
×
732
                        }
×
733

734
                        return insertChannel(ctx, tx, edge)
×
735
                },
736
                OnCommit: func(err error) error {
×
737
                        switch {
×
738
                        case err != nil:
×
739
                                return err
×
740
                        case alreadyExists:
×
741
                                return ErrEdgeAlreadyExist
×
742
                        default:
×
743
                                s.rejectCache.remove(edge.ChannelID)
×
744
                                s.chanCache.remove(edge.ChannelID)
×
745
                                return nil
×
746
                        }
747
                },
748
        }
749

750
        return s.chanScheduler.Execute(ctx, r)
×
751
}
752

753
// HighestChanID returns the "highest" known channel ID in the channel graph.
754
// This represents the "newest" channel from the PoV of the chain. This method
755
// can be used by peers to quickly determine if their graphs are in sync.
756
//
757
// NOTE: This is part of the Store interface.
758
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
759
        var highestChanID uint64
×
760
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
761
                chanID, err := db.HighestSCID(ctx, int16(lnwire.GossipVersion1))
×
762
                if errors.Is(err, sql.ErrNoRows) {
×
763
                        return nil
×
764
                } else if err != nil {
×
765
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
766
                                err)
×
767
                }
×
768

769
                highestChanID = byteOrder.Uint64(chanID)
×
770

×
771
                return nil
×
772
        }, sqldb.NoOpReset)
773
        if err != nil {
×
774
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
775
        }
×
776

777
        return highestChanID, nil
×
778
}
779

780
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
781
// within the database for the referenced channel. The `flags` attribute within
782
// the ChannelEdgePolicy determines which of the directed edges are being
783
// updated. If the flag is 1, then the first node's information is being
784
// updated, otherwise it's the second node's information. The node ordering is
785
// determined by the lexicographical ordering of the identity public keys of the
786
// nodes on either side of the channel.
787
//
788
// NOTE: part of the Store interface.
789
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
790
        edge *models.ChannelEdgePolicy,
791
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
792

×
793
        var (
×
794
                isUpdate1    bool
×
795
                edgeNotFound bool
×
796
                from, to     route.Vertex
×
797
        )
×
798

×
799
        r := &batch.Request[SQLQueries]{
×
800
                Opts: batch.NewSchedulerOptions(opts...),
×
801
                Reset: func() {
×
802
                        isUpdate1 = false
×
803
                        edgeNotFound = false
×
804
                },
×
805
                Do: func(tx SQLQueries) error {
×
806
                        var err error
×
807
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
808
                                ctx, tx, edge,
×
809
                        )
×
810
                        // It is possible that two of the same policy
×
811
                        // announcements are both being processed in the same
×
812
                        // batch. This may case the UpsertEdgePolicy conflict to
×
813
                        // be hit since we require at the db layer that the
×
814
                        // new last_update is greater than the existing
×
815
                        // last_update. We need to gracefully handle this here.
×
816
                        if errors.Is(err, sql.ErrNoRows) {
×
817
                                return nil
×
818
                        } else if err != nil {
×
819
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
820
                        }
×
821

822
                        // Silence ErrEdgeNotFound so that the batch can
823
                        // succeed, but propagate the error via local state.
824
                        if errors.Is(err, ErrEdgeNotFound) {
×
825
                                edgeNotFound = true
×
826
                                return nil
×
827
                        }
×
828

829
                        return err
×
830
                },
831
                OnCommit: func(err error) error {
×
832
                        switch {
×
833
                        case err != nil:
×
834
                                return err
×
835
                        case edgeNotFound:
×
836
                                return ErrEdgeNotFound
×
837
                        default:
×
838
                                s.updateEdgeCache(edge, isUpdate1)
×
839
                                return nil
×
840
                        }
841
                },
842
        }
843

844
        err := s.chanScheduler.Execute(ctx, r)
×
845

×
846
        return from, to, err
×
847
}
848

849
// updateEdgeCache updates our reject and channel caches with the new
850
// edge policy information.
851
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
852
        isUpdate1 bool) {
×
853

×
854
        // If an entry for this channel is found in reject cache, we'll modify
×
855
        // the entry with the updated timestamp for the direction that was just
×
856
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
857
        // during the next query for this edge.
×
858
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
859
                if isUpdate1 {
×
860
                        entry.upd1Time = e.LastUpdate.Unix()
×
861
                } else {
×
862
                        entry.upd2Time = e.LastUpdate.Unix()
×
863
                }
×
864
                s.rejectCache.insert(e.ChannelID, entry)
×
865
        }
866

867
        // If an entry for this channel is found in channel cache, we'll modify
868
        // the entry with the updated policy for the direction that was just
869
        // written. If the edge doesn't exist, we'll defer loading the info and
870
        // policies and lazily read from disk during the next query.
871
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
872
                if isUpdate1 {
×
873
                        channel.Policy1 = e
×
874
                } else {
×
875
                        channel.Policy2 = e
×
876
                }
×
877
                s.chanCache.insert(e.ChannelID, channel)
×
878
        }
879
}
880

881
// ForEachSourceNodeChannel iterates through all channels of the source node,
882
// executing the passed callback on each. The call-back is provided with the
883
// channel's outpoint, whether we have a policy for the channel and the channel
884
// peer's node information.
885
//
886
// NOTE: part of the Store interface.
887
func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context,
888
        cb func(chanPoint wire.OutPoint, havePolicy bool,
889
                otherNode *models.Node) error, reset func()) error {
×
890

×
891
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
892
                nodeID, nodePub, err := s.getSourceNode(
×
893
                        ctx, db, lnwire.GossipVersion1,
×
894
                )
×
895
                if err != nil {
×
896
                        return fmt.Errorf("unable to fetch source node: %w",
×
897
                                err)
×
898
                }
×
899

900
                return forEachNodeChannel(
×
901
                        ctx, db, s.cfg, nodeID,
×
902
                        func(info *models.ChannelEdgeInfo,
×
903
                                outPolicy *models.ChannelEdgePolicy,
×
904
                                _ *models.ChannelEdgePolicy) error {
×
905

×
906
                                // Fetch the other node.
×
907
                                var (
×
908
                                        otherNodePub [33]byte
×
909
                                        node1        = info.NodeKey1Bytes
×
910
                                        node2        = info.NodeKey2Bytes
×
911
                                )
×
912
                                switch {
×
913
                                case bytes.Equal(node1[:], nodePub[:]):
×
914
                                        otherNodePub = node2
×
915
                                case bytes.Equal(node2[:], nodePub[:]):
×
916
                                        otherNodePub = node1
×
917
                                default:
×
918
                                        return fmt.Errorf("node not " +
×
919
                                                "participating in this channel")
×
920
                                }
921

922
                                _, otherNode, err := getNodeByPubKey(
×
NEW
923
                                        ctx, s.cfg.QueryCfg, db,
×
NEW
924
                                        lnwire.GossipVersion1, otherNodePub,
×
925
                                )
×
926
                                if err != nil {
×
927
                                        return fmt.Errorf("unable to fetch "+
×
928
                                                "other node(%x): %w",
×
929
                                                otherNodePub, err)
×
930
                                }
×
931

932
                                return cb(
×
933
                                        info.ChannelPoint, outPolicy != nil,
×
934
                                        otherNode,
×
935
                                )
×
936
                        },
937
                )
938
        }, reset)
939
}
940

941
// ForEachNode iterates through all the stored vertices/nodes in the graph,
942
// executing the passed callback with each node encountered. If the callback
943
// returns an error, then the transaction is aborted and the iteration stops
944
// early.
945
//
946
// NOTE: part of the Store interface.
947
func (s *SQLStore) ForEachNode(ctx context.Context,
948
        cb func(node *models.Node) error, reset func()) error {
×
949

×
950
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
951
                return forEachNodePaginated(
×
952
                        ctx, s.cfg.QueryCfg, db,
×
953
                        lnwire.GossipVersion1, func(_ context.Context, _ int64,
×
954
                                node *models.Node) error {
×
955

×
956
                                return cb(node)
×
957
                        },
×
958
                )
959
        }, reset)
960
}
961

962
// ForEachNodeDirectedChannel iterates through all channels of a given node,
963
// executing the passed callback on the directed edge representing the channel
964
// and its incoming policy. If the callback returns an error, then the iteration
965
// is halted with the error propagated back up to the caller.
966
//
967
// Unknown policies are passed into the callback as nil values.
968
//
969
// NOTE: this is part of the graphdb.NodeTraverser interface.
970
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
971
        cb func(channel *DirectedChannel) error, reset func()) error {
×
972

×
973
        var ctx = context.TODO()
×
974

×
975
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
976
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
977
        }, reset)
×
978
}
979

980
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
981
// graph, executing the passed callback with each node encountered. If the
982
// callback returns an error, then the transaction is aborted and the iteration
983
// stops early.
984
func (s *SQLStore) ForEachNodeCacheable(ctx context.Context,
985
        cb func(route.Vertex, *lnwire.FeatureVector) error,
986
        reset func()) error {
×
987

×
988
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
989
                return forEachNodeCacheable(
×
990
                        ctx, s.cfg.QueryCfg, db,
×
991
                        func(_ int64, nodePub route.Vertex,
×
992
                                features *lnwire.FeatureVector) error {
×
993

×
994
                                return cb(nodePub, features)
×
995
                        },
×
996
                )
997
        }, reset)
998
        if err != nil {
×
999
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
1000
        }
×
1001

1002
        return nil
×
1003
}
1004

1005
// ForEachNodeChannel iterates through all channels of the given node,
1006
// executing the passed callback with an edge info structure and the policies
1007
// of each end of the channel. The first edge policy is the outgoing edge *to*
1008
// the connecting node, while the second is the incoming edge *from* the
1009
// connecting node. If the callback returns an error, then the iteration is
1010
// halted with the error propagated back up to the caller.
1011
//
1012
// Unknown policies are passed into the callback as nil values.
1013
//
1014
// NOTE: part of the Store interface.
1015
func (s *SQLStore) ForEachNodeChannel(ctx context.Context, nodePub route.Vertex,
1016
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1017
                *models.ChannelEdgePolicy) error, reset func()) error {
×
1018

×
1019
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1020
                dbNode, err := db.GetNodeByPubKey(
×
1021
                        ctx, sqlc.GetNodeByPubKeyParams{
×
1022
                                Version: int16(lnwire.GossipVersion1),
×
1023
                                PubKey:  nodePub[:],
×
1024
                        },
×
1025
                )
×
1026
                if errors.Is(err, sql.ErrNoRows) {
×
1027
                        return nil
×
1028
                } else if err != nil {
×
1029
                        return fmt.Errorf("unable to fetch node: %w", err)
×
1030
                }
×
1031

1032
                return forEachNodeChannel(ctx, db, s.cfg, dbNode.ID, cb)
×
1033
        }, reset)
1034
}
1035

1036
// extractMaxUpdateTime returns the maximum of the two policy update times.
1037
// This is used for pagination cursor tracking.
1038
func extractMaxUpdateTime(
1039
        row sqlc.GetChannelsByPolicyLastUpdateRangeRow) int64 {
×
1040

×
1041
        switch {
×
1042
        case row.Policy1LastUpdate.Valid && row.Policy2LastUpdate.Valid:
×
1043
                return max(row.Policy1LastUpdate.Int64,
×
1044
                        row.Policy2LastUpdate.Int64)
×
1045
        case row.Policy1LastUpdate.Valid:
×
1046
                return row.Policy1LastUpdate.Int64
×
1047
        case row.Policy2LastUpdate.Valid:
×
1048
                return row.Policy2LastUpdate.Int64
×
1049
        default:
×
1050
                return 0
×
1051
        }
1052
}
1053

1054
// buildChannelFromRow constructs a ChannelEdge from a database row.
1055
// This includes building the nodes, channel info, and policies.
1056
func (s *SQLStore) buildChannelFromRow(ctx context.Context, db SQLQueries,
1057
        row sqlc.GetChannelsByPolicyLastUpdateRangeRow) (ChannelEdge, error) {
×
1058

×
1059
        node1, err := buildNode(ctx, s.cfg.QueryCfg, db, row.GraphNode)
×
1060
        if err != nil {
×
1061
                return ChannelEdge{}, fmt.Errorf("unable to build node1: %w",
×
1062
                        err)
×
1063
        }
×
1064

1065
        node2, err := buildNode(ctx, s.cfg.QueryCfg, db, row.GraphNode_2)
×
1066
        if err != nil {
×
1067
                return ChannelEdge{}, fmt.Errorf("unable to build node2: %w",
×
1068
                        err)
×
1069
        }
×
1070

1071
        channel, err := getAndBuildEdgeInfo(
×
1072
                ctx, s.cfg, db,
×
1073
                row.GraphChannel, node1.PubKeyBytes,
×
1074
                node2.PubKeyBytes,
×
1075
        )
×
1076
        if err != nil {
×
1077
                return ChannelEdge{}, fmt.Errorf("unable to build "+
×
1078
                        "channel info: %w", err)
×
1079
        }
×
1080

1081
        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1082
        if err != nil {
×
1083
                return ChannelEdge{}, fmt.Errorf("unable to extract "+
×
1084
                        "channel policies: %w", err)
×
1085
        }
×
1086

1087
        p1, p2, err := getAndBuildChanPolicies(
×
1088
                ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, channel.ChannelID,
×
1089
                node1.PubKeyBytes, node2.PubKeyBytes,
×
1090
        )
×
1091
        if err != nil {
×
1092
                return ChannelEdge{}, fmt.Errorf("unable to build "+
×
1093
                        "channel policies: %w", err)
×
1094
        }
×
1095

1096
        return ChannelEdge{
×
1097
                Info:    channel,
×
1098
                Policy1: p1,
×
1099
                Policy2: p2,
×
1100
                Node1:   node1,
×
1101
                Node2:   node2,
×
1102
        }, nil
×
1103
}
1104

1105
// updateChanCacheBatch updates the channel cache with multiple edges at once.
1106
// This method acquires the cache lock only once for the entire batch.
1107
func (s *SQLStore) updateChanCacheBatch(edgesToCache map[uint64]ChannelEdge) {
×
1108
        if len(edgesToCache) == 0 {
×
1109
                return
×
1110
        }
×
1111

1112
        s.cacheMu.Lock()
×
1113
        defer s.cacheMu.Unlock()
×
1114

×
1115
        for chanID, edge := range edgesToCache {
×
1116
                s.chanCache.insert(chanID, edge)
×
1117
        }
×
1118
}
1119

1120
// ChanUpdatesInHorizon returns all the known channel edges which have at least
1121
// one edge that has an update timestamp within the specified horizon.
1122
//
1123
// Iterator Lifecycle:
1124
// 1. Initialize state (edgesSeen map, cache tracking, pagination cursors)
1125
// 2. Query batch of channels with policies in time range
1126
// 3. For each channel: check if seen, check cache, or build from DB
1127
// 4. Yield channels to caller
1128
// 5. Update cache after successful batch
1129
// 6. Repeat with updated pagination cursor until no more results
1130
//
1131
// NOTE: This is part of the Store interface.
1132
func (s *SQLStore) ChanUpdatesInHorizon(startTime, endTime time.Time,
1133
        opts ...IteratorOption) iter.Seq2[ChannelEdge, error] {
×
1134

×
1135
        // Apply options.
×
1136
        cfg := defaultIteratorConfig()
×
1137
        for _, opt := range opts {
×
1138
                opt(cfg)
×
1139
        }
×
1140

1141
        return func(yield func(ChannelEdge, error) bool) {
×
1142
                var (
×
1143
                        ctx            = context.TODO()
×
1144
                        edgesSeen      = make(map[uint64]struct{})
×
1145
                        edgesToCache   = make(map[uint64]ChannelEdge)
×
1146
                        hits           int
×
1147
                        total          int
×
1148
                        lastUpdateTime sql.NullInt64
×
1149
                        lastID         sql.NullInt64
×
1150
                        hasMore        = true
×
1151
                )
×
1152

×
1153
                // Each iteration, we'll read a batch amount of channel updates
×
1154
                // (consulting the cache along the way), yield them, then loop
×
1155
                // back to decide if we have any more updates to read out.
×
1156
                for hasMore {
×
1157
                        var batch []ChannelEdge
×
1158

×
1159
                        // Acquire read lock before starting transaction to
×
1160
                        // ensure consistent lock ordering (cacheMu -> DB) and
×
1161
                        // prevent deadlock with write operations.
×
1162
                        s.cacheMu.RLock()
×
1163

×
1164
                        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(),
×
1165
                                func(db SQLQueries) error {
×
1166
                                        //nolint:ll
×
1167
                                        params := sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
1168
                                                Version: int16(lnwire.GossipVersion1),
×
1169
                                                StartTime: sqldb.SQLInt64(
×
1170
                                                        startTime.Unix(),
×
1171
                                                ),
×
1172
                                                EndTime: sqldb.SQLInt64(
×
1173
                                                        endTime.Unix(),
×
1174
                                                ),
×
1175
                                                LastUpdateTime: lastUpdateTime,
×
1176
                                                LastID:         lastID,
×
1177
                                                MaxResults: sql.NullInt32{
×
1178
                                                        Int32: int32(
×
1179
                                                                cfg.chanUpdateIterBatchSize,
×
1180
                                                        ),
×
1181
                                                        Valid: true,
×
1182
                                                },
×
1183
                                        }
×
1184
                                        //nolint:ll
×
1185
                                        rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
1186
                                                ctx, params,
×
1187
                                        )
×
1188
                                        if err != nil {
×
1189
                                                return err
×
1190
                                        }
×
1191

1192
                                        //nolint:ll
1193
                                        hasMore = len(rows) == cfg.chanUpdateIterBatchSize
×
1194

×
1195
                                        //nolint:ll
×
1196
                                        for _, row := range rows {
×
1197
                                                lastUpdateTime = sql.NullInt64{
×
1198
                                                        Int64: extractMaxUpdateTime(row),
×
1199
                                                        Valid: true,
×
1200
                                                }
×
1201
                                                lastID = sql.NullInt64{
×
1202
                                                        Int64: row.GraphChannel.ID,
×
1203
                                                        Valid: true,
×
1204
                                                }
×
1205

×
1206
                                                // Skip if we've already
×
1207
                                                // processed this channel.
×
1208
                                                chanIDInt := byteOrder.Uint64(
×
1209
                                                        row.GraphChannel.Scid,
×
1210
                                                )
×
1211
                                                _, ok := edgesSeen[chanIDInt]
×
1212
                                                if ok {
×
1213
                                                        continue
×
1214
                                                }
1215

1216
                                                // Check cache (we already hold
1217
                                                // shared read lock).
1218
                                                channel, ok := s.chanCache.get(
×
1219
                                                        chanIDInt,
×
1220
                                                )
×
1221
                                                if ok {
×
1222
                                                        hits++
×
1223
                                                        total++
×
1224
                                                        edgesSeen[chanIDInt] = struct{}{}
×
1225
                                                        batch = append(batch, channel)
×
1226

×
1227
                                                        continue
×
1228
                                                }
1229

1230
                                                chanEdge, err := s.buildChannelFromRow(
×
1231
                                                        ctx, db, row,
×
1232
                                                )
×
1233
                                                if err != nil {
×
1234
                                                        return err
×
1235
                                                }
×
1236

1237
                                                edgesSeen[chanIDInt] = struct{}{}
×
1238
                                                edgesToCache[chanIDInt] = chanEdge
×
1239

×
1240
                                                batch = append(batch, chanEdge)
×
1241

×
1242
                                                total++
×
1243
                                        }
1244

1245
                                        return nil
×
1246
                                }, func() {
×
1247
                                        batch = nil
×
1248
                                        edgesSeen = make(map[uint64]struct{})
×
1249
                                        edgesToCache = make(
×
1250
                                                map[uint64]ChannelEdge,
×
1251
                                        )
×
1252
                                })
×
1253

1254
                        // Release read lock after transaction completes.
1255
                        s.cacheMu.RUnlock()
×
1256

×
1257
                        if err != nil {
×
1258
                                log.Errorf("ChanUpdatesInHorizon "+
×
1259
                                        "batch error: %v", err)
×
1260

×
1261
                                yield(ChannelEdge{}, err)
×
1262

×
1263
                                return
×
1264
                        }
×
1265

1266
                        for _, edge := range batch {
×
1267
                                if !yield(edge, nil) {
×
1268
                                        return
×
1269
                                }
×
1270
                        }
1271

1272
                        // Update cache after successful batch yield, setting
1273
                        // the cache lock only once for the entire batch.
1274
                        s.updateChanCacheBatch(edgesToCache)
×
1275
                        edgesToCache = make(map[uint64]ChannelEdge)
×
1276

×
1277
                        // If the batch didn't yield anything, then we're done.
×
1278
                        if len(batch) == 0 {
×
1279
                                break
×
1280
                        }
1281
                }
1282

1283
                if total > 0 {
×
1284
                        log.Debugf("ChanUpdatesInHorizon hit percentage: "+
×
1285
                                "%.2f (%d/%d)",
×
1286
                                float64(hits)*100/float64(total), hits, total)
×
1287
                } else {
×
1288
                        log.Debugf("ChanUpdatesInHorizon returned no edges "+
×
1289
                                "in horizon (%s, %s)", startTime, endTime)
×
1290
                }
×
1291
        }
1292
}
1293

1294
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1295
// data to the call-back. If withAddrs is true, then the call-back will also be
1296
// provided with the addresses associated with the node. The address retrieval
1297
// result in an additional round-trip to the database, so it should only be used
1298
// if the addresses are actually needed.
1299
//
1300
// NOTE: part of the Store interface.
1301
func (s *SQLStore) ForEachNodeCached(ctx context.Context, withAddrs bool,
1302
        cb func(ctx context.Context, node route.Vertex, addrs []net.Addr,
1303
                chans map[uint64]*DirectedChannel) error, reset func()) error {
×
1304

×
1305
        type nodeCachedBatchData struct {
×
1306
                features      map[int64][]int
×
1307
                addrs         map[int64][]nodeAddress
×
1308
                chanBatchData *batchChannelData
×
1309
                chanMap       map[int64][]sqlc.ListChannelsForNodeIDsRow
×
1310
        }
×
1311

×
1312
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1313
                // pageQueryFunc is used to query the next page of nodes.
×
1314
                pageQueryFunc := func(ctx context.Context, lastID int64,
×
1315
                        limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
1316

×
1317
                        return db.ListNodeIDsAndPubKeys(
×
1318
                                ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
1319
                                        Version: int16(lnwire.GossipVersion1),
×
1320
                                        ID:      lastID,
×
1321
                                        Limit:   limit,
×
1322
                                },
×
1323
                        )
×
1324
                }
×
1325

1326
                // batchDataFunc is then used to batch load the data required
1327
                // for each page of nodes.
1328
                batchDataFunc := func(ctx context.Context,
×
1329
                        nodeIDs []int64) (*nodeCachedBatchData, error) {
×
1330

×
1331
                        // Batch load node features.
×
1332
                        nodeFeatures, err := batchLoadNodeFeaturesHelper(
×
1333
                                ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1334
                        )
×
1335
                        if err != nil {
×
1336
                                return nil, fmt.Errorf("unable to batch load "+
×
1337
                                        "node features: %w", err)
×
1338
                        }
×
1339

1340
                        // Maybe fetch the node's addresses if requested.
1341
                        var nodeAddrs map[int64][]nodeAddress
×
1342
                        if withAddrs {
×
1343
                                nodeAddrs, err = batchLoadNodeAddressesHelper(
×
1344
                                        ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1345
                                )
×
1346
                                if err != nil {
×
1347
                                        return nil, fmt.Errorf("unable to "+
×
1348
                                                "batch load node "+
×
1349
                                                "addresses: %w", err)
×
1350
                                }
×
1351
                        }
1352

1353
                        // Batch load ALL unique channels for ALL nodes in this
1354
                        // page.
1355
                        allChannels, err := db.ListChannelsForNodeIDs(
×
1356
                                ctx, sqlc.ListChannelsForNodeIDsParams{
×
1357
                                        Version:  int16(lnwire.GossipVersion1),
×
1358
                                        Node1Ids: nodeIDs,
×
1359
                                        Node2Ids: nodeIDs,
×
1360
                                },
×
1361
                        )
×
1362
                        if err != nil {
×
1363
                                return nil, fmt.Errorf("unable to batch "+
×
1364
                                        "fetch channels for nodes: %w", err)
×
1365
                        }
×
1366

1367
                        // Deduplicate channels and collect IDs.
1368
                        var (
×
1369
                                allChannelIDs []int64
×
1370
                                allPolicyIDs  []int64
×
1371
                        )
×
1372
                        uniqueChannels := make(
×
1373
                                map[int64]sqlc.ListChannelsForNodeIDsRow,
×
1374
                        )
×
1375

×
1376
                        for _, channel := range allChannels {
×
1377
                                channelID := channel.GraphChannel.ID
×
1378

×
1379
                                // Only process each unique channel once.
×
1380
                                _, exists := uniqueChannels[channelID]
×
1381
                                if exists {
×
1382
                                        continue
×
1383
                                }
1384

1385
                                uniqueChannels[channelID] = channel
×
1386
                                allChannelIDs = append(allChannelIDs, channelID)
×
1387

×
1388
                                if channel.Policy1ID.Valid {
×
1389
                                        allPolicyIDs = append(
×
1390
                                                allPolicyIDs,
×
1391
                                                channel.Policy1ID.Int64,
×
1392
                                        )
×
1393
                                }
×
1394
                                if channel.Policy2ID.Valid {
×
1395
                                        allPolicyIDs = append(
×
1396
                                                allPolicyIDs,
×
1397
                                                channel.Policy2ID.Int64,
×
1398
                                        )
×
1399
                                }
×
1400
                        }
1401

1402
                        // Batch load channel data for all unique channels.
1403
                        channelBatchData, err := batchLoadChannelData(
×
1404
                                ctx, s.cfg.QueryCfg, db, allChannelIDs,
×
1405
                                allPolicyIDs,
×
1406
                        )
×
1407
                        if err != nil {
×
1408
                                return nil, fmt.Errorf("unable to batch "+
×
1409
                                        "load channel data: %w", err)
×
1410
                        }
×
1411

1412
                        // Create map of node ID to channels that involve this
1413
                        // node.
1414
                        nodeIDSet := make(map[int64]bool)
×
1415
                        for _, nodeID := range nodeIDs {
×
1416
                                nodeIDSet[nodeID] = true
×
1417
                        }
×
1418

1419
                        nodeChannelMap := make(
×
1420
                                map[int64][]sqlc.ListChannelsForNodeIDsRow,
×
1421
                        )
×
1422
                        for _, channel := range uniqueChannels {
×
1423
                                // Add channel to both nodes if they're in our
×
1424
                                // current page.
×
1425
                                node1 := channel.GraphChannel.NodeID1
×
1426
                                if nodeIDSet[node1] {
×
1427
                                        nodeChannelMap[node1] = append(
×
1428
                                                nodeChannelMap[node1], channel,
×
1429
                                        )
×
1430
                                }
×
1431
                                node2 := channel.GraphChannel.NodeID2
×
1432
                                if nodeIDSet[node2] {
×
1433
                                        nodeChannelMap[node2] = append(
×
1434
                                                nodeChannelMap[node2], channel,
×
1435
                                        )
×
1436
                                }
×
1437
                        }
1438

1439
                        return &nodeCachedBatchData{
×
1440
                                features:      nodeFeatures,
×
1441
                                addrs:         nodeAddrs,
×
1442
                                chanBatchData: channelBatchData,
×
1443
                                chanMap:       nodeChannelMap,
×
1444
                        }, nil
×
1445
                }
1446

1447
                // processItem is used to process each node in the current page.
1448
                processItem := func(ctx context.Context,
×
1449
                        nodeData sqlc.ListNodeIDsAndPubKeysRow,
×
1450
                        batchData *nodeCachedBatchData) error {
×
1451

×
1452
                        // Build feature vector for this node.
×
1453
                        fv := lnwire.EmptyFeatureVector()
×
1454
                        features, exists := batchData.features[nodeData.ID]
×
1455
                        if exists {
×
1456
                                for _, bit := range features {
×
1457
                                        fv.Set(lnwire.FeatureBit(bit))
×
1458
                                }
×
1459
                        }
1460

1461
                        var nodePub route.Vertex
×
1462
                        copy(nodePub[:], nodeData.PubKey)
×
1463

×
1464
                        nodeChannels := batchData.chanMap[nodeData.ID]
×
1465

×
1466
                        toNodeCallback := func() route.Vertex {
×
1467
                                return nodePub
×
1468
                        }
×
1469

1470
                        // Build cached channels map for this node.
1471
                        channels := make(map[uint64]*DirectedChannel)
×
1472
                        for _, channelRow := range nodeChannels {
×
1473
                                directedChan, err := buildDirectedChannel(
×
1474
                                        s.cfg.ChainHash, nodeData.ID, nodePub,
×
1475
                                        channelRow, batchData.chanBatchData, fv,
×
1476
                                        toNodeCallback,
×
1477
                                )
×
1478
                                if err != nil {
×
1479
                                        return err
×
1480
                                }
×
1481

1482
                                channels[directedChan.ChannelID] = directedChan
×
1483
                        }
1484

1485
                        addrs, err := buildNodeAddresses(
×
1486
                                batchData.addrs[nodeData.ID],
×
1487
                        )
×
1488
                        if err != nil {
×
1489
                                return fmt.Errorf("unable to build node "+
×
1490
                                        "addresses: %w", err)
×
1491
                        }
×
1492

1493
                        return cb(ctx, nodePub, addrs, channels)
×
1494
                }
1495

1496
                return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
1497
                        ctx, s.cfg.QueryCfg, int64(-1), pageQueryFunc,
×
1498
                        func(node sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
1499
                                return node.ID
×
1500
                        },
×
1501
                        func(node sqlc.ListNodeIDsAndPubKeysRow) (int64,
1502
                                error) {
×
1503

×
1504
                                return node.ID, nil
×
1505
                        },
×
1506
                        batchDataFunc, processItem,
1507
                )
1508
        }, reset)
1509
}
1510

1511
// ForEachChannelCacheable iterates through all the channel edges stored
1512
// within the graph and invokes the passed callback for each edge. The
1513
// callback takes two edges as since this is a directed graph, both the
1514
// in/out edges are visited. If the callback returns an error, then the
1515
// transaction is aborted and the iteration stops early.
1516
//
1517
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1518
// pointer for that particular channel edge routing policy will be
1519
// passed into the callback.
1520
//
1521
// NOTE: this method is like ForEachChannel but fetches only the data
1522
// required for the graph cache.
1523
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1524
        *models.CachedEdgePolicy, *models.CachedEdgePolicy) error,
1525
        reset func()) error {
×
1526

×
1527
        ctx := context.TODO()
×
1528

×
1529
        handleChannel := func(_ context.Context,
×
1530
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) error {
×
1531

×
1532
                node1, node2, err := buildNodeVertices(
×
1533
                        row.Node1Pubkey, row.Node2Pubkey,
×
1534
                )
×
1535
                if err != nil {
×
1536
                        return err
×
1537
                }
×
1538

1539
                edge := buildCacheableChannelInfo(
×
1540
                        row.Scid, row.Capacity.Int64, node1, node2,
×
1541
                )
×
1542

×
1543
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1544
                if err != nil {
×
1545
                        return err
×
1546
                }
×
1547

1548
                pol1, pol2, err := buildCachedChanPolicies(
×
1549
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1550
                )
×
1551
                if err != nil {
×
1552
                        return err
×
1553
                }
×
1554

1555
                return cb(edge, pol1, pol2)
×
1556
        }
1557

1558
        extractCursor := func(
×
1559
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) int64 {
×
1560

×
1561
                return row.ID
×
1562
        }
×
1563

1564
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1565
                //nolint:ll
×
1566
                queryFunc := func(ctx context.Context, lastID int64,
×
1567
                        limit int32) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow,
×
1568
                        error) {
×
1569

×
1570
                        return db.ListChannelsWithPoliciesForCachePaginated(
×
1571
                                ctx, sqlc.ListChannelsWithPoliciesForCachePaginatedParams{
×
1572
                                        Version: int16(lnwire.GossipVersion1),
×
1573
                                        ID:      lastID,
×
1574
                                        Limit:   limit,
×
1575
                                },
×
1576
                        )
×
1577
                }
×
1578

1579
                return sqldb.ExecutePaginatedQuery(
×
1580
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
1581
                        extractCursor, handleChannel,
×
1582
                )
×
1583
        }, reset)
1584
}
1585

1586
// ForEachChannel iterates through all the channel edges stored within the
1587
// graph and invokes the passed callback for each edge. The callback takes two
1588
// edges as since this is a directed graph, both the in/out edges are visited.
1589
// If the callback returns an error, then the transaction is aborted and the
1590
// iteration stops early.
1591
//
1592
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1593
// for that particular channel edge routing policy will be passed into the
1594
// callback.
1595
//
1596
// NOTE: part of the Store interface.
1597
func (s *SQLStore) ForEachChannel(ctx context.Context,
1598
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1599
                *models.ChannelEdgePolicy) error, reset func()) error {
×
1600

×
1601
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1602
                return forEachChannelWithPolicies(ctx, db, s.cfg, cb)
×
1603
        }, reset)
×
1604
}
1605

1606
// FilterChannelRange returns the channel ID's of all known channels which were
1607
// mined in a block height within the passed range. The channel IDs are grouped
1608
// by their common block height. This method can be used to quickly share with a
1609
// peer the set of channels we know of within a particular range to catch them
1610
// up after a period of time offline. If withTimestamps is true then the
1611
// timestamp info of the latest received channel update messages of the channel
1612
// will be included in the response.
1613
//
1614
// NOTE: This is part of the Store interface.
1615
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1616
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1617

×
1618
        var (
×
1619
                ctx       = context.TODO()
×
1620
                startSCID = &lnwire.ShortChannelID{
×
1621
                        BlockHeight: startHeight,
×
1622
                }
×
1623
                endSCID = lnwire.ShortChannelID{
×
1624
                        BlockHeight: endHeight,
×
1625
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1626
                        TxPosition:  math.MaxUint16,
×
1627
                }
×
1628
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1629
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1630
        )
×
1631

×
1632
        // 1) get all channels where channelID is between start and end chan ID.
×
1633
        // 2) skip if not public (ie, no channel_proof)
×
1634
        // 3) collect that channel.
×
1635
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1636
        //    and add those timestamps to the collected channel.
×
1637
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1638
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1639
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1640
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1641
                                StartScid: chanIDStart,
×
1642
                                EndScid:   chanIDEnd,
×
1643
                        },
×
1644
                )
×
1645
                if err != nil {
×
1646
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1647
                                err)
×
1648
                }
×
1649

1650
                for _, dbChan := range dbChans {
×
1651
                        cid := lnwire.NewShortChanIDFromInt(
×
1652
                                byteOrder.Uint64(dbChan.Scid),
×
1653
                        )
×
1654
                        chanInfo := NewChannelUpdateInfo(
×
1655
                                cid, time.Time{}, time.Time{},
×
1656
                        )
×
1657

×
1658
                        if !withTimestamps {
×
1659
                                channelsPerBlock[cid.BlockHeight] = append(
×
1660
                                        channelsPerBlock[cid.BlockHeight],
×
1661
                                        chanInfo,
×
1662
                                )
×
1663

×
1664
                                continue
×
1665
                        }
1666

1667
                        //nolint:ll
1668
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1669
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1670
                                        Version:   int16(lnwire.GossipVersion1),
×
1671
                                        ChannelID: dbChan.ID,
×
1672
                                        NodeID:    dbChan.NodeID1,
×
1673
                                },
×
1674
                        )
×
1675
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1676
                                return fmt.Errorf("unable to fetch node1 "+
×
1677
                                        "policy: %w", err)
×
1678
                        } else if err == nil {
×
1679
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1680
                                        node1Policy.LastUpdate.Int64, 0,
×
1681
                                )
×
1682
                        }
×
1683

1684
                        //nolint:ll
1685
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1686
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1687
                                        Version:   int16(lnwire.GossipVersion1),
×
1688
                                        ChannelID: dbChan.ID,
×
1689
                                        NodeID:    dbChan.NodeID2,
×
1690
                                },
×
1691
                        )
×
1692
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1693
                                return fmt.Errorf("unable to fetch node2 "+
×
1694
                                        "policy: %w", err)
×
1695
                        } else if err == nil {
×
1696
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1697
                                        node2Policy.LastUpdate.Int64, 0,
×
1698
                                )
×
1699
                        }
×
1700

1701
                        channelsPerBlock[cid.BlockHeight] = append(
×
1702
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1703
                        )
×
1704
                }
1705

1706
                return nil
×
1707
        }, func() {
×
1708
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1709
        })
×
1710
        if err != nil {
×
1711
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1712
        }
×
1713

1714
        if len(channelsPerBlock) == 0 {
×
1715
                return nil, nil
×
1716
        }
×
1717

1718
        // Return the channel ranges in ascending block height order.
1719
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1720
        slices.Sort(blocks)
×
1721

×
1722
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1723
                return BlockChannelRange{
×
1724
                        Height:   block,
×
1725
                        Channels: channelsPerBlock[block],
×
1726
                }
×
1727
        }), nil
×
1728
}
1729

1730
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1731
// zombie. This method is used on an ad-hoc basis, when channels need to be
1732
// marked as zombies outside the normal pruning cycle.
1733
//
1734
// NOTE: part of the Store interface.
1735
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1736
        pubKey1, pubKey2 [33]byte) error {
×
1737

×
1738
        ctx := context.TODO()
×
1739

×
1740
        s.cacheMu.Lock()
×
1741
        defer s.cacheMu.Unlock()
×
1742

×
1743
        chanIDB := channelIDToBytes(chanID)
×
1744

×
1745
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1746
                return db.UpsertZombieChannel(
×
1747
                        ctx, sqlc.UpsertZombieChannelParams{
×
1748
                                Version:  int16(lnwire.GossipVersion1),
×
1749
                                Scid:     chanIDB,
×
1750
                                NodeKey1: pubKey1[:],
×
1751
                                NodeKey2: pubKey2[:],
×
1752
                        },
×
1753
                )
×
1754
        }, sqldb.NoOpReset)
×
1755
        if err != nil {
×
1756
                return fmt.Errorf("unable to upsert zombie channel "+
×
1757
                        "(channel_id=%d): %w", chanID, err)
×
1758
        }
×
1759

1760
        s.rejectCache.remove(chanID)
×
1761
        s.chanCache.remove(chanID)
×
1762

×
1763
        return nil
×
1764
}
1765

1766
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1767
//
1768
// NOTE: part of the Store interface.
1769
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1770
        s.cacheMu.Lock()
×
1771
        defer s.cacheMu.Unlock()
×
1772

×
1773
        var (
×
1774
                ctx     = context.TODO()
×
1775
                chanIDB = channelIDToBytes(chanID)
×
1776
        )
×
1777

×
1778
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1779
                res, err := db.DeleteZombieChannel(
×
1780
                        ctx, sqlc.DeleteZombieChannelParams{
×
1781
                                Scid:    chanIDB,
×
1782
                                Version: int16(lnwire.GossipVersion1),
×
1783
                        },
×
1784
                )
×
1785
                if err != nil {
×
1786
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1787
                                err)
×
1788
                }
×
1789

1790
                rows, err := res.RowsAffected()
×
1791
                if err != nil {
×
1792
                        return err
×
1793
                }
×
1794

1795
                if rows == 0 {
×
1796
                        return ErrZombieEdgeNotFound
×
1797
                } else if rows > 1 {
×
1798
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1799
                                "expected 1", rows)
×
1800
                }
×
1801

1802
                return nil
×
1803
        }, sqldb.NoOpReset)
1804
        if err != nil {
×
1805
                return fmt.Errorf("unable to mark edge live "+
×
1806
                        "(channel_id=%d): %w", chanID, err)
×
1807
        }
×
1808

1809
        s.rejectCache.remove(chanID)
×
1810
        s.chanCache.remove(chanID)
×
1811

×
1812
        return err
×
1813
}
1814

1815
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1816
// zombie, then the two node public keys corresponding to this edge are also
1817
// returned.
1818
//
1819
// NOTE: part of the Store interface.
1820
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
1821
        error) {
×
1822

×
1823
        var (
×
1824
                ctx              = context.TODO()
×
1825
                isZombie         bool
×
1826
                pubKey1, pubKey2 route.Vertex
×
1827
                chanIDB          = channelIDToBytes(chanID)
×
1828
        )
×
1829

×
1830
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1831
                zombie, err := db.GetZombieChannel(
×
1832
                        ctx, sqlc.GetZombieChannelParams{
×
1833
                                Scid:    chanIDB,
×
1834
                                Version: int16(lnwire.GossipVersion1),
×
1835
                        },
×
1836
                )
×
1837
                if errors.Is(err, sql.ErrNoRows) {
×
1838
                        return nil
×
1839
                }
×
1840
                if err != nil {
×
1841
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1842
                                err)
×
1843
                }
×
1844

1845
                copy(pubKey1[:], zombie.NodeKey1)
×
1846
                copy(pubKey2[:], zombie.NodeKey2)
×
1847
                isZombie = true
×
1848

×
1849
                return nil
×
1850
        }, sqldb.NoOpReset)
1851
        if err != nil {
×
1852
                return false, route.Vertex{}, route.Vertex{},
×
1853
                        fmt.Errorf("%w: %w (chanID=%d)",
×
1854
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
×
1855
        }
×
1856

1857
        return isZombie, pubKey1, pubKey2, nil
×
1858
}
1859

1860
// NumZombies returns the current number of zombie channels in the graph.
1861
//
1862
// NOTE: part of the Store interface.
1863
func (s *SQLStore) NumZombies() (uint64, error) {
×
1864
        var (
×
1865
                ctx        = context.TODO()
×
1866
                numZombies uint64
×
1867
        )
×
1868
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1869
                count, err := db.CountZombieChannels(
×
1870
                        ctx, int16(lnwire.GossipVersion1),
×
1871
                )
×
1872
                if err != nil {
×
1873
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1874
                                err)
×
1875
                }
×
1876

1877
                numZombies = uint64(count)
×
1878

×
1879
                return nil
×
1880
        }, sqldb.NoOpReset)
1881
        if err != nil {
×
1882
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1883
        }
×
1884

1885
        return numZombies, nil
×
1886
}
1887

1888
// DeleteChannelEdges removes edges with the given channel IDs from the
1889
// database and marks them as zombies. This ensures that we're unable to re-add
1890
// it to our database once again. If an edge does not exist within the
1891
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1892
// true, then when we mark these edges as zombies, we'll set up the keys such
1893
// that we require the node that failed to send the fresh update to be the one
1894
// that resurrects the channel from its zombie state. The markZombie bool
1895
// denotes whether to mark the channel as a zombie.
1896
//
1897
// NOTE: part of the Store interface.
1898
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1899
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
1900

×
1901
        s.cacheMu.Lock()
×
1902
        defer s.cacheMu.Unlock()
×
1903

×
1904
        // Keep track of which channels we end up finding so that we can
×
1905
        // correctly return ErrEdgeNotFound if we do not find a channel.
×
1906
        chanLookup := make(map[uint64]struct{}, len(chanIDs))
×
1907
        for _, chanID := range chanIDs {
×
1908
                chanLookup[chanID] = struct{}{}
×
1909
        }
×
1910

1911
        var (
×
1912
                ctx   = context.TODO()
×
1913
                edges []*models.ChannelEdgeInfo
×
1914
        )
×
1915
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1916
                // First, collect all channel rows.
×
1917
                var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow
×
1918
                chanCallBack := func(ctx context.Context,
×
1919
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
1920

×
1921
                        // Deleting the entry from the map indicates that we
×
1922
                        // have found the channel.
×
1923
                        scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1924
                        delete(chanLookup, scid)
×
1925

×
1926
                        channelRows = append(channelRows, row)
×
1927

×
1928
                        return nil
×
1929
                }
×
1930

1931
                err := s.forEachChanWithPoliciesInSCIDList(
×
1932
                        ctx, db, chanCallBack, chanIDs,
×
1933
                )
×
1934
                if err != nil {
×
1935
                        return err
×
1936
                }
×
1937

1938
                if len(chanLookup) > 0 {
×
1939
                        return ErrEdgeNotFound
×
1940
                }
×
1941

1942
                if len(channelRows) == 0 {
×
1943
                        return nil
×
1944
                }
×
1945

1946
                // Batch build all channel edges.
1947
                var chanIDsToDelete []int64
×
1948
                edges, chanIDsToDelete, err = batchBuildChannelInfo(
×
1949
                        ctx, s.cfg, db, channelRows,
×
1950
                )
×
1951
                if err != nil {
×
1952
                        return err
×
1953
                }
×
1954

1955
                if markZombie {
×
1956
                        for i, row := range channelRows {
×
1957
                                scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1958

×
1959
                                err := handleZombieMarking(
×
1960
                                        ctx, db, row, edges[i],
×
1961
                                        strictZombiePruning, scid,
×
1962
                                )
×
1963
                                if err != nil {
×
1964
                                        return fmt.Errorf("unable to mark "+
×
1965
                                                "channel as zombie: %w", err)
×
1966
                                }
×
1967
                        }
1968
                }
1969

1970
                return s.deleteChannels(ctx, db, chanIDsToDelete)
×
1971
        }, func() {
×
1972
                edges = nil
×
1973

×
1974
                // Re-fill the lookup map.
×
1975
                for _, chanID := range chanIDs {
×
1976
                        chanLookup[chanID] = struct{}{}
×
1977
                }
×
1978
        })
1979
        if err != nil {
×
1980
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1981
                        err)
×
1982
        }
×
1983

1984
        for _, chanID := range chanIDs {
×
1985
                s.rejectCache.remove(chanID)
×
1986
                s.chanCache.remove(chanID)
×
1987
        }
×
1988

1989
        return edges, nil
×
1990
}
1991

1992
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1993
// channel identified by the channel ID. If the channel can't be found, then
1994
// ErrEdgeNotFound is returned. A struct which houses the general information
1995
// for the channel itself is returned as well as two structs that contain the
1996
// routing policies for the channel in either direction.
1997
//
1998
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1999
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
2000
// the ChannelEdgeInfo will only include the public keys of each node.
2001
//
2002
// NOTE: part of the Store interface.
2003
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
2004
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
2005
        *models.ChannelEdgePolicy, error) {
×
2006

×
2007
        var (
×
2008
                ctx              = context.TODO()
×
2009
                edge             *models.ChannelEdgeInfo
×
2010
                policy1, policy2 *models.ChannelEdgePolicy
×
2011
                chanIDB          = channelIDToBytes(chanID)
×
2012
        )
×
2013
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2014
                row, err := db.GetChannelBySCIDWithPolicies(
×
2015
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
2016
                                Scid:    chanIDB,
×
2017
                                Version: int16(lnwire.GossipVersion1),
×
2018
                        },
×
2019
                )
×
2020
                if errors.Is(err, sql.ErrNoRows) {
×
2021
                        // First check if this edge is perhaps in the zombie
×
2022
                        // index.
×
2023
                        zombie, err := db.GetZombieChannel(
×
2024
                                ctx, sqlc.GetZombieChannelParams{
×
2025
                                        Scid:    chanIDB,
×
2026
                                        Version: int16(lnwire.GossipVersion1),
×
2027
                                },
×
2028
                        )
×
2029
                        if errors.Is(err, sql.ErrNoRows) {
×
2030
                                return ErrEdgeNotFound
×
2031
                        } else if err != nil {
×
2032
                                return fmt.Errorf("unable to check if "+
×
2033
                                        "channel is zombie: %w", err)
×
2034
                        }
×
2035

2036
                        // At this point, we know the channel is a zombie, so
2037
                        // we'll return an error indicating this, and we will
2038
                        // populate the edge info with the public keys of each
2039
                        // party as this is the only information we have about
2040
                        // it.
NEW
2041
                        node1, err := route.NewVertexFromBytes(zombie.NodeKey1)
×
NEW
2042
                        if err != nil {
×
NEW
2043
                                return err
×
NEW
2044
                        }
×
NEW
2045
                        node2, err := route.NewVertexFromBytes(zombie.NodeKey2)
×
NEW
2046
                        if err != nil {
×
NEW
2047
                                return err
×
NEW
2048
                        }
×
NEW
2049
                        zombieEdge, err := models.NewV1Channel(
×
NEW
2050
                                0, chainhash.Hash{}, node1, node2,
×
NEW
2051
                                &models.ChannelV1Fields{},
×
NEW
2052
                        )
×
NEW
2053
                        if err != nil {
×
NEW
2054
                                return err
×
NEW
2055
                        }
×
NEW
2056
                        edge = zombieEdge
×
2057

×
2058
                        return ErrZombieEdge
×
2059
                } else if err != nil {
×
2060
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2061
                }
×
2062

2063
                node1, node2, err := buildNodeVertices(
×
2064
                        row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
2065
                )
×
2066
                if err != nil {
×
2067
                        return err
×
2068
                }
×
2069

2070
                edge, err = getAndBuildEdgeInfo(
×
2071
                        ctx, s.cfg, db, row.GraphChannel, node1, node2,
×
2072
                )
×
2073
                if err != nil {
×
2074
                        return fmt.Errorf("unable to build channel info: %w",
×
2075
                                err)
×
2076
                }
×
2077

2078
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2079
                if err != nil {
×
2080
                        return fmt.Errorf("unable to extract channel "+
×
2081
                                "policies: %w", err)
×
2082
                }
×
2083

2084
                policy1, policy2, err = getAndBuildChanPolicies(
×
2085
                        ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, edge.ChannelID,
×
2086
                        node1, node2,
×
2087
                )
×
2088
                if err != nil {
×
2089
                        return fmt.Errorf("unable to build channel "+
×
2090
                                "policies: %w", err)
×
2091
                }
×
2092

2093
                return nil
×
2094
        }, sqldb.NoOpReset)
2095
        if err != nil {
×
2096
                // If we are returning the ErrZombieEdge, then we also need to
×
2097
                // return the edge info as the method comment indicates that
×
2098
                // this will be populated when the edge is a zombie.
×
2099
                return edge, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
2100
                        err)
×
2101
        }
×
2102

2103
        return edge, policy1, policy2, nil
×
2104
}
2105

2106
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
2107
// the channel identified by the funding outpoint. If the channel can't be
2108
// found, then ErrEdgeNotFound is returned. A struct which houses the general
2109
// information for the channel itself is returned as well as two structs that
2110
// contain the routing policies for the channel in either direction.
2111
//
2112
// NOTE: part of the Store interface.
2113
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
2114
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
2115
        *models.ChannelEdgePolicy, error) {
×
2116

×
2117
        var (
×
2118
                ctx              = context.TODO()
×
2119
                edge             *models.ChannelEdgeInfo
×
2120
                policy1, policy2 *models.ChannelEdgePolicy
×
2121
        )
×
2122
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2123
                row, err := db.GetChannelByOutpointWithPolicies(
×
2124
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
2125
                                Outpoint: op.String(),
×
2126
                                Version:  int16(lnwire.GossipVersion1),
×
2127
                        },
×
2128
                )
×
2129
                if errors.Is(err, sql.ErrNoRows) {
×
2130
                        return ErrEdgeNotFound
×
2131
                } else if err != nil {
×
2132
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2133
                }
×
2134

2135
                node1, node2, err := buildNodeVertices(
×
2136
                        row.Node1Pubkey, row.Node2Pubkey,
×
2137
                )
×
2138
                if err != nil {
×
2139
                        return err
×
2140
                }
×
2141

2142
                edge, err = getAndBuildEdgeInfo(
×
2143
                        ctx, s.cfg, db, row.GraphChannel, node1, node2,
×
2144
                )
×
2145
                if err != nil {
×
2146
                        return fmt.Errorf("unable to build channel info: %w",
×
2147
                                err)
×
2148
                }
×
2149

2150
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2151
                if err != nil {
×
2152
                        return fmt.Errorf("unable to extract channel "+
×
2153
                                "policies: %w", err)
×
2154
                }
×
2155

2156
                policy1, policy2, err = getAndBuildChanPolicies(
×
2157
                        ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, edge.ChannelID,
×
2158
                        node1, node2,
×
2159
                )
×
2160
                if err != nil {
×
2161
                        return fmt.Errorf("unable to build channel "+
×
2162
                                "policies: %w", err)
×
2163
                }
×
2164

2165
                return nil
×
2166
        }, sqldb.NoOpReset)
2167
        if err != nil {
×
2168
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
2169
                        err)
×
2170
        }
×
2171

2172
        return edge, policy1, policy2, nil
×
2173
}
2174

2175
// HasChannelEdge returns true if the database knows of a channel edge with the
2176
// passed channel ID, and false otherwise. If an edge with that ID is found
2177
// within the graph, then two time stamps representing the last time the edge
2178
// was updated for both directed edges are returned along with the boolean. If
2179
// it is not found, then the zombie index is checked and its result is returned
2180
// as the second boolean.
2181
//
2182
// NOTE: part of the Store interface.
2183
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
2184
        bool, error) {
×
2185

×
2186
        ctx := context.TODO()
×
2187

×
2188
        var (
×
2189
                exists          bool
×
2190
                isZombie        bool
×
2191
                node1LastUpdate time.Time
×
2192
                node2LastUpdate time.Time
×
2193
        )
×
2194

×
2195
        // We'll query the cache with the shared lock held to allow multiple
×
2196
        // readers to access values in the cache concurrently if they exist.
×
2197
        s.cacheMu.RLock()
×
2198
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2199
                s.cacheMu.RUnlock()
×
2200
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2201
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2202
                exists, isZombie = entry.flags.unpack()
×
2203

×
2204
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2205
        }
×
2206
        s.cacheMu.RUnlock()
×
2207

×
2208
        s.cacheMu.Lock()
×
2209
        defer s.cacheMu.Unlock()
×
2210

×
2211
        // The item was not found with the shared lock, so we'll acquire the
×
2212
        // exclusive lock and check the cache again in case another method added
×
2213
        // the entry to the cache while no lock was held.
×
2214
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2215
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2216
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2217
                exists, isZombie = entry.flags.unpack()
×
2218

×
2219
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2220
        }
×
2221

2222
        chanIDB := channelIDToBytes(chanID)
×
2223
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2224
                channel, err := db.GetChannelBySCID(
×
2225
                        ctx, sqlc.GetChannelBySCIDParams{
×
2226
                                Scid:    chanIDB,
×
2227
                                Version: int16(lnwire.GossipVersion1),
×
2228
                        },
×
2229
                )
×
2230
                if errors.Is(err, sql.ErrNoRows) {
×
2231
                        // Check if it is a zombie channel.
×
2232
                        isZombie, err = db.IsZombieChannel(
×
2233
                                ctx, sqlc.IsZombieChannelParams{
×
2234
                                        Scid:    chanIDB,
×
2235
                                        Version: int16(lnwire.GossipVersion1),
×
2236
                                },
×
2237
                        )
×
2238
                        if err != nil {
×
2239
                                return fmt.Errorf("could not check if channel "+
×
2240
                                        "is zombie: %w", err)
×
2241
                        }
×
2242

2243
                        return nil
×
2244
                } else if err != nil {
×
2245
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2246
                }
×
2247

2248
                exists = true
×
2249

×
2250
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
2251
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2252
                                Version:   int16(lnwire.GossipVersion1),
×
2253
                                ChannelID: channel.ID,
×
2254
                                NodeID:    channel.NodeID1,
×
2255
                        },
×
2256
                )
×
2257
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2258
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2259
                                err)
×
2260
                } else if err == nil {
×
2261
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
2262
                }
×
2263

2264
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
2265
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2266
                                Version:   int16(lnwire.GossipVersion1),
×
2267
                                ChannelID: channel.ID,
×
2268
                                NodeID:    channel.NodeID2,
×
2269
                        },
×
2270
                )
×
2271
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2272
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2273
                                err)
×
2274
                } else if err == nil {
×
2275
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
2276
                }
×
2277

2278
                return nil
×
2279
        }, sqldb.NoOpReset)
2280
        if err != nil {
×
2281
                return time.Time{}, time.Time{}, false, false,
×
2282
                        fmt.Errorf("unable to fetch channel: %w", err)
×
2283
        }
×
2284

2285
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
2286
                upd1Time: node1LastUpdate.Unix(),
×
2287
                upd2Time: node2LastUpdate.Unix(),
×
2288
                flags:    packRejectFlags(exists, isZombie),
×
2289
        })
×
2290

×
2291
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2292
}
2293

2294
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2295
// passed channel point (outpoint). If the passed channel doesn't exist within
2296
// the database, then ErrEdgeNotFound is returned.
2297
//
2298
// NOTE: part of the Store interface.
2299
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
2300
        var (
×
2301
                ctx       = context.TODO()
×
2302
                channelID uint64
×
2303
        )
×
2304
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2305
                chanID, err := db.GetSCIDByOutpoint(
×
2306
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2307
                                Outpoint: chanPoint.String(),
×
2308
                                Version:  int16(lnwire.GossipVersion1),
×
2309
                        },
×
2310
                )
×
2311
                if errors.Is(err, sql.ErrNoRows) {
×
2312
                        return ErrEdgeNotFound
×
2313
                } else if err != nil {
×
2314
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2315
                                err)
×
2316
                }
×
2317

2318
                channelID = byteOrder.Uint64(chanID)
×
2319

×
2320
                return nil
×
2321
        }, sqldb.NoOpReset)
2322
        if err != nil {
×
2323
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2324
        }
×
2325

2326
        return channelID, nil
×
2327
}
2328

2329
// IsPublicNode is a helper method that determines whether the node with the
2330
// given public key is seen as a public node in the graph from the graph's
2331
// source node's point of view.
2332
//
2333
// NOTE: part of the Store interface.
2334
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
2335
        ctx := context.TODO()
×
2336

×
2337
        var isPublic bool
×
2338
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2339
                var err error
×
2340
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2341

×
2342
                return err
×
2343
        }, sqldb.NoOpReset)
×
2344
        if err != nil {
×
2345
                return false, fmt.Errorf("unable to check if node is "+
×
2346
                        "public: %w", err)
×
2347
        }
×
2348

2349
        return isPublic, nil
×
2350
}
2351

2352
// FetchChanInfos returns the set of channel edges that correspond to the passed
2353
// channel ID's. If an edge is the query is unknown to the database, it will
2354
// skipped and the result will contain only those edges that exist at the time
2355
// of the query. This can be used to respond to peer queries that are seeking to
2356
// fill in gaps in their view of the channel graph.
2357
//
2358
// NOTE: part of the Store interface.
2359
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2360
        var (
×
2361
                ctx   = context.TODO()
×
2362
                edges = make(map[uint64]ChannelEdge)
×
2363
        )
×
2364
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2365
                // First, collect all channel rows.
×
2366
                var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow
×
2367
                chanCallBack := func(ctx context.Context,
×
2368
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
2369

×
2370
                        channelRows = append(channelRows, row)
×
2371
                        return nil
×
2372
                }
×
2373

2374
                err := s.forEachChanWithPoliciesInSCIDList(
×
2375
                        ctx, db, chanCallBack, chanIDs,
×
2376
                )
×
2377
                if err != nil {
×
2378
                        return err
×
2379
                }
×
2380

2381
                if len(channelRows) == 0 {
×
2382
                        return nil
×
2383
                }
×
2384

2385
                // Batch build all channel edges.
2386
                chans, err := batchBuildChannelEdges(
×
2387
                        ctx, s.cfg, db, channelRows,
×
2388
                )
×
2389
                if err != nil {
×
2390
                        return fmt.Errorf("unable to build channel edges: %w",
×
2391
                                err)
×
2392
                }
×
2393

2394
                for _, c := range chans {
×
2395
                        edges[c.Info.ChannelID] = c
×
2396
                }
×
2397

2398
                return err
×
2399
        }, func() {
×
2400
                clear(edges)
×
2401
        })
×
2402
        if err != nil {
×
2403
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2404
        }
×
2405

2406
        res := make([]ChannelEdge, 0, len(edges))
×
2407
        for _, chanID := range chanIDs {
×
2408
                edge, ok := edges[chanID]
×
2409
                if !ok {
×
2410
                        continue
×
2411
                }
2412

2413
                res = append(res, edge)
×
2414
        }
2415

2416
        return res, nil
×
2417
}
2418

2419
// forEachChanWithPoliciesInSCIDList is a wrapper around the
2420
// GetChannelsBySCIDWithPolicies query that allows us to iterate through
2421
// channels in a paginated manner.
2422
func (s *SQLStore) forEachChanWithPoliciesInSCIDList(ctx context.Context,
2423
        db SQLQueries, cb func(ctx context.Context,
2424
                row sqlc.GetChannelsBySCIDWithPoliciesRow) error,
2425
        chanIDs []uint64) error {
×
2426

×
2427
        queryWrapper := func(ctx context.Context,
×
2428
                scids [][]byte) ([]sqlc.GetChannelsBySCIDWithPoliciesRow,
×
2429
                error) {
×
2430

×
2431
                return db.GetChannelsBySCIDWithPolicies(
×
2432
                        ctx, sqlc.GetChannelsBySCIDWithPoliciesParams{
×
2433
                                Version: int16(lnwire.GossipVersion1),
×
2434
                                Scids:   scids,
×
2435
                        },
×
2436
                )
×
2437
        }
×
2438

2439
        return sqldb.ExecuteBatchQuery(
×
2440
                ctx, s.cfg.QueryCfg, chanIDs, channelIDToBytes, queryWrapper,
×
2441
                cb,
×
2442
        )
×
2443
}
2444

2445
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2446
// ID's that we don't know and are not known zombies of the passed set. In other
2447
// words, we perform a set difference of our set of chan ID's and the ones
2448
// passed in. This method can be used by callers to determine the set of
2449
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2450
// known zombies is also returned.
2451
//
2452
// NOTE: part of the Store interface.
2453
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
2454
        []ChannelUpdateInfo, error) {
×
2455

×
2456
        var (
×
2457
                ctx          = context.TODO()
×
2458
                newChanIDs   []uint64
×
2459
                knownZombies []ChannelUpdateInfo
×
2460
                infoLookup   = make(
×
2461
                        map[uint64]ChannelUpdateInfo, len(chansInfo),
×
2462
                )
×
2463
        )
×
2464

×
2465
        // We first build a lookup map of the channel ID's to the
×
2466
        // ChannelUpdateInfo. This allows us to quickly delete channels that we
×
2467
        // already know about.
×
2468
        for _, chanInfo := range chansInfo {
×
2469
                infoLookup[chanInfo.ShortChannelID.ToUint64()] = chanInfo
×
2470
        }
×
2471

2472
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2473
                // The call-back function deletes known channels from
×
2474
                // infoLookup, so that we can later check which channels are
×
2475
                // zombies by only looking at the remaining channels in the set.
×
2476
                cb := func(ctx context.Context,
×
2477
                        channel sqlc.GraphChannel) error {
×
2478

×
2479
                        delete(infoLookup, byteOrder.Uint64(channel.Scid))
×
2480

×
2481
                        return nil
×
2482
                }
×
2483

2484
                err := s.forEachChanInSCIDList(ctx, db, cb, chansInfo)
×
2485
                if err != nil {
×
2486
                        return fmt.Errorf("unable to iterate through "+
×
2487
                                "channels: %w", err)
×
2488
                }
×
2489

2490
                // We want to ensure that we deal with the channels in the
2491
                // same order that they were passed in, so we iterate over the
2492
                // original chansInfo slice and then check if that channel is
2493
                // still in the infoLookup map.
2494
                for _, chanInfo := range chansInfo {
×
2495
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
2496
                        if _, ok := infoLookup[channelID]; !ok {
×
2497
                                continue
×
2498
                        }
2499

2500
                        isZombie, err := db.IsZombieChannel(
×
2501
                                ctx, sqlc.IsZombieChannelParams{
×
2502
                                        Scid:    channelIDToBytes(channelID),
×
2503
                                        Version: int16(lnwire.GossipVersion1),
×
2504
                                },
×
2505
                        )
×
2506
                        if err != nil {
×
2507
                                return fmt.Errorf("unable to fetch zombie "+
×
2508
                                        "channel: %w", err)
×
2509
                        }
×
2510

2511
                        if isZombie {
×
2512
                                knownZombies = append(knownZombies, chanInfo)
×
2513

×
2514
                                continue
×
2515
                        }
2516

2517
                        newChanIDs = append(newChanIDs, channelID)
×
2518
                }
2519

2520
                return nil
×
2521
        }, func() {
×
2522
                newChanIDs = nil
×
2523
                knownZombies = nil
×
2524
                // Rebuild the infoLookup map in case of a rollback.
×
2525
                for _, chanInfo := range chansInfo {
×
2526
                        scid := chanInfo.ShortChannelID.ToUint64()
×
2527
                        infoLookup[scid] = chanInfo
×
2528
                }
×
2529
        })
2530
        if err != nil {
×
2531
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2532
        }
×
2533

2534
        return newChanIDs, knownZombies, nil
×
2535
}
2536

2537
// forEachChanInSCIDList is a helper method that executes a paged query
2538
// against the database to fetch all channels that match the passed
2539
// ChannelUpdateInfo slice. The callback function is called for each channel
2540
// that is found.
2541
func (s *SQLStore) forEachChanInSCIDList(ctx context.Context, db SQLQueries,
2542
        cb func(ctx context.Context, channel sqlc.GraphChannel) error,
2543
        chansInfo []ChannelUpdateInfo) error {
×
2544

×
2545
        queryWrapper := func(ctx context.Context,
×
2546
                scids [][]byte) ([]sqlc.GraphChannel, error) {
×
2547

×
2548
                return db.GetChannelsBySCIDs(
×
2549
                        ctx, sqlc.GetChannelsBySCIDsParams{
×
2550
                                Version: int16(lnwire.GossipVersion1),
×
2551
                                Scids:   scids,
×
2552
                        },
×
2553
                )
×
2554
        }
×
2555

2556
        chanIDConverter := func(chanInfo ChannelUpdateInfo) []byte {
×
2557
                channelID := chanInfo.ShortChannelID.ToUint64()
×
2558

×
2559
                return channelIDToBytes(channelID)
×
2560
        }
×
2561

2562
        return sqldb.ExecuteBatchQuery(
×
2563
                ctx, s.cfg.QueryCfg, chansInfo, chanIDConverter, queryWrapper,
×
2564
                cb,
×
2565
        )
×
2566
}
2567

2568
// PruneGraphNodes is a garbage collection method which attempts to prune out
2569
// any nodes from the channel graph that are currently unconnected. This ensure
2570
// that we only maintain a graph of reachable nodes. In the event that a pruned
2571
// node gains more channels, it will be re-added back to the graph.
2572
//
2573
// NOTE: this prunes nodes across protocol versions. It will never prune the
2574
// source nodes.
2575
//
2576
// NOTE: part of the Store interface.
2577
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
2578
        var ctx = context.TODO()
×
2579

×
2580
        var prunedNodes []route.Vertex
×
2581
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2582
                var err error
×
2583
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2584

×
2585
                return err
×
2586
        }, func() {
×
2587
                prunedNodes = nil
×
2588
        })
×
2589
        if err != nil {
×
2590
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2591
        }
×
2592

2593
        return prunedNodes, nil
×
2594
}
2595

2596
// PruneGraph prunes newly closed channels from the channel graph in response
2597
// to a new block being solved on the network. Any transactions which spend the
2598
// funding output of any known channels within he graph will be deleted.
2599
// Additionally, the "prune tip", or the last block which has been used to
2600
// prune the graph is stored so callers can ensure the graph is fully in sync
2601
// with the current UTXO state. A slice of channels that have been closed by
2602
// the target block along with any pruned nodes are returned if the function
2603
// succeeds without error.
2604
//
2605
// NOTE: part of the Store interface.
2606
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2607
        blockHash *chainhash.Hash, blockHeight uint32) (
2608
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
2609

×
2610
        ctx := context.TODO()
×
2611

×
2612
        s.cacheMu.Lock()
×
2613
        defer s.cacheMu.Unlock()
×
2614

×
2615
        var (
×
2616
                closedChans []*models.ChannelEdgeInfo
×
2617
                prunedNodes []route.Vertex
×
2618
        )
×
2619
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2620
                // First, collect all channel rows that need to be pruned.
×
2621
                var channelRows []sqlc.GetChannelsByOutpointsRow
×
2622
                channelCallback := func(ctx context.Context,
×
2623
                        row sqlc.GetChannelsByOutpointsRow) error {
×
2624

×
2625
                        channelRows = append(channelRows, row)
×
2626

×
2627
                        return nil
×
2628
                }
×
2629

2630
                err := s.forEachChanInOutpoints(
×
2631
                        ctx, db, spentOutputs, channelCallback,
×
2632
                )
×
2633
                if err != nil {
×
2634
                        return fmt.Errorf("unable to fetch channels by "+
×
2635
                                "outpoints: %w", err)
×
2636
                }
×
2637

2638
                if len(channelRows) == 0 {
×
2639
                        // There are no channels to prune. So we can exit early
×
2640
                        // after updating the prune log.
×
2641
                        err = db.UpsertPruneLogEntry(
×
2642
                                ctx, sqlc.UpsertPruneLogEntryParams{
×
2643
                                        BlockHash:   blockHash[:],
×
2644
                                        BlockHeight: int64(blockHeight),
×
2645
                                },
×
2646
                        )
×
2647
                        if err != nil {
×
2648
                                return fmt.Errorf("unable to insert prune log "+
×
2649
                                        "entry: %w", err)
×
2650
                        }
×
2651

2652
                        return nil
×
2653
                }
2654

2655
                // Batch build all channel edges for pruning.
2656
                var chansToDelete []int64
×
2657
                closedChans, chansToDelete, err = batchBuildChannelInfo(
×
2658
                        ctx, s.cfg, db, channelRows,
×
2659
                )
×
2660
                if err != nil {
×
2661
                        return err
×
2662
                }
×
2663

2664
                err = s.deleteChannels(ctx, db, chansToDelete)
×
2665
                if err != nil {
×
2666
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2667
                }
×
2668

2669
                err = db.UpsertPruneLogEntry(
×
2670
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
2671
                                BlockHash:   blockHash[:],
×
2672
                                BlockHeight: int64(blockHeight),
×
2673
                        },
×
2674
                )
×
2675
                if err != nil {
×
2676
                        return fmt.Errorf("unable to insert prune log "+
×
2677
                                "entry: %w", err)
×
2678
                }
×
2679

2680
                // Now that we've pruned some channels, we'll also prune any
2681
                // nodes that no longer have any channels.
2682
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2683
                if err != nil {
×
2684
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2685
                                err)
×
2686
                }
×
2687

2688
                return nil
×
2689
        }, func() {
×
2690
                prunedNodes = nil
×
2691
                closedChans = nil
×
2692
        })
×
2693
        if err != nil {
×
2694
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2695
        }
×
2696

2697
        for _, channel := range closedChans {
×
2698
                s.rejectCache.remove(channel.ChannelID)
×
2699
                s.chanCache.remove(channel.ChannelID)
×
2700
        }
×
2701

2702
        return closedChans, prunedNodes, nil
×
2703
}
2704

2705
// forEachChanInOutpoints is a helper function that executes a paginated
2706
// query to fetch channels by their outpoints and applies the given call-back
2707
// to each.
2708
//
2709
// NOTE: this fetches channels for all protocol versions.
2710
func (s *SQLStore) forEachChanInOutpoints(ctx context.Context, db SQLQueries,
2711
        outpoints []*wire.OutPoint, cb func(ctx context.Context,
2712
                row sqlc.GetChannelsByOutpointsRow) error) error {
×
2713

×
2714
        // Create a wrapper that uses the transaction's db instance to execute
×
2715
        // the query.
×
2716
        queryWrapper := func(ctx context.Context,
×
2717
                pageOutpoints []string) ([]sqlc.GetChannelsByOutpointsRow,
×
2718
                error) {
×
2719

×
2720
                return db.GetChannelsByOutpoints(ctx, pageOutpoints)
×
2721
        }
×
2722

2723
        // Define the conversion function from Outpoint to string.
2724
        outpointToString := func(outpoint *wire.OutPoint) string {
×
2725
                return outpoint.String()
×
2726
        }
×
2727

2728
        return sqldb.ExecuteBatchQuery(
×
2729
                ctx, s.cfg.QueryCfg, outpoints, outpointToString,
×
2730
                queryWrapper, cb,
×
2731
        )
×
2732
}
2733

2734
func (s *SQLStore) deleteChannels(ctx context.Context, db SQLQueries,
2735
        dbIDs []int64) error {
×
2736

×
2737
        // Create a wrapper that uses the transaction's db instance to execute
×
2738
        // the query.
×
2739
        queryWrapper := func(ctx context.Context, ids []int64) ([]any, error) {
×
2740
                return nil, db.DeleteChannels(ctx, ids)
×
2741
        }
×
2742

2743
        idConverter := func(id int64) int64 {
×
2744
                return id
×
2745
        }
×
2746

2747
        return sqldb.ExecuteBatchQuery(
×
2748
                ctx, s.cfg.QueryCfg, dbIDs, idConverter,
×
2749
                queryWrapper, func(ctx context.Context, _ any) error {
×
2750
                        return nil
×
2751
                },
×
2752
        )
2753
}
2754

2755
// ChannelView returns the verifiable edge information for each active channel
2756
// within the known channel graph. The set of UTXOs (along with their scripts)
2757
// returned are the ones that need to be watched on chain to detect channel
2758
// closes on the resident blockchain.
2759
//
2760
// NOTE: part of the Store interface.
2761
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
2762
        var (
×
2763
                ctx        = context.TODO()
×
2764
                edgePoints []EdgePoint
×
2765
        )
×
2766

×
2767
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2768
                handleChannel := func(_ context.Context,
×
2769
                        channel sqlc.ListChannelsPaginatedRow) error {
×
2770

×
NEW
2771
                        // TODO(elle): update to handle V2 channels.
×
2772
                        pkScript, err := genMultiSigP2WSH(
×
2773
                                channel.BitcoinKey1, channel.BitcoinKey2,
×
2774
                        )
×
2775
                        if err != nil {
×
2776
                                return err
×
2777
                        }
×
2778

2779
                        op, err := wire.NewOutPointFromString(channel.Outpoint)
×
2780
                        if err != nil {
×
2781
                                return err
×
2782
                        }
×
2783

2784
                        edgePoints = append(edgePoints, EdgePoint{
×
2785
                                FundingPkScript: pkScript,
×
2786
                                OutPoint:        *op,
×
2787
                        })
×
2788

×
2789
                        return nil
×
2790
                }
2791

2792
                queryFunc := func(ctx context.Context, lastID int64,
×
2793
                        limit int32) ([]sqlc.ListChannelsPaginatedRow, error) {
×
2794

×
2795
                        return db.ListChannelsPaginated(
×
2796
                                ctx, sqlc.ListChannelsPaginatedParams{
×
2797
                                        Version: int16(lnwire.GossipVersion1),
×
2798
                                        ID:      lastID,
×
2799
                                        Limit:   limit,
×
2800
                                },
×
2801
                        )
×
2802
                }
×
2803

2804
                extractCursor := func(row sqlc.ListChannelsPaginatedRow) int64 {
×
2805
                        return row.ID
×
2806
                }
×
2807

2808
                return sqldb.ExecutePaginatedQuery(
×
2809
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
2810
                        extractCursor, handleChannel,
×
2811
                )
×
2812
        }, func() {
×
2813
                edgePoints = nil
×
2814
        })
×
2815
        if err != nil {
×
2816
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
2817
        }
×
2818

2819
        return edgePoints, nil
×
2820
}
2821

2822
// PruneTip returns the block height and hash of the latest block that has been
2823
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2824
// to tell if the graph is currently in sync with the current best known UTXO
2825
// state.
2826
//
2827
// NOTE: part of the Store interface.
2828
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
2829
        var (
×
2830
                ctx       = context.TODO()
×
2831
                tipHash   chainhash.Hash
×
2832
                tipHeight uint32
×
2833
        )
×
2834
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2835
                pruneTip, err := db.GetPruneTip(ctx)
×
2836
                if errors.Is(err, sql.ErrNoRows) {
×
2837
                        return ErrGraphNeverPruned
×
2838
                } else if err != nil {
×
2839
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
2840
                }
×
2841

2842
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
2843
                tipHeight = uint32(pruneTip.BlockHeight)
×
2844

×
2845
                return nil
×
2846
        }, sqldb.NoOpReset)
2847
        if err != nil {
×
2848
                return nil, 0, err
×
2849
        }
×
2850

2851
        return &tipHash, tipHeight, nil
×
2852
}
2853

2854
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2855
//
2856
// NOTE: this prunes nodes across protocol versions. It will never prune the
2857
// source nodes.
2858
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
2859
        db SQLQueries) ([]route.Vertex, error) {
×
2860

×
2861
        nodeKeys, err := db.DeleteUnconnectedNodes(ctx)
×
2862
        if err != nil {
×
2863
                return nil, fmt.Errorf("unable to delete unconnected "+
×
2864
                        "nodes: %w", err)
×
2865
        }
×
2866

2867
        prunedNodes := make([]route.Vertex, len(nodeKeys))
×
2868
        for i, nodeKey := range nodeKeys {
×
2869
                pub, err := route.NewVertexFromBytes(nodeKey)
×
2870
                if err != nil {
×
2871
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
2872
                                "from bytes: %w", err)
×
2873
                }
×
2874

2875
                prunedNodes[i] = pub
×
2876
        }
2877

2878
        return prunedNodes, nil
×
2879
}
2880

2881
// DisconnectBlockAtHeight is used to indicate that the block specified
2882
// by the passed height has been disconnected from the main chain. This
2883
// will "rewind" the graph back to the height below, deleting channels
2884
// that are no longer confirmed from the graph. The prune log will be
2885
// set to the last prune height valid for the remaining chain.
2886
// Channels that were removed from the graph resulting from the
2887
// disconnected block are returned.
2888
//
2889
// NOTE: part of the Store interface.
2890
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
2891
        []*models.ChannelEdgeInfo, error) {
×
2892

×
2893
        ctx := context.TODO()
×
2894

×
2895
        var (
×
2896
                // Every channel having a ShortChannelID starting at 'height'
×
2897
                // will no longer be confirmed.
×
2898
                startShortChanID = lnwire.ShortChannelID{
×
2899
                        BlockHeight: height,
×
2900
                }
×
2901

×
2902
                // Delete everything after this height from the db up until the
×
2903
                // SCID alias range.
×
2904
                endShortChanID = aliasmgr.StartingAlias
×
2905

×
2906
                removedChans []*models.ChannelEdgeInfo
×
2907

×
2908
                chanIDStart = channelIDToBytes(startShortChanID.ToUint64())
×
2909
                chanIDEnd   = channelIDToBytes(endShortChanID.ToUint64())
×
2910
        )
×
2911

×
2912
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2913
                rows, err := db.GetChannelsBySCIDRange(
×
2914
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
2915
                                StartScid: chanIDStart,
×
2916
                                EndScid:   chanIDEnd,
×
2917
                        },
×
2918
                )
×
2919
                if err != nil {
×
2920
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2921
                }
×
2922

2923
                if len(rows) == 0 {
×
2924
                        // No channels to disconnect, but still clean up prune
×
2925
                        // log.
×
2926
                        return db.DeletePruneLogEntriesInRange(
×
2927
                                ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2928
                                        StartHeight: int64(height),
×
2929
                                        EndHeight: int64(
×
2930
                                                endShortChanID.BlockHeight,
×
2931
                                        ),
×
2932
                                },
×
2933
                        )
×
2934
                }
×
2935

2936
                // Batch build all channel edges for disconnection.
2937
                channelEdges, chanIDsToDelete, err := batchBuildChannelInfo(
×
2938
                        ctx, s.cfg, db, rows,
×
2939
                )
×
2940
                if err != nil {
×
2941
                        return err
×
2942
                }
×
2943

2944
                removedChans = channelEdges
×
2945

×
2946
                err = s.deleteChannels(ctx, db, chanIDsToDelete)
×
2947
                if err != nil {
×
2948
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2949
                }
×
2950

2951
                return db.DeletePruneLogEntriesInRange(
×
2952
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2953
                                StartHeight: int64(height),
×
2954
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
2955
                        },
×
2956
                )
×
2957
        }, func() {
×
2958
                removedChans = nil
×
2959
        })
×
2960
        if err != nil {
×
2961
                return nil, fmt.Errorf("unable to disconnect block at "+
×
2962
                        "height: %w", err)
×
2963
        }
×
2964

2965
        s.cacheMu.Lock()
×
2966
        for _, channel := range removedChans {
×
2967
                s.rejectCache.remove(channel.ChannelID)
×
2968
                s.chanCache.remove(channel.ChannelID)
×
2969
        }
×
2970
        s.cacheMu.Unlock()
×
2971

×
2972
        return removedChans, nil
×
2973
}
2974

2975
// AddEdgeProof sets the proof of an existing edge in the graph database.
2976
//
2977
// NOTE: part of the Store interface.
2978
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
2979
        proof *models.ChannelAuthProof) error {
×
2980

×
NEW
2981
        // For now, we only support v1 channel proofs.
×
NEW
2982
        if proof.Version != lnwire.GossipVersion1 {
×
NEW
2983
                return fmt.Errorf("only v1 channel proofs supported, got v%d",
×
NEW
2984
                        proof.Version)
×
NEW
2985
        }
×
2986

2987
        var (
×
2988
                ctx       = context.TODO()
×
2989
                scidBytes = channelIDToBytes(scid.ToUint64())
×
2990
        )
×
2991

×
2992
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2993
                res, err := db.AddV1ChannelProof(
×
2994
                        ctx, sqlc.AddV1ChannelProofParams{
×
2995
                                Scid:              scidBytes,
×
NEW
2996
                                Node1Signature:    proof.NodeSig1(),
×
NEW
2997
                                Node2Signature:    proof.NodeSig2(),
×
NEW
2998
                                Bitcoin1Signature: proof.BitcoinSig1(),
×
NEW
2999
                                Bitcoin2Signature: proof.BitcoinSig2(),
×
3000
                        },
×
3001
                )
×
3002
                if err != nil {
×
3003
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
3004
                }
×
3005

3006
                n, err := res.RowsAffected()
×
3007
                if err != nil {
×
3008
                        return err
×
3009
                }
×
3010

3011
                if n == 0 {
×
3012
                        return fmt.Errorf("no rows affected when adding edge "+
×
3013
                                "proof for SCID %v", scid)
×
3014
                } else if n > 1 {
×
3015
                        return fmt.Errorf("multiple rows affected when adding "+
×
3016
                                "edge proof for SCID %v: %d rows affected",
×
3017
                                scid, n)
×
3018
                }
×
3019

3020
                return nil
×
3021
        }, sqldb.NoOpReset)
3022
        if err != nil {
×
3023
                return fmt.Errorf("unable to add edge proof: %w", err)
×
3024
        }
×
3025

3026
        return nil
×
3027
}
3028

3029
// PutClosedScid stores a SCID for a closed channel in the database. This is so
3030
// that we can ignore channel announcements that we know to be closed without
3031
// having to validate them and fetch a block.
3032
//
3033
// NOTE: part of the Store interface.
3034
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
3035
        var (
×
3036
                ctx     = context.TODO()
×
3037
                chanIDB = channelIDToBytes(scid.ToUint64())
×
3038
        )
×
3039

×
3040
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
3041
                return db.InsertClosedChannel(ctx, chanIDB)
×
3042
        }, sqldb.NoOpReset)
×
3043
}
3044

3045
// IsClosedScid checks whether a channel identified by the passed in scid is
3046
// closed. This helps avoid having to perform expensive validation checks.
3047
//
3048
// NOTE: part of the Store interface.
3049
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
3050
        var (
×
3051
                ctx      = context.TODO()
×
3052
                isClosed bool
×
3053
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
3054
        )
×
3055
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
3056
                var err error
×
3057
                isClosed, err = db.IsClosedChannel(ctx, chanIDB)
×
3058
                if err != nil {
×
3059
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
3060
                                err)
×
3061
                }
×
3062

3063
                return nil
×
3064
        }, sqldb.NoOpReset)
3065
        if err != nil {
×
3066
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
3067
                        err)
×
3068
        }
×
3069

3070
        return isClosed, nil
×
3071
}
3072

3073
// GraphSession will provide the call-back with access to a NodeTraverser
3074
// instance which can be used to perform queries against the channel graph.
3075
//
3076
// NOTE: part of the Store interface.
3077
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error,
3078
        reset func()) error {
×
3079

×
3080
        var ctx = context.TODO()
×
3081

×
3082
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
3083
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
3084
        }, reset)
×
3085
}
3086

3087
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
3088
// read only transaction for a consistent view of the graph.
3089
type sqlNodeTraverser struct {
3090
        db    SQLQueries
3091
        chain chainhash.Hash
3092
}
3093

3094
// A compile-time assertion to ensure that sqlNodeTraverser implements the
3095
// NodeTraverser interface.
3096
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
3097

3098
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
3099
func newSQLNodeTraverser(db SQLQueries,
3100
        chain chainhash.Hash) *sqlNodeTraverser {
×
3101

×
3102
        return &sqlNodeTraverser{
×
3103
                db:    db,
×
3104
                chain: chain,
×
3105
        }
×
3106
}
×
3107

3108
// ForEachNodeDirectedChannel calls the callback for every channel of the given
3109
// node.
3110
//
3111
// NOTE: Part of the NodeTraverser interface.
3112
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
3113
        cb func(channel *DirectedChannel) error, _ func()) error {
×
3114

×
3115
        ctx := context.TODO()
×
3116

×
3117
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
3118
}
×
3119

3120
// FetchNodeFeatures returns the features of the given node. If the node is
3121
// unknown, assume no additional features are supported.
3122
//
3123
// NOTE: Part of the NodeTraverser interface.
3124
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
3125
        *lnwire.FeatureVector, error) {
×
3126

×
3127
        ctx := context.TODO()
×
3128

×
NEW
3129
        return fetchNodeFeatures(ctx, s.db, lnwire.GossipVersion1, nodePub)
×
3130
}
×
3131

3132
// forEachNodeDirectedChannel iterates through all channels of a given
3133
// node, executing the passed callback on the directed edge representing the
3134
// channel and its incoming policy. If the node is not found, no error is
3135
// returned.
3136
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
3137
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
3138

×
3139
        toNodeCallback := func() route.Vertex {
×
3140
                return nodePub
×
3141
        }
×
3142

3143
        dbID, err := db.GetNodeIDByPubKey(
×
3144
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
3145
                        Version: int16(lnwire.GossipVersion1),
×
3146
                        PubKey:  nodePub[:],
×
3147
                },
×
3148
        )
×
3149
        if errors.Is(err, sql.ErrNoRows) {
×
3150
                return nil
×
3151
        } else if err != nil {
×
3152
                return fmt.Errorf("unable to fetch node: %w", err)
×
3153
        }
×
3154

3155
        rows, err := db.ListChannelsByNodeID(
×
3156
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3157
                        Version: int16(lnwire.GossipVersion1),
×
3158
                        NodeID1: dbID,
×
3159
                },
×
3160
        )
×
3161
        if err != nil {
×
3162
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3163
        }
×
3164

3165
        // Exit early if there are no channels for this node so we don't
3166
        // do the unnecessary feature fetching.
3167
        if len(rows) == 0 {
×
3168
                return nil
×
3169
        }
×
3170

3171
        features, err := getNodeFeatures(ctx, db, dbID)
×
3172
        if err != nil {
×
3173
                return fmt.Errorf("unable to fetch node features: %w", err)
×
3174
        }
×
3175

3176
        for _, row := range rows {
×
3177
                node1, node2, err := buildNodeVertices(
×
3178
                        row.Node1Pubkey, row.Node2Pubkey,
×
3179
                )
×
3180
                if err != nil {
×
3181
                        return fmt.Errorf("unable to build node vertices: %w",
×
3182
                                err)
×
3183
                }
×
3184

3185
                edge := buildCacheableChannelInfo(
×
3186
                        row.GraphChannel.Scid, row.GraphChannel.Capacity.Int64,
×
3187
                        node1, node2,
×
3188
                )
×
3189

×
3190
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3191
                if err != nil {
×
3192
                        return err
×
3193
                }
×
3194

3195
                p1, p2, err := buildCachedChanPolicies(
×
3196
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
3197
                )
×
3198
                if err != nil {
×
3199
                        return err
×
3200
                }
×
3201

3202
                // Determine the outgoing and incoming policy for this
3203
                // channel and node combo.
3204
                outPolicy, inPolicy := p1, p2
×
3205
                if p1 != nil && node2 == nodePub {
×
3206
                        outPolicy, inPolicy = p2, p1
×
3207
                } else if p2 != nil && node1 != nodePub {
×
3208
                        outPolicy, inPolicy = p2, p1
×
3209
                }
×
3210

3211
                var cachedInPolicy *models.CachedEdgePolicy
×
3212
                if inPolicy != nil {
×
3213
                        cachedInPolicy = inPolicy
×
3214
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
3215
                        cachedInPolicy.ToNodeFeatures = features
×
3216
                }
×
3217

3218
                directedChannel := &DirectedChannel{
×
3219
                        ChannelID:    edge.ChannelID,
×
3220
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
3221
                        OtherNode:    edge.NodeKey2Bytes,
×
3222
                        Capacity:     edge.Capacity,
×
3223
                        OutPolicySet: outPolicy != nil,
×
3224
                        InPolicy:     cachedInPolicy,
×
3225
                }
×
3226
                if outPolicy != nil {
×
3227
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3228
                                directedChannel.InboundFee = fee
×
3229
                        })
×
3230
                }
3231

3232
                if nodePub == edge.NodeKey2Bytes {
×
3233
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
3234
                }
×
3235

3236
                if err := cb(directedChannel); err != nil {
×
3237
                        return err
×
3238
                }
×
3239
        }
3240

3241
        return nil
×
3242
}
3243

3244
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
3245
// and executes the provided callback for each node. It does so via pagination
3246
// along with batch loading of the node feature bits.
3247
func forEachNodeCacheable(ctx context.Context, cfg *sqldb.QueryConfig,
3248
        db SQLQueries, processNode func(nodeID int64, nodePub route.Vertex,
3249
                features *lnwire.FeatureVector) error) error {
×
3250

×
3251
        handleNode := func(_ context.Context,
×
3252
                dbNode sqlc.ListNodeIDsAndPubKeysRow,
×
3253
                featureBits map[int64][]int) error {
×
3254

×
3255
                fv := lnwire.EmptyFeatureVector()
×
3256
                if features, exists := featureBits[dbNode.ID]; exists {
×
3257
                        for _, bit := range features {
×
3258
                                fv.Set(lnwire.FeatureBit(bit))
×
3259
                        }
×
3260
                }
3261

3262
                var pub route.Vertex
×
3263
                copy(pub[:], dbNode.PubKey)
×
3264

×
3265
                return processNode(dbNode.ID, pub, fv)
×
3266
        }
3267

3268
        queryFunc := func(ctx context.Context, lastID int64,
×
3269
                limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
3270

×
3271
                return db.ListNodeIDsAndPubKeys(
×
3272
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
3273
                                Version: int16(lnwire.GossipVersion1),
×
3274
                                ID:      lastID,
×
3275
                                Limit:   limit,
×
3276
                        },
×
3277
                )
×
3278
        }
×
3279

3280
        extractCursor := func(row sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
3281
                return row.ID
×
3282
        }
×
3283

3284
        collectFunc := func(node sqlc.ListNodeIDsAndPubKeysRow) (int64, error) {
×
3285
                return node.ID, nil
×
3286
        }
×
3287

3288
        batchQueryFunc := func(ctx context.Context,
×
3289
                nodeIDs []int64) (map[int64][]int, error) {
×
3290

×
3291
                return batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
3292
        }
×
3293

3294
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
3295
                ctx, cfg, int64(-1), queryFunc, extractCursor, collectFunc,
×
3296
                batchQueryFunc, handleNode,
×
3297
        )
×
3298
}
3299

3300
// forEachNodeChannel iterates through all channels of a node, executing
3301
// the passed callback on each. The call-back is provided with the channel's
3302
// edge information, the outgoing policy and the incoming policy for the
3303
// channel and node combo.
3304
func forEachNodeChannel(ctx context.Context, db SQLQueries,
3305
        cfg *SQLStoreConfig, id int64, cb func(*models.ChannelEdgeInfo,
3306
                *models.ChannelEdgePolicy,
3307
                *models.ChannelEdgePolicy) error) error {
×
3308

×
3309
        // Get all the V1 channels for this node.
×
3310
        rows, err := db.ListChannelsByNodeID(
×
3311
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3312
                        Version: int16(lnwire.GossipVersion1),
×
3313
                        NodeID1: id,
×
3314
                },
×
3315
        )
×
3316
        if err != nil {
×
3317
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3318
        }
×
3319

3320
        // Collect all the channel and policy IDs.
3321
        var (
×
3322
                chanIDs   = make([]int64, 0, len(rows))
×
3323
                policyIDs = make([]int64, 0, 2*len(rows))
×
3324
        )
×
3325
        for _, row := range rows {
×
3326
                chanIDs = append(chanIDs, row.GraphChannel.ID)
×
3327

×
3328
                if row.Policy1ID.Valid {
×
3329
                        policyIDs = append(policyIDs, row.Policy1ID.Int64)
×
3330
                }
×
3331
                if row.Policy2ID.Valid {
×
3332
                        policyIDs = append(policyIDs, row.Policy2ID.Int64)
×
3333
                }
×
3334
        }
3335

3336
        batchData, err := batchLoadChannelData(
×
3337
                ctx, cfg.QueryCfg, db, chanIDs, policyIDs,
×
3338
        )
×
3339
        if err != nil {
×
3340
                return fmt.Errorf("unable to batch load channel data: %w", err)
×
3341
        }
×
3342

3343
        // Call the call-back for each channel and its known policies.
3344
        for _, row := range rows {
×
3345
                node1, node2, err := buildNodeVertices(
×
3346
                        row.Node1Pubkey, row.Node2Pubkey,
×
3347
                )
×
3348
                if err != nil {
×
3349
                        return fmt.Errorf("unable to build node vertices: %w",
×
3350
                                err)
×
3351
                }
×
3352

3353
                edge, err := buildEdgeInfoWithBatchData(
×
3354
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
3355
                        batchData,
×
3356
                )
×
3357
                if err != nil {
×
3358
                        return fmt.Errorf("unable to build channel info: %w",
×
3359
                                err)
×
3360
                }
×
3361

3362
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3363
                if err != nil {
×
3364
                        return fmt.Errorf("unable to extract channel "+
×
3365
                                "policies: %w", err)
×
3366
                }
×
3367

3368
                p1, p2, err := buildChanPoliciesWithBatchData(
×
3369
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
3370
                )
×
3371
                if err != nil {
×
3372
                        return fmt.Errorf("unable to build channel "+
×
3373
                                "policies: %w", err)
×
3374
                }
×
3375

3376
                // Determine the outgoing and incoming policy for this
3377
                // channel and node combo.
3378
                p1ToNode := row.GraphChannel.NodeID2
×
3379
                p2ToNode := row.GraphChannel.NodeID1
×
3380
                outPolicy, inPolicy := p1, p2
×
3381
                if (p1 != nil && p1ToNode == id) ||
×
3382
                        (p2 != nil && p2ToNode != id) {
×
3383

×
3384
                        outPolicy, inPolicy = p2, p1
×
3385
                }
×
3386

3387
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3388
                        return err
×
3389
                }
×
3390
        }
3391

3392
        return nil
×
3393
}
3394

3395
// updateChanEdgePolicy upserts the channel policy info we have stored for
3396
// a channel we already know of.
3397
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3398
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
3399
        error) {
×
3400

×
3401
        var (
×
3402
                node1Pub, node2Pub route.Vertex
×
3403
                isNode1            bool
×
3404
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3405
        )
×
3406

×
3407
        // Check that this edge policy refers to a channel that we already
×
3408
        // know of. We do this explicitly so that we can return the appropriate
×
3409
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
3410
        // abort the transaction which would abort the entire batch.
×
3411
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3412
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3413
                        Scid:    chanIDB,
×
3414
                        Version: int16(lnwire.GossipVersion1),
×
3415
                },
×
3416
        )
×
3417
        if errors.Is(err, sql.ErrNoRows) {
×
3418
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3419
        } else if err != nil {
×
3420
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3421
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3422
        }
×
3423

3424
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3425
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3426

×
3427
        // Figure out which node this edge is from.
×
3428
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3429
        nodeID := dbChan.NodeID1
×
3430
        if !isNode1 {
×
3431
                nodeID = dbChan.NodeID2
×
3432
        }
×
3433

3434
        var (
×
3435
                inboundBase sql.NullInt64
×
3436
                inboundRate sql.NullInt64
×
3437
        )
×
3438
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3439
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3440
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3441
        })
×
3442

3443
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
3444
                Version:     int16(lnwire.GossipVersion1),
×
3445
                ChannelID:   dbChan.ID,
×
3446
                NodeID:      nodeID,
×
3447
                Timelock:    int32(edge.TimeLockDelta),
×
3448
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
3449
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
3450
                MinHtlcMsat: int64(edge.MinHTLC),
×
3451
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3452
                Disabled: sql.NullBool{
×
3453
                        Valid: true,
×
3454
                        Bool:  edge.IsDisabled(),
×
3455
                },
×
3456
                MaxHtlcMsat: sql.NullInt64{
×
3457
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3458
                        Int64: int64(edge.MaxHTLC),
×
3459
                },
×
3460
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
3461
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
3462
                InboundBaseFeeMsat:      inboundBase,
×
3463
                InboundFeeRateMilliMsat: inboundRate,
×
3464
                Signature:               edge.SigBytes,
×
3465
        })
×
3466
        if err != nil {
×
3467
                return node1Pub, node2Pub, isNode1,
×
3468
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3469
        }
×
3470

3471
        // Convert the flat extra opaque data into a map of TLV types to
3472
        // values.
3473
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3474
        if err != nil {
×
3475
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3476
                        "marshal extra opaque data: %w", err)
×
3477
        }
×
3478

3479
        // Update the channel policy's extra signed fields.
3480
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3481
        if err != nil {
×
3482
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3483
                        "policy extra TLVs: %w", err)
×
3484
        }
×
3485

3486
        return node1Pub, node2Pub, isNode1, nil
×
3487
}
3488

3489
// getNodeByPubKey attempts to look up a target node by its public key.
3490
func getNodeByPubKey(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries,
3491
        v lnwire.GossipVersion, pubKey route.Vertex) (int64, *models.Node,
NEW
3492
        error) {
×
3493

×
3494
        dbNode, err := db.GetNodeByPubKey(
×
3495
                ctx, sqlc.GetNodeByPubKeyParams{
×
NEW
3496
                        Version: int16(v),
×
3497
                        PubKey:  pubKey[:],
×
3498
                },
×
3499
        )
×
3500
        if errors.Is(err, sql.ErrNoRows) {
×
3501
                return 0, nil, ErrGraphNodeNotFound
×
3502
        } else if err != nil {
×
3503
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3504
        }
×
3505

3506
        node, err := buildNode(ctx, cfg, db, dbNode)
×
3507
        if err != nil {
×
3508
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3509
        }
×
3510

3511
        return dbNode.ID, node, nil
×
3512
}
3513

3514
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3515
// provided parameters.
3516
func buildCacheableChannelInfo(scid []byte, capacity int64, node1Pub,
3517
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
3518

×
3519
        return &models.CachedEdgeInfo{
×
3520
                ChannelID:     byteOrder.Uint64(scid),
×
3521
                NodeKey1Bytes: node1Pub,
×
3522
                NodeKey2Bytes: node2Pub,
×
3523
                Capacity:      btcutil.Amount(capacity),
×
3524
        }
×
3525
}
×
3526

3527
// buildNode constructs a Node instance from the given database node
3528
// record. The node's features, addresses and extra signed fields are also
3529
// fetched from the database and set on the node.
3530
func buildNode(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries,
3531
        dbNode sqlc.GraphNode) (*models.Node, error) {
×
3532

×
3533
        data, err := batchLoadNodeData(ctx, cfg, db, []int64{dbNode.ID})
×
3534
        if err != nil {
×
3535
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
3536
                        err)
×
3537
        }
×
3538

3539
        return buildNodeWithBatchData(dbNode, data)
×
3540
}
3541

3542
// isKnownGossipVersion checks whether the provided gossip version is known
3543
// and supported.
NEW
3544
func isKnownGossipVersion(v lnwire.GossipVersion) bool {
×
NEW
3545
        switch v {
×
NEW
3546
        case lnwire.GossipVersion1:
×
NEW
3547
                return true
×
NEW
3548
        case lnwire.GossipVersion2:
×
NEW
3549
                return true
×
NEW
3550
        default:
×
NEW
3551
                return false
×
3552
        }
3553
}
3554

3555
// buildNodeWithBatchData builds a models.Node instance
3556
// from the provided sqlc.GraphNode and batchNodeData. If the node does have
3557
// features/addresses/extra fields, then the corresponding fields are expected
3558
// to be present in the batchNodeData.
3559
func buildNodeWithBatchData(dbNode sqlc.GraphNode,
3560
        batchData *batchNodeData) (*models.Node, error) {
×
3561

×
NEW
3562
        v := lnwire.GossipVersion(dbNode.Version)
×
NEW
3563

×
NEW
3564
        if !isKnownGossipVersion(v) {
×
NEW
3565
                return nil, fmt.Errorf("unknown node version: %d", v)
×
UNCOV
3566
        }
×
3567

NEW
3568
        pub, err := route.NewVertexFromBytes(dbNode.PubKey)
×
NEW
3569
        if err != nil {
×
NEW
3570
                return nil, fmt.Errorf("unable to parse pubkey: %w", err)
×
NEW
3571
        }
×
3572

NEW
3573
        node := models.NewShellNode(v, pub)
×
3574

×
3575
        if len(dbNode.Signature) == 0 {
×
3576
                return node, nil
×
3577
        }
×
3578

3579
        node.AuthSigBytes = dbNode.Signature
×
3580

×
3581
        if dbNode.Alias.Valid {
×
3582
                node.Alias = fn.Some(dbNode.Alias.String)
×
3583
        }
×
3584
        if dbNode.LastUpdate.Valid {
×
3585
                node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
3586
        }
×
NEW
3587
        if dbNode.BlockHeight.Valid {
×
NEW
3588
                node.LastBlockHeight = uint32(dbNode.BlockHeight.Int64)
×
NEW
3589
        }
×
3590

3591
        if dbNode.Color.Valid {
×
3592
                nodeColor, err := DecodeHexColor(dbNode.Color.String)
×
3593
                if err != nil {
×
3594
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3595
                                err)
×
3596
                }
×
3597

3598
                node.Color = fn.Some(nodeColor)
×
3599
        }
3600

3601
        // Use preloaded features.
3602
        if features, exists := batchData.features[dbNode.ID]; exists {
×
3603
                fv := lnwire.EmptyFeatureVector()
×
3604
                for _, bit := range features {
×
3605
                        fv.Set(lnwire.FeatureBit(bit))
×
3606
                }
×
3607
                node.Features = fv
×
3608
        }
3609

3610
        // Use preloaded addresses.
3611
        addresses, exists := batchData.addresses[dbNode.ID]
×
3612
        if exists && len(addresses) > 0 {
×
3613
                node.Addresses, err = buildNodeAddresses(addresses)
×
3614
                if err != nil {
×
3615
                        return nil, fmt.Errorf("unable to build addresses "+
×
3616
                                "for node(%d): %w", dbNode.ID, err)
×
3617
                }
×
3618
        }
3619

3620
        // Use preloaded extra fields.
3621
        if extraFields, exists := batchData.extraFields[dbNode.ID]; exists {
×
NEW
3622
                if v == lnwire.GossipVersion1 {
×
NEW
3623
                        records := lnwire.CustomRecords(extraFields)
×
NEW
3624
                        recs, err := records.Serialize()
×
NEW
3625
                        if err != nil {
×
NEW
3626
                                return nil, fmt.Errorf("unable to serialize "+
×
NEW
3627
                                        "extra signed fields: %w", err)
×
NEW
3628
                        }
×
3629

NEW
3630
                        if len(recs) != 0 {
×
NEW
3631
                                node.ExtraOpaqueData = recs
×
NEW
3632
                        }
×
NEW
3633
                } else if len(extraFields) > 0 {
×
NEW
3634
                        node.ExtraSignedFields = extraFields
×
UNCOV
3635
                }
×
3636
        }
3637

3638
        return node, nil
×
3639
}
3640

3641
// forEachNodeInBatch fetches all nodes in the provided batch, builds them
3642
// with the preloaded data, and executes the provided callback for each node.
3643
func forEachNodeInBatch(ctx context.Context, cfg *sqldb.QueryConfig,
3644
        db SQLQueries, nodes []sqlc.GraphNode,
3645
        cb func(dbID int64, node *models.Node) error) error {
×
3646

×
3647
        // Extract node IDs for batch loading.
×
3648
        nodeIDs := make([]int64, len(nodes))
×
3649
        for i, node := range nodes {
×
3650
                nodeIDs[i] = node.ID
×
3651
        }
×
3652

3653
        // Batch load all related data for this page.
3654
        batchData, err := batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
3655
        if err != nil {
×
3656
                return fmt.Errorf("unable to batch load node data: %w", err)
×
3657
        }
×
3658

3659
        for _, dbNode := range nodes {
×
3660
                node, err := buildNodeWithBatchData(dbNode, batchData)
×
3661
                if err != nil {
×
3662
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
3663
                                dbNode.ID, err)
×
3664
                }
×
3665

3666
                if err := cb(dbNode.ID, node); err != nil {
×
3667
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
3668
                                dbNode.ID, err)
×
3669
                }
×
3670
        }
3671

3672
        return nil
×
3673
}
3674

3675
// getNodeFeatures fetches the feature bits and constructs the feature vector
3676
// for a node with the given DB ID.
3677
func getNodeFeatures(ctx context.Context, db SQLQueries,
3678
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3679

×
3680
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3681
        if err != nil {
×
3682
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3683
                        nodeID, err)
×
3684
        }
×
3685

3686
        features := lnwire.EmptyFeatureVector()
×
3687
        for _, feature := range rows {
×
3688
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3689
        }
×
3690

3691
        return features, nil
×
3692
}
3693

3694
// upsertNodeAncillaryData updates the node's features, addresses, and extra
3695
// signed fields. This is common logic shared by upsertNode and
3696
// upsertSourceNode.
3697
func upsertNodeAncillaryData(ctx context.Context, db SQLQueries,
3698
        nodeID int64, node *models.Node) error {
×
3699

×
3700
        // Update the node's features.
×
3701
        err := upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3702
        if err != nil {
×
3703
                return fmt.Errorf("inserting node features: %w", err)
×
3704
        }
×
3705

3706
        // Update the node's addresses.
3707
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3708
        if err != nil {
×
3709
                return fmt.Errorf("inserting node addresses: %w", err)
×
3710
        }
×
3711

3712
        // Convert the flat extra opaque data into a map of TLV types to
3713
        // values.
NEW
3714
        extra := node.ExtraSignedFields
×
NEW
3715
        if node.Version == lnwire.GossipVersion1 {
×
NEW
3716
                extra, err = marshalExtraOpaqueData(node.ExtraOpaqueData)
×
NEW
3717
                if err != nil {
×
NEW
3718
                        return fmt.Errorf("unable to marshal extra opaque "+
×
NEW
3719
                                "data: %w", err)
×
NEW
3720
                }
×
3721
        }
3722

3723
        // Update the node's extra signed fields.
3724
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3725
        if err != nil {
×
3726
                return fmt.Errorf("inserting node extra TLVs: %w", err)
×
3727
        }
×
3728

3729
        return nil
×
3730
}
3731

3732
// populateNodeParams populates the common node parameters from a models.Node.
3733
// This is a helper for building UpsertNodeParams and UpsertSourceNodeParams.
3734
func populateNodeParams(node *models.Node,
3735
        setParams func(lastUpdate, lastBlockHeight sql.NullInt64, alias,
3736
                colorStr sql.NullString, signature []byte)) error {
×
3737

×
3738
        if !node.HaveAnnouncement() {
×
3739
                return nil
×
3740
        }
×
3741

NEW
3742
        var (
×
NEW
3743
                alias, colorStr             sql.NullString
×
NEW
3744
                lastUpdate, lastBlockHeight sql.NullInt64
×
NEW
3745
        )
×
NEW
3746
        node.Color.WhenSome(func(rgba color.RGBA) {
×
NEW
3747
                colorStr = sqldb.SQLStrValid(EncodeHexColor(rgba))
×
NEW
3748
        })
×
NEW
3749
        node.Alias.WhenSome(func(s string) {
×
NEW
3750
                alias = sqldb.SQLStrValid(s)
×
NEW
3751
        })
×
3752

3753
        switch node.Version {
×
3754
        case lnwire.GossipVersion1:
×
NEW
3755
                lastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
3756

3757
        case lnwire.GossipVersion2:
×
NEW
3758
                lastBlockHeight = sqldb.SQLInt64(int64(node.LastBlockHeight))
×
3759

3760
        default:
×
3761
                return fmt.Errorf("unknown gossip version: %d", node.Version)
×
3762
        }
3763

NEW
3764
        setParams(
×
NEW
3765
                lastUpdate, lastBlockHeight, alias, colorStr, node.AuthSigBytes,
×
NEW
3766
        )
×
NEW
3767

×
UNCOV
3768
        return nil
×
3769
}
3770

3771
// buildNodeUpsertParams builds the parameters for upserting a node using the
3772
// strict UpsertNode query (requires timestamp to be increasing).
3773
func buildNodeUpsertParams(node *models.Node) (sqlc.UpsertNodeParams, error) {
×
3774
        params := sqlc.UpsertNodeParams{
×
NEW
3775
                Version: int16(node.Version),
×
3776
                PubKey:  node.PubKeyBytes[:],
×
3777
        }
×
3778

×
3779
        err := populateNodeParams(
×
NEW
3780
                node, func(lastUpdate, lastBlockHeight sql.NullInt64, alias,
×
3781
                        colorStr sql.NullString,
×
3782
                        signature []byte) {
×
3783

×
3784
                        params.LastUpdate = lastUpdate
×
NEW
3785
                        params.BlockHeight = lastBlockHeight
×
3786
                        params.Alias = alias
×
3787
                        params.Color = colorStr
×
3788
                        params.Signature = signature
×
NEW
3789
                },
×
3790
        )
3791

3792
        return params, err
×
3793
}
3794

3795
// buildSourceNodeUpsertParams builds the parameters for upserting the source
3796
// node using the lenient UpsertSourceNode query (allows same timestamp).
3797
func buildSourceNodeUpsertParams(node *models.Node) (
3798
        sqlc.UpsertSourceNodeParams, error) {
×
3799

×
3800
        params := sqlc.UpsertSourceNodeParams{
×
NEW
3801
                Version: int16(node.Version),
×
3802
                PubKey:  node.PubKeyBytes[:],
×
3803
        }
×
3804

×
3805
        err := populateNodeParams(
×
NEW
3806
                node, func(lastUpdate, lastBlock sql.NullInt64, alias,
×
3807
                        colorStr sql.NullString, signature []byte) {
×
3808

×
NEW
3809
                        params.BlockHeight = lastBlock
×
3810
                        params.LastUpdate = lastUpdate
×
3811
                        params.Alias = alias
×
3812
                        params.Color = colorStr
×
3813
                        params.Signature = signature
×
3814
                },
×
3815
        )
3816

3817
        return params, err
×
3818
}
3819

3820
// upsertSourceNode upserts the source node record into the database using a
3821
// less strict upsert that allows updates even when the timestamp hasn't
3822
// changed. This is necessary to handle concurrent updates to our own node
3823
// during startup and runtime. The node's features, addresses and extra TLV
3824
// types are also updated. The node's DB ID is returned.
3825
func upsertSourceNode(ctx context.Context, db SQLQueries,
3826
        node *models.Node) (int64, error) {
×
3827

×
3828
        params, err := buildSourceNodeUpsertParams(node)
×
3829
        if err != nil {
×
3830
                return 0, err
×
3831
        }
×
3832

3833
        nodeID, err := db.UpsertSourceNode(ctx, params)
×
3834
        if err != nil {
×
3835
                return 0, fmt.Errorf("upserting source node(%x): %w",
×
3836
                        node.PubKeyBytes, err)
×
3837
        }
×
3838

3839
        // We can exit here if we don't have the announcement yet.
3840
        if !node.HaveAnnouncement() {
×
3841
                return nodeID, nil
×
3842
        }
×
3843

3844
        // Update the ancillary node data (features, addresses, extra fields).
3845
        err = upsertNodeAncillaryData(ctx, db, nodeID, node)
×
3846
        if err != nil {
×
3847
                return 0, err
×
3848
        }
×
3849

3850
        return nodeID, nil
×
3851
}
3852

3853
// upsertNode upserts the node record into the database. If the node already
3854
// exists, then the node's information is updated. If the node doesn't exist,
3855
// then a new node is created. The node's features, addresses and extra TLV
3856
// types are also updated. The node's DB ID is returned.
3857
func upsertNode(ctx context.Context, db SQLQueries,
3858
        node *models.Node) (int64, error) {
×
3859

×
NEW
3860
        if !isKnownGossipVersion(node.Version) {
×
NEW
3861
                return 0, fmt.Errorf("unknown gossip version: %d", node.Version)
×
NEW
3862
        }
×
3863

3864
        params, err := buildNodeUpsertParams(node)
×
3865
        if err != nil {
×
3866
                return 0, err
×
3867
        }
×
3868

3869
        nodeID, err := db.UpsertNode(ctx, params)
×
3870
        if err != nil {
×
3871
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3872
                        err)
×
3873
        }
×
3874

3875
        // We can exit here if we don't have the announcement yet.
3876
        if !node.HaveAnnouncement() {
×
3877
                return nodeID, nil
×
3878
        }
×
3879

3880
        // Update the ancillary node data (features, addresses, extra fields).
3881
        err = upsertNodeAncillaryData(ctx, db, nodeID, node)
×
3882
        if err != nil {
×
3883
                return 0, err
×
3884
        }
×
3885

3886
        return nodeID, nil
×
3887
}
3888

3889
// upsertNodeFeatures updates the node's features node_features table. This
3890
// includes deleting any feature bits no longer present and inserting any new
3891
// feature bits. If the feature bit does not yet exist in the features table,
3892
// then an entry is created in that table first.
3893
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3894
        features *lnwire.FeatureVector) error {
×
3895

×
3896
        // Get any existing features for the node.
×
3897
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3898
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3899
                return err
×
3900
        }
×
3901

3902
        // Copy the nodes latest set of feature bits.
3903
        newFeatures := make(map[int32]struct{})
×
3904
        if features != nil {
×
3905
                for feature := range features.Features() {
×
3906
                        newFeatures[int32(feature)] = struct{}{}
×
3907
                }
×
3908
        }
3909

3910
        // For any current feature that already exists in the DB, remove it from
3911
        // the in-memory map. For any existing feature that does not exist in
3912
        // the in-memory map, delete it from the database.
3913
        for _, feature := range existingFeatures {
×
3914
                // The feature is still present, so there are no updates to be
×
3915
                // made.
×
3916
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3917
                        delete(newFeatures, feature.FeatureBit)
×
3918
                        continue
×
3919
                }
3920

3921
                // The feature is no longer present, so we remove it from the
3922
                // database.
3923
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
3924
                        NodeID:     nodeID,
×
3925
                        FeatureBit: feature.FeatureBit,
×
3926
                })
×
3927
                if err != nil {
×
3928
                        return fmt.Errorf("unable to delete node(%d) "+
×
3929
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3930
                                err)
×
3931
                }
×
3932
        }
3933

3934
        // Any remaining entries in newFeatures are new features that need to be
3935
        // added to the database for the first time.
3936
        for feature := range newFeatures {
×
3937
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3938
                        NodeID:     nodeID,
×
3939
                        FeatureBit: feature,
×
3940
                })
×
3941
                if err != nil {
×
3942
                        return fmt.Errorf("unable to insert node(%d) "+
×
3943
                                "feature(%v): %w", nodeID, feature, err)
×
3944
                }
×
3945
        }
3946

3947
        return nil
×
3948
}
3949

3950
// fetchNodeFeatures fetches the features for a node with the given public key.
3951
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3952
        v lnwire.GossipVersion, nodePub route.Vertex) (*lnwire.FeatureVector,
NEW
3953
        error) {
×
3954

×
3955
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3956
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3957
                        PubKey:  nodePub[:],
×
NEW
3958
                        Version: int16(v),
×
3959
                },
×
3960
        )
×
3961
        if err != nil {
×
3962
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
3963
                        nodePub, err)
×
3964
        }
×
3965

3966
        features := lnwire.EmptyFeatureVector()
×
3967
        for _, bit := range rows {
×
3968
                features.Set(lnwire.FeatureBit(bit))
×
3969
        }
×
3970

3971
        return features, nil
×
3972
}
3973

3974
// dbAddressType is an enum type that represents the different address types
3975
// that we store in the node_addresses table. The address type determines how
3976
// the address is to be serialised/deserialize.
3977
type dbAddressType uint8
3978

3979
const (
3980
        addressTypeIPv4   dbAddressType = 1
3981
        addressTypeIPv6   dbAddressType = 2
3982
        addressTypeTorV2  dbAddressType = 3
3983
        addressTypeTorV3  dbAddressType = 4
3984
        addressTypeDNS    dbAddressType = 5
3985
        addressTypeOpaque dbAddressType = math.MaxInt8
3986
)
3987

3988
// collectAddressRecords collects the addresses from the provided
3989
// net.Addr slice and returns a map of dbAddressType to a slice of address
3990
// strings.
3991
func collectAddressRecords(addresses []net.Addr) (map[dbAddressType][]string,
3992
        error) {
×
3993

×
3994
        // Copy the nodes latest set of addresses.
×
3995
        newAddresses := map[dbAddressType][]string{
×
3996
                addressTypeIPv4:   {},
×
3997
                addressTypeIPv6:   {},
×
3998
                addressTypeTorV2:  {},
×
3999
                addressTypeTorV3:  {},
×
4000
                addressTypeDNS:    {},
×
4001
                addressTypeOpaque: {},
×
4002
        }
×
4003
        addAddr := func(t dbAddressType, addr net.Addr) {
×
4004
                newAddresses[t] = append(newAddresses[t], addr.String())
×
4005
        }
×
4006

4007
        for _, address := range addresses {
×
4008
                switch addr := address.(type) {
×
4009
                case *net.TCPAddr:
×
4010
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
4011
                                addAddr(addressTypeIPv4, addr)
×
4012
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
4013
                                addAddr(addressTypeIPv6, addr)
×
4014
                        } else {
×
4015
                                return nil, fmt.Errorf("unhandled IP "+
×
4016
                                        "address: %v", addr)
×
4017
                        }
×
4018

4019
                case *tor.OnionAddr:
×
4020
                        switch len(addr.OnionService) {
×
4021
                        case tor.V2Len:
×
4022
                                addAddr(addressTypeTorV2, addr)
×
4023
                        case tor.V3Len:
×
4024
                                addAddr(addressTypeTorV3, addr)
×
4025
                        default:
×
4026
                                return nil, fmt.Errorf("invalid length for " +
×
4027
                                        "a tor address")
×
4028
                        }
4029

4030
                case *lnwire.DNSAddress:
×
4031
                        addAddr(addressTypeDNS, addr)
×
4032

4033
                case *lnwire.OpaqueAddrs:
×
4034
                        addAddr(addressTypeOpaque, addr)
×
4035

4036
                default:
×
4037
                        return nil, fmt.Errorf("unhandled address type: %T",
×
4038
                                addr)
×
4039
                }
4040
        }
4041

4042
        return newAddresses, nil
×
4043
}
4044

4045
// upsertNodeAddresses updates the node's addresses in the database. This
4046
// includes deleting any existing addresses and inserting the new set of
4047
// addresses. The deletion is necessary since the ordering of the addresses may
4048
// change, and we need to ensure that the database reflects the latest set of
4049
// addresses so that at the time of reconstructing the node announcement, the
4050
// order is preserved and the signature over the message remains valid.
4051
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
4052
        addresses []net.Addr) error {
×
4053

×
4054
        // Delete any existing addresses for the node. This is required since
×
4055
        // even if the new set of addresses is the same, the ordering may have
×
4056
        // changed for a given address type.
×
4057
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
4058
        if err != nil {
×
4059
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
4060
                        nodeID, err)
×
4061
        }
×
4062

4063
        newAddresses, err := collectAddressRecords(addresses)
×
4064
        if err != nil {
×
4065
                return err
×
4066
        }
×
4067

4068
        // Any remaining entries in newAddresses are new addresses that need to
4069
        // be added to the database for the first time.
4070
        for addrType, addrList := range newAddresses {
×
4071
                for position, addr := range addrList {
×
4072
                        err := db.UpsertNodeAddress(
×
4073
                                ctx, sqlc.UpsertNodeAddressParams{
×
4074
                                        NodeID:   nodeID,
×
4075
                                        Type:     int16(addrType),
×
4076
                                        Address:  addr,
×
4077
                                        Position: int32(position),
×
4078
                                },
×
4079
                        )
×
4080
                        if err != nil {
×
4081
                                return fmt.Errorf("unable to insert "+
×
4082
                                        "node(%d) address(%v): %w", nodeID,
×
4083
                                        addr, err)
×
4084
                        }
×
4085
                }
4086
        }
4087

4088
        return nil
×
4089
}
4090

4091
// getNodeAddresses fetches the addresses for a node with the given DB ID.
4092
func getNodeAddresses(ctx context.Context, db SQLQueries, id int64) ([]net.Addr,
4093
        error) {
×
4094

×
4095
        // GetNodeAddresses ensures that the addresses for a given type are
×
4096
        // returned in the same order as they were inserted.
×
4097
        rows, err := db.GetNodeAddresses(ctx, id)
×
4098
        if err != nil {
×
4099
                return nil, err
×
4100
        }
×
4101

4102
        addresses := make([]net.Addr, 0, len(rows))
×
4103
        for _, row := range rows {
×
4104
                address := row.Address
×
4105

×
4106
                addr, err := parseAddress(dbAddressType(row.Type), address)
×
4107
                if err != nil {
×
4108
                        return nil, fmt.Errorf("unable to parse address "+
×
4109
                                "for node(%d): %v: %w", id, address, err)
×
4110
                }
×
4111

4112
                addresses = append(addresses, addr)
×
4113
        }
4114

4115
        // If we have no addresses, then we'll return nil instead of an
4116
        // empty slice.
4117
        if len(addresses) == 0 {
×
4118
                addresses = nil
×
4119
        }
×
4120

4121
        return addresses, nil
×
4122
}
4123

4124
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
4125
// database. This includes updating any existing types, inserting any new types,
4126
// and deleting any types that are no longer present.
4127
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
4128
        nodeID int64, extraFields map[uint64][]byte) error {
×
4129

×
4130
        // Get any existing extra signed fields for the node.
×
4131
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
4132
        if err != nil {
×
4133
                return err
×
4134
        }
×
4135

4136
        // Make a lookup map of the existing field types so that we can use it
4137
        // to keep track of any fields we should delete.
4138
        m := make(map[uint64]bool)
×
4139
        for _, field := range existingFields {
×
4140
                m[uint64(field.Type)] = true
×
4141
        }
×
4142

4143
        // For all the new fields, we'll upsert them and remove them from the
4144
        // map of existing fields.
4145
        for tlvType, value := range extraFields {
×
4146
                err = db.UpsertNodeExtraType(
×
4147
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
4148
                                NodeID: nodeID,
×
4149
                                Type:   int64(tlvType),
×
4150
                                Value:  value,
×
4151
                        },
×
4152
                )
×
4153
                if err != nil {
×
4154
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
4155
                                "signed field(%v): %w", nodeID, tlvType, err)
×
4156
                }
×
4157

4158
                // Remove the field from the map of existing fields if it was
4159
                // present.
4160
                delete(m, tlvType)
×
4161
        }
4162

4163
        // For all the fields that are left in the map of existing fields, we'll
4164
        // delete them as they are no longer present in the new set of fields.
4165
        for tlvType := range m {
×
4166
                err = db.DeleteExtraNodeType(
×
4167
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
4168
                                NodeID: nodeID,
×
4169
                                Type:   int64(tlvType),
×
4170
                        },
×
4171
                )
×
4172
                if err != nil {
×
4173
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
4174
                                "signed field(%v): %w", nodeID, tlvType, err)
×
4175
                }
×
4176
        }
4177

4178
        return nil
×
4179
}
4180

4181
// srcNodeInfo holds the information about the source node of the graph.
4182
type srcNodeInfo struct {
4183
        // id is the DB level ID of the source node entry in the "nodes" table.
4184
        id int64
4185

4186
        // pub is the public key of the source node.
4187
        pub route.Vertex
4188
}
4189

4190
// sourceNode returns the DB node ID and pub key of the source node for the
4191
// specified protocol version.
4192
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
4193
        version lnwire.GossipVersion) (int64, route.Vertex, error) {
×
4194

×
4195
        s.srcNodeMu.Lock()
×
4196
        defer s.srcNodeMu.Unlock()
×
4197

×
4198
        // If we already have the source node ID and pub key cached, then
×
4199
        // return them.
×
4200
        if info, ok := s.srcNodes[version]; ok {
×
4201
                return info.id, info.pub, nil
×
4202
        }
×
4203

4204
        var pubKey route.Vertex
×
4205

×
4206
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
4207
        if err != nil {
×
4208
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
4209
                        err)
×
4210
        }
×
4211

4212
        if len(nodes) == 0 {
×
4213
                return 0, pubKey, ErrSourceNodeNotSet
×
4214
        } else if len(nodes) > 1 {
×
4215
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
4216
                        "protocol %s found", version)
×
4217
        }
×
4218

4219
        copy(pubKey[:], nodes[0].PubKey)
×
4220

×
4221
        s.srcNodes[version] = &srcNodeInfo{
×
4222
                id:  nodes[0].NodeID,
×
4223
                pub: pubKey,
×
4224
        }
×
4225

×
4226
        return nodes[0].NodeID, pubKey, nil
×
4227
}
4228

4229
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
4230
// This then produces a map from TLV type to value. If the input is not a
4231
// valid TLV stream, then an error is returned.
4232
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
4233
        r := bytes.NewReader(data)
×
4234

×
4235
        tlvStream, err := tlv.NewStream()
×
4236
        if err != nil {
×
4237
                return nil, err
×
4238
        }
×
4239

4240
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
4241
        // pass it into the P2P decoding variant.
4242
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
4243
        if err != nil {
×
4244
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
4245
        }
×
4246
        if len(parsedTypes) == 0 {
×
4247
                return nil, nil
×
4248
        }
×
4249

4250
        records := make(map[uint64][]byte)
×
4251
        for k, v := range parsedTypes {
×
4252
                records[uint64(k)] = v
×
4253
        }
×
4254

4255
        return records, nil
×
4256
}
4257

4258
// insertChannel inserts a new channel record into the database.
4259
func insertChannel(ctx context.Context, db SQLQueries,
4260
        edge *models.ChannelEdgeInfo) error {
×
4261

×
NEW
4262
        v := lnwire.GossipVersion1
×
NEW
4263

×
NEW
4264
        // For now, we only support V1 channel edges in the SQL store.
×
NEW
4265
        if edge.Version != v {
×
NEW
4266
                return fmt.Errorf("only V1 channel edges supported, got V%d",
×
NEW
4267
                        edge.Version)
×
NEW
4268
        }
×
4269

4270
        // Make sure that at least a "shell" entry for each node is present in
4271
        // the nodes table.
NEW
4272
        node1DBID, err := maybeCreateShellNode(
×
NEW
4273
                ctx, db, v, edge.NodeKey1Bytes,
×
NEW
4274
        )
×
4275
        if err != nil {
×
4276
                return fmt.Errorf("unable to create shell node: %w", err)
×
4277
        }
×
4278

NEW
4279
        node2DBID, err := maybeCreateShellNode(
×
NEW
4280
                ctx, db, v, edge.NodeKey2Bytes,
×
NEW
4281
        )
×
4282
        if err != nil {
×
4283
                return fmt.Errorf("unable to create shell node: %w", err)
×
4284
        }
×
4285

4286
        var capacity sql.NullInt64
×
4287
        if edge.Capacity != 0 {
×
4288
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
4289
        }
×
4290

4291
        createParams := sqlc.CreateChannelParams{
×
NEW
4292
                Version:  int16(v),
×
NEW
4293
                Scid:     channelIDToBytes(edge.ChannelID),
×
NEW
4294
                NodeID1:  node1DBID,
×
NEW
4295
                NodeID2:  node2DBID,
×
NEW
4296
                Outpoint: edge.ChannelPoint.String(),
×
NEW
4297
                Capacity: capacity,
×
NEW
4298
        }
×
NEW
4299
        edge.BitcoinKey1Bytes.WhenSome(func(vertex route.Vertex) {
×
NEW
4300
                createParams.BitcoinKey1 = vertex[:]
×
NEW
4301
        })
×
NEW
4302
        edge.BitcoinKey2Bytes.WhenSome(func(vertex route.Vertex) {
×
NEW
4303
                createParams.BitcoinKey2 = vertex[:]
×
NEW
4304
        })
×
4305

4306
        if edge.AuthProof != nil {
×
4307
                proof := edge.AuthProof
×
4308

×
NEW
4309
                createParams.Node1Signature = proof.NodeSig1()
×
NEW
4310
                createParams.Node2Signature = proof.NodeSig2()
×
NEW
4311
                createParams.Bitcoin1Signature = proof.BitcoinSig1()
×
NEW
4312
                createParams.Bitcoin2Signature = proof.BitcoinSig2()
×
UNCOV
4313
        }
×
4314

4315
        // Insert the new channel record.
4316
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
4317
        if err != nil {
×
4318
                return err
×
4319
        }
×
4320

4321
        // Insert any channel features.
4322
        for feature := range edge.Features.Features() {
×
4323
                err = db.InsertChannelFeature(
×
4324
                        ctx, sqlc.InsertChannelFeatureParams{
×
4325
                                ChannelID:  dbChanID,
×
4326
                                FeatureBit: int32(feature),
×
4327
                        },
×
4328
                )
×
4329
                if err != nil {
×
4330
                        return fmt.Errorf("unable to insert channel(%d) "+
×
4331
                                "feature(%v): %w", dbChanID, feature, err)
×
4332
                }
×
4333
        }
4334

4335
        // Finally, insert any extra TLV fields in the channel announcement.
4336
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
4337
        if err != nil {
×
4338
                return fmt.Errorf("unable to marshal extra opaque data: %w",
×
4339
                        err)
×
4340
        }
×
4341

4342
        for tlvType, value := range extra {
×
4343
                err := db.UpsertChannelExtraType(
×
4344
                        ctx, sqlc.UpsertChannelExtraTypeParams{
×
4345
                                ChannelID: dbChanID,
×
4346
                                Type:      int64(tlvType),
×
4347
                                Value:     value,
×
4348
                        },
×
4349
                )
×
4350
                if err != nil {
×
4351
                        return fmt.Errorf("unable to upsert channel(%d) "+
×
4352
                                "extra signed field(%v): %w", edge.ChannelID,
×
4353
                                tlvType, err)
×
4354
                }
×
4355
        }
4356

4357
        return nil
×
4358
}
4359

4360
// maybeCreateShellNode checks if a shell node entry exists for the
4361
// given public key. If it does not exist, then a new shell node entry is
4362
// created. The ID of the node is returned. A shell node only has a protocol
4363
// version and public key persisted.
4364
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
NEW
4365
        v lnwire.GossipVersion, pubKey route.Vertex) (int64, error) {
×
4366

×
4367
        dbNode, err := db.GetNodeByPubKey(
×
4368
                ctx, sqlc.GetNodeByPubKeyParams{
×
4369
                        PubKey:  pubKey[:],
×
NEW
4370
                        Version: int16(v),
×
4371
                },
×
4372
        )
×
4373
        // The node exists. Return the ID.
×
4374
        if err == nil {
×
4375
                return dbNode.ID, nil
×
4376
        } else if !errors.Is(err, sql.ErrNoRows) {
×
4377
                return 0, err
×
4378
        }
×
4379

4380
        // Otherwise, the node does not exist, so we create a shell entry for
4381
        // it.
4382
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
NEW
4383
                Version: int16(v),
×
4384
                PubKey:  pubKey[:],
×
4385
        })
×
4386
        if err != nil {
×
4387
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
4388
        }
×
4389

4390
        return id, nil
×
4391
}
4392

4393
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
4394
// the database. This includes deleting any existing types and then inserting
4395
// the new types.
4396
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
4397
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
4398

×
4399
        // Delete all existing extra signed fields for the channel policy.
×
4400
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
4401
        if err != nil {
×
4402
                return fmt.Errorf("unable to delete "+
×
4403
                        "existing policy extra signed fields for policy %d: %w",
×
4404
                        chanPolicyID, err)
×
4405
        }
×
4406

4407
        // Insert all new extra signed fields for the channel policy.
4408
        for tlvType, value := range extraFields {
×
4409
                err = db.UpsertChanPolicyExtraType(
×
4410
                        ctx, sqlc.UpsertChanPolicyExtraTypeParams{
×
4411
                                ChannelPolicyID: chanPolicyID,
×
4412
                                Type:            int64(tlvType),
×
4413
                                Value:           value,
×
4414
                        },
×
4415
                )
×
4416
                if err != nil {
×
4417
                        return fmt.Errorf("unable to insert "+
×
4418
                                "channel_policy(%d) extra signed field(%v): %w",
×
4419
                                chanPolicyID, tlvType, err)
×
4420
                }
×
4421
        }
4422

4423
        return nil
×
4424
}
4425

4426
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
4427
// provided dbChanRow and also fetches any other required information
4428
// to construct the edge info.
4429
func getAndBuildEdgeInfo(ctx context.Context, cfg *SQLStoreConfig,
4430
        db SQLQueries, dbChan sqlc.GraphChannel, node1,
4431
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
4432

×
4433
        data, err := batchLoadChannelData(
×
4434
                ctx, cfg.QueryCfg, db, []int64{dbChan.ID}, nil,
×
4435
        )
×
4436
        if err != nil {
×
4437
                return nil, fmt.Errorf("unable to batch load channel data: %w",
×
4438
                        err)
×
4439
        }
×
4440

4441
        return buildEdgeInfoWithBatchData(
×
4442
                cfg.ChainHash, dbChan, node1, node2, data,
×
4443
        )
×
4444
}
4445

4446
// buildEdgeInfoWithBatchData builds edge info using pre-loaded batch data.
4447
func buildEdgeInfoWithBatchData(chain chainhash.Hash,
4448
        dbChan sqlc.GraphChannel, node1, node2 route.Vertex,
4449
        batchData *batchChannelData) (*models.ChannelEdgeInfo, error) {
×
4450

×
4451
        if dbChan.Version != int16(lnwire.GossipVersion1) {
×
4452
                return nil, fmt.Errorf("unsupported channel version: %d",
×
4453
                        dbChan.Version)
×
4454
        }
×
4455

4456
        // Use pre-loaded features and extras types.
4457
        fv := lnwire.EmptyFeatureVector()
×
4458
        if features, exists := batchData.chanfeatures[dbChan.ID]; exists {
×
4459
                for _, bit := range features {
×
4460
                        fv.Set(lnwire.FeatureBit(bit))
×
4461
                }
×
4462
        }
4463

4464
        var extras map[uint64][]byte
×
4465
        channelExtras, exists := batchData.chanExtraTypes[dbChan.ID]
×
4466
        if exists {
×
4467
                extras = channelExtras
×
4468
        } else {
×
4469
                extras = make(map[uint64][]byte)
×
4470
        }
×
4471

4472
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
4473
        if err != nil {
×
4474
                return nil, err
×
4475
        }
×
4476

4477
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4478
        if err != nil {
×
4479
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4480
                        "fields: %w", err)
×
4481
        }
×
4482
        if recs == nil {
×
4483
                recs = make([]byte, 0)
×
4484
        }
×
4485

NEW
4486
        btcKey1, err := route.NewVertexFromBytes(dbChan.BitcoinKey1)
×
NEW
4487
        if err != nil {
×
NEW
4488
                return nil, err
×
NEW
4489
        }
×
NEW
4490
        btcKey2, err := route.NewVertexFromBytes(dbChan.BitcoinKey2)
×
NEW
4491
        if err != nil {
×
NEW
4492
                return nil, err
×
NEW
4493
        }
×
4494

NEW
4495
        channel, err := models.NewV1Channel(
×
NEW
4496
                byteOrder.Uint64(dbChan.Scid), chain, node1, node2,
×
NEW
4497
                &models.ChannelV1Fields{
×
NEW
4498
                        BitcoinKey1Bytes: btcKey1,
×
NEW
4499
                        BitcoinKey2Bytes: btcKey2,
×
NEW
4500
                        ExtraOpaqueData:  recs,
×
NEW
4501
                },
×
NEW
4502
                models.WithChannelPoint(*op),
×
NEW
4503
                models.WithCapacity(btcutil.Amount(dbChan.Capacity.Int64)),
×
NEW
4504
                models.WithFeatures(fv.RawFeatureVector),
×
NEW
4505
        )
×
NEW
4506
        if err != nil {
×
NEW
4507
                return nil, err
×
4508
        }
×
4509

4510
        // We always set all the signatures at the same time, so we can
4511
        // safely check if one signature is present to determine if we have the
4512
        // rest of the signatures for the auth proof.
4513
        if len(dbChan.Bitcoin1Signature) > 0 {
×
NEW
4514
                // For v1 channels, we have four signatures.
×
NEW
4515
                if dbChan.Version == int16(lnwire.GossipVersion1) {
×
NEW
4516
                        channel.AuthProof = models.NewV1ChannelAuthProof(
×
NEW
4517
                                dbChan.Node1Signature,
×
NEW
4518
                                dbChan.Node2Signature,
×
NEW
4519
                                dbChan.Bitcoin1Signature,
×
NEW
4520
                                dbChan.Bitcoin2Signature,
×
NEW
4521
                        )
×
4522
                }
×
4523
                // TODO(elle): Add v2 support when needed.
4524
        }
4525

4526
        return channel, nil
×
4527
}
4528

4529
// buildNodeVertices is a helper that converts raw node public keys
4530
// into route.Vertex instances.
4531
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4532
        route.Vertex, error) {
×
4533

×
4534
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4535
        if err != nil {
×
4536
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4537
                        "create vertex from node1 pubkey: %w", err)
×
4538
        }
×
4539

4540
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
4541
        if err != nil {
×
4542
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4543
                        "create vertex from node2 pubkey: %w", err)
×
4544
        }
×
4545

4546
        return node1Vertex, node2Vertex, nil
×
4547
}
4548

4549
// getAndBuildChanPolicies uses the given sqlc.GraphChannelPolicy and also
4550
// retrieves all the extra info required to build the complete
4551
// models.ChannelEdgePolicy types. It returns two policies, which may be nil if
4552
// the provided sqlc.GraphChannelPolicy records are nil.
4553
func getAndBuildChanPolicies(ctx context.Context, cfg *sqldb.QueryConfig,
4554
        db SQLQueries, dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4555
        channelID uint64, node1, node2 route.Vertex) (*models.ChannelEdgePolicy,
4556
        *models.ChannelEdgePolicy, error) {
×
4557

×
4558
        if dbPol1 == nil && dbPol2 == nil {
×
4559
                return nil, nil, nil
×
4560
        }
×
4561

4562
        var policyIDs = make([]int64, 0, 2)
×
4563
        if dbPol1 != nil {
×
4564
                policyIDs = append(policyIDs, dbPol1.ID)
×
4565
        }
×
4566
        if dbPol2 != nil {
×
4567
                policyIDs = append(policyIDs, dbPol2.ID)
×
4568
        }
×
4569

4570
        batchData, err := batchLoadChannelData(ctx, cfg, db, nil, policyIDs)
×
4571
        if err != nil {
×
4572
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
4573
                        "data: %w", err)
×
4574
        }
×
4575

4576
        pol1, err := buildChanPolicyWithBatchData(
×
4577
                dbPol1, channelID, node2, batchData,
×
4578
        )
×
4579
        if err != nil {
×
4580
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
4581
        }
×
4582

4583
        pol2, err := buildChanPolicyWithBatchData(
×
4584
                dbPol2, channelID, node1, batchData,
×
4585
        )
×
4586
        if err != nil {
×
4587
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
4588
        }
×
4589

4590
        return pol1, pol2, nil
×
4591
}
4592

4593
// buildCachedChanPolicies builds models.CachedEdgePolicy instances from the
4594
// provided sqlc.GraphChannelPolicy objects. If a provided policy is nil,
4595
// then nil is returned for it.
4596
func buildCachedChanPolicies(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4597
        channelID uint64, node1, node2 route.Vertex) (*models.CachedEdgePolicy,
4598
        *models.CachedEdgePolicy, error) {
×
4599

×
4600
        var p1, p2 *models.CachedEdgePolicy
×
4601
        if dbPol1 != nil {
×
4602
                policy1, err := buildChanPolicy(*dbPol1, channelID, nil, node2)
×
4603
                if err != nil {
×
4604
                        return nil, nil, err
×
4605
                }
×
4606

4607
                p1 = models.NewCachedPolicy(policy1)
×
4608
        }
4609
        if dbPol2 != nil {
×
4610
                policy2, err := buildChanPolicy(*dbPol2, channelID, nil, node1)
×
4611
                if err != nil {
×
4612
                        return nil, nil, err
×
4613
                }
×
4614

4615
                p2 = models.NewCachedPolicy(policy2)
×
4616
        }
4617

4618
        return p1, p2, nil
×
4619
}
4620

4621
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4622
// provided sqlc.GraphChannelPolicy and other required information.
4623
func buildChanPolicy(dbPolicy sqlc.GraphChannelPolicy, channelID uint64,
4624
        extras map[uint64][]byte,
4625
        toNode route.Vertex) (*models.ChannelEdgePolicy, error) {
×
4626

×
4627
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4628
        if err != nil {
×
4629
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4630
                        "fields: %w", err)
×
4631
        }
×
4632

4633
        var inboundFee fn.Option[lnwire.Fee]
×
4634
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
4635
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4636

×
4637
                inboundFee = fn.Some(lnwire.Fee{
×
4638
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4639
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4640
                })
×
4641
        }
×
4642

4643
        return &models.ChannelEdgePolicy{
×
4644
                SigBytes:  dbPolicy.Signature,
×
4645
                ChannelID: channelID,
×
4646
                LastUpdate: time.Unix(
×
4647
                        dbPolicy.LastUpdate.Int64, 0,
×
4648
                ),
×
4649
                MessageFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags](
×
4650
                        dbPolicy.MessageFlags,
×
4651
                ),
×
4652
                ChannelFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags](
×
4653
                        dbPolicy.ChannelFlags,
×
4654
                ),
×
4655
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
4656
                MinHTLC: lnwire.MilliSatoshi(
×
4657
                        dbPolicy.MinHtlcMsat,
×
4658
                ),
×
4659
                MaxHTLC: lnwire.MilliSatoshi(
×
4660
                        dbPolicy.MaxHtlcMsat.Int64,
×
4661
                ),
×
4662
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4663
                        dbPolicy.BaseFeeMsat,
×
4664
                ),
×
4665
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4666
                ToNode:                    toNode,
×
4667
                InboundFee:                inboundFee,
×
4668
                ExtraOpaqueData:           recs,
×
4669
        }, nil
×
4670
}
4671

4672
// extractChannelPolicies extracts the sqlc.GraphChannelPolicy records from the give
4673
// row which is expected to be a sqlc type that contains channel policy
4674
// information. It returns two policies, which may be nil if the policy
4675
// information is not present in the row.
4676
//
4677
//nolint:ll,dupl,funlen
4678
func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy,
4679
        *sqlc.GraphChannelPolicy, error) {
×
4680

×
4681
        var policy1, policy2 *sqlc.GraphChannelPolicy
×
4682
        switch r := row.(type) {
×
4683
        case sqlc.ListChannelsWithPoliciesForCachePaginatedRow:
×
4684
                if r.Policy1Timelock.Valid {
×
4685
                        policy1 = &sqlc.GraphChannelPolicy{
×
4686
                                Timelock:                r.Policy1Timelock.Int32,
×
4687
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4688
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4689
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4690
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4691
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4692
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4693
                                Disabled:                r.Policy1Disabled,
×
4694
                                MessageFlags:            r.Policy1MessageFlags,
×
4695
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4696
                        }
×
4697
                }
×
4698
                if r.Policy2Timelock.Valid {
×
4699
                        policy2 = &sqlc.GraphChannelPolicy{
×
4700
                                Timelock:                r.Policy2Timelock.Int32,
×
4701
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4702
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4703
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4704
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4705
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4706
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4707
                                Disabled:                r.Policy2Disabled,
×
4708
                                MessageFlags:            r.Policy2MessageFlags,
×
4709
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4710
                        }
×
4711
                }
×
4712

4713
                return policy1, policy2, nil
×
4714

4715
        case sqlc.GetChannelsBySCIDWithPoliciesRow:
×
4716
                if r.Policy1ID.Valid {
×
4717
                        policy1 = &sqlc.GraphChannelPolicy{
×
4718
                                ID:                      r.Policy1ID.Int64,
×
4719
                                Version:                 r.Policy1Version.Int16,
×
4720
                                ChannelID:               r.GraphChannel.ID,
×
4721
                                NodeID:                  r.Policy1NodeID.Int64,
×
4722
                                Timelock:                r.Policy1Timelock.Int32,
×
4723
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4724
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4725
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4726
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4727
                                LastUpdate:              r.Policy1LastUpdate,
×
4728
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4729
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4730
                                Disabled:                r.Policy1Disabled,
×
4731
                                MessageFlags:            r.Policy1MessageFlags,
×
4732
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4733
                                Signature:               r.Policy1Signature,
×
4734
                        }
×
4735
                }
×
4736
                if r.Policy2ID.Valid {
×
4737
                        policy2 = &sqlc.GraphChannelPolicy{
×
4738
                                ID:                      r.Policy2ID.Int64,
×
4739
                                Version:                 r.Policy2Version.Int16,
×
4740
                                ChannelID:               r.GraphChannel.ID,
×
4741
                                NodeID:                  r.Policy2NodeID.Int64,
×
4742
                                Timelock:                r.Policy2Timelock.Int32,
×
4743
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4744
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4745
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4746
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4747
                                LastUpdate:              r.Policy2LastUpdate,
×
4748
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4749
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4750
                                Disabled:                r.Policy2Disabled,
×
4751
                                MessageFlags:            r.Policy2MessageFlags,
×
4752
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4753
                                Signature:               r.Policy2Signature,
×
4754
                        }
×
4755
                }
×
4756

4757
                return policy1, policy2, nil
×
4758

4759
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
4760
                if r.Policy1ID.Valid {
×
4761
                        policy1 = &sqlc.GraphChannelPolicy{
×
4762
                                ID:                      r.Policy1ID.Int64,
×
4763
                                Version:                 r.Policy1Version.Int16,
×
4764
                                ChannelID:               r.GraphChannel.ID,
×
4765
                                NodeID:                  r.Policy1NodeID.Int64,
×
4766
                                Timelock:                r.Policy1Timelock.Int32,
×
4767
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4768
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4769
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4770
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4771
                                LastUpdate:              r.Policy1LastUpdate,
×
4772
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4773
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4774
                                Disabled:                r.Policy1Disabled,
×
4775
                                MessageFlags:            r.Policy1MessageFlags,
×
4776
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4777
                                Signature:               r.Policy1Signature,
×
4778
                        }
×
4779
                }
×
4780
                if r.Policy2ID.Valid {
×
4781
                        policy2 = &sqlc.GraphChannelPolicy{
×
4782
                                ID:                      r.Policy2ID.Int64,
×
4783
                                Version:                 r.Policy2Version.Int16,
×
4784
                                ChannelID:               r.GraphChannel.ID,
×
4785
                                NodeID:                  r.Policy2NodeID.Int64,
×
4786
                                Timelock:                r.Policy2Timelock.Int32,
×
4787
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4788
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4789
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4790
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4791
                                LastUpdate:              r.Policy2LastUpdate,
×
4792
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4793
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4794
                                Disabled:                r.Policy2Disabled,
×
4795
                                MessageFlags:            r.Policy2MessageFlags,
×
4796
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4797
                                Signature:               r.Policy2Signature,
×
4798
                        }
×
4799
                }
×
4800

4801
                return policy1, policy2, nil
×
4802

4803
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4804
                if r.Policy1ID.Valid {
×
4805
                        policy1 = &sqlc.GraphChannelPolicy{
×
4806
                                ID:                      r.Policy1ID.Int64,
×
4807
                                Version:                 r.Policy1Version.Int16,
×
4808
                                ChannelID:               r.GraphChannel.ID,
×
4809
                                NodeID:                  r.Policy1NodeID.Int64,
×
4810
                                Timelock:                r.Policy1Timelock.Int32,
×
4811
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4812
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4813
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4814
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4815
                                LastUpdate:              r.Policy1LastUpdate,
×
4816
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4817
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4818
                                Disabled:                r.Policy1Disabled,
×
4819
                                MessageFlags:            r.Policy1MessageFlags,
×
4820
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4821
                                Signature:               r.Policy1Signature,
×
4822
                        }
×
4823
                }
×
4824
                if r.Policy2ID.Valid {
×
4825
                        policy2 = &sqlc.GraphChannelPolicy{
×
4826
                                ID:                      r.Policy2ID.Int64,
×
4827
                                Version:                 r.Policy2Version.Int16,
×
4828
                                ChannelID:               r.GraphChannel.ID,
×
4829
                                NodeID:                  r.Policy2NodeID.Int64,
×
4830
                                Timelock:                r.Policy2Timelock.Int32,
×
4831
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4832
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4833
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4834
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4835
                                LastUpdate:              r.Policy2LastUpdate,
×
4836
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4837
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4838
                                Disabled:                r.Policy2Disabled,
×
4839
                                MessageFlags:            r.Policy2MessageFlags,
×
4840
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4841
                                Signature:               r.Policy2Signature,
×
4842
                        }
×
4843
                }
×
4844

4845
                return policy1, policy2, nil
×
4846

4847
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4848
                if r.Policy1ID.Valid {
×
4849
                        policy1 = &sqlc.GraphChannelPolicy{
×
4850
                                ID:                      r.Policy1ID.Int64,
×
4851
                                Version:                 r.Policy1Version.Int16,
×
4852
                                ChannelID:               r.GraphChannel.ID,
×
4853
                                NodeID:                  r.Policy1NodeID.Int64,
×
4854
                                Timelock:                r.Policy1Timelock.Int32,
×
4855
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4856
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4857
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4858
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4859
                                LastUpdate:              r.Policy1LastUpdate,
×
4860
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4861
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4862
                                Disabled:                r.Policy1Disabled,
×
4863
                                MessageFlags:            r.Policy1MessageFlags,
×
4864
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4865
                                Signature:               r.Policy1Signature,
×
4866
                        }
×
4867
                }
×
4868
                if r.Policy2ID.Valid {
×
4869
                        policy2 = &sqlc.GraphChannelPolicy{
×
4870
                                ID:                      r.Policy2ID.Int64,
×
4871
                                Version:                 r.Policy2Version.Int16,
×
4872
                                ChannelID:               r.GraphChannel.ID,
×
4873
                                NodeID:                  r.Policy2NodeID.Int64,
×
4874
                                Timelock:                r.Policy2Timelock.Int32,
×
4875
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4876
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4877
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4878
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4879
                                LastUpdate:              r.Policy2LastUpdate,
×
4880
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4881
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4882
                                Disabled:                r.Policy2Disabled,
×
4883
                                MessageFlags:            r.Policy2MessageFlags,
×
4884
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4885
                                Signature:               r.Policy2Signature,
×
4886
                        }
×
4887
                }
×
4888

4889
                return policy1, policy2, nil
×
4890

4891
        case sqlc.ListChannelsForNodeIDsRow:
×
4892
                if r.Policy1ID.Valid {
×
4893
                        policy1 = &sqlc.GraphChannelPolicy{
×
4894
                                ID:                      r.Policy1ID.Int64,
×
4895
                                Version:                 r.Policy1Version.Int16,
×
4896
                                ChannelID:               r.GraphChannel.ID,
×
4897
                                NodeID:                  r.Policy1NodeID.Int64,
×
4898
                                Timelock:                r.Policy1Timelock.Int32,
×
4899
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4900
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4901
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4902
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4903
                                LastUpdate:              r.Policy1LastUpdate,
×
4904
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4905
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4906
                                Disabled:                r.Policy1Disabled,
×
4907
                                MessageFlags:            r.Policy1MessageFlags,
×
4908
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4909
                                Signature:               r.Policy1Signature,
×
4910
                        }
×
4911
                }
×
4912
                if r.Policy2ID.Valid {
×
4913
                        policy2 = &sqlc.GraphChannelPolicy{
×
4914
                                ID:                      r.Policy2ID.Int64,
×
4915
                                Version:                 r.Policy2Version.Int16,
×
4916
                                ChannelID:               r.GraphChannel.ID,
×
4917
                                NodeID:                  r.Policy2NodeID.Int64,
×
4918
                                Timelock:                r.Policy2Timelock.Int32,
×
4919
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4920
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4921
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4922
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4923
                                LastUpdate:              r.Policy2LastUpdate,
×
4924
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4925
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4926
                                Disabled:                r.Policy2Disabled,
×
4927
                                MessageFlags:            r.Policy2MessageFlags,
×
4928
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4929
                                Signature:               r.Policy2Signature,
×
4930
                        }
×
4931
                }
×
4932

4933
                return policy1, policy2, nil
×
4934

4935
        case sqlc.ListChannelsByNodeIDRow:
×
4936
                if r.Policy1ID.Valid {
×
4937
                        policy1 = &sqlc.GraphChannelPolicy{
×
4938
                                ID:                      r.Policy1ID.Int64,
×
4939
                                Version:                 r.Policy1Version.Int16,
×
4940
                                ChannelID:               r.GraphChannel.ID,
×
4941
                                NodeID:                  r.Policy1NodeID.Int64,
×
4942
                                Timelock:                r.Policy1Timelock.Int32,
×
4943
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4944
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4945
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4946
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4947
                                LastUpdate:              r.Policy1LastUpdate,
×
4948
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4949
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4950
                                Disabled:                r.Policy1Disabled,
×
4951
                                MessageFlags:            r.Policy1MessageFlags,
×
4952
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4953
                                Signature:               r.Policy1Signature,
×
4954
                        }
×
4955
                }
×
4956
                if r.Policy2ID.Valid {
×
4957
                        policy2 = &sqlc.GraphChannelPolicy{
×
4958
                                ID:                      r.Policy2ID.Int64,
×
4959
                                Version:                 r.Policy2Version.Int16,
×
4960
                                ChannelID:               r.GraphChannel.ID,
×
4961
                                NodeID:                  r.Policy2NodeID.Int64,
×
4962
                                Timelock:                r.Policy2Timelock.Int32,
×
4963
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4964
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4965
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4966
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4967
                                LastUpdate:              r.Policy2LastUpdate,
×
4968
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4969
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4970
                                Disabled:                r.Policy2Disabled,
×
4971
                                MessageFlags:            r.Policy2MessageFlags,
×
4972
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4973
                                Signature:               r.Policy2Signature,
×
4974
                        }
×
4975
                }
×
4976

4977
                return policy1, policy2, nil
×
4978

4979
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4980
                if r.Policy1ID.Valid {
×
4981
                        policy1 = &sqlc.GraphChannelPolicy{
×
4982
                                ID:                      r.Policy1ID.Int64,
×
4983
                                Version:                 r.Policy1Version.Int16,
×
4984
                                ChannelID:               r.GraphChannel.ID,
×
4985
                                NodeID:                  r.Policy1NodeID.Int64,
×
4986
                                Timelock:                r.Policy1Timelock.Int32,
×
4987
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4988
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4989
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4990
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4991
                                LastUpdate:              r.Policy1LastUpdate,
×
4992
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4993
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4994
                                Disabled:                r.Policy1Disabled,
×
4995
                                MessageFlags:            r.Policy1MessageFlags,
×
4996
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4997
                                Signature:               r.Policy1Signature,
×
4998
                        }
×
4999
                }
×
5000
                if r.Policy2ID.Valid {
×
5001
                        policy2 = &sqlc.GraphChannelPolicy{
×
5002
                                ID:                      r.Policy2ID.Int64,
×
5003
                                Version:                 r.Policy2Version.Int16,
×
5004
                                ChannelID:               r.GraphChannel.ID,
×
5005
                                NodeID:                  r.Policy2NodeID.Int64,
×
5006
                                Timelock:                r.Policy2Timelock.Int32,
×
5007
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
5008
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
5009
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
5010
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
5011
                                LastUpdate:              r.Policy2LastUpdate,
×
5012
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
5013
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
5014
                                Disabled:                r.Policy2Disabled,
×
5015
                                MessageFlags:            r.Policy2MessageFlags,
×
5016
                                ChannelFlags:            r.Policy2ChannelFlags,
×
5017
                                Signature:               r.Policy2Signature,
×
5018
                        }
×
5019
                }
×
5020

5021
                return policy1, policy2, nil
×
5022

5023
        case sqlc.GetChannelsByIDsRow:
×
5024
                if r.Policy1ID.Valid {
×
5025
                        policy1 = &sqlc.GraphChannelPolicy{
×
5026
                                ID:                      r.Policy1ID.Int64,
×
5027
                                Version:                 r.Policy1Version.Int16,
×
5028
                                ChannelID:               r.GraphChannel.ID,
×
5029
                                NodeID:                  r.Policy1NodeID.Int64,
×
5030
                                Timelock:                r.Policy1Timelock.Int32,
×
5031
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
5032
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
5033
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
5034
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
5035
                                LastUpdate:              r.Policy1LastUpdate,
×
5036
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
5037
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
5038
                                Disabled:                r.Policy1Disabled,
×
5039
                                MessageFlags:            r.Policy1MessageFlags,
×
5040
                                ChannelFlags:            r.Policy1ChannelFlags,
×
5041
                                Signature:               r.Policy1Signature,
×
5042
                        }
×
5043
                }
×
5044
                if r.Policy2ID.Valid {
×
5045
                        policy2 = &sqlc.GraphChannelPolicy{
×
5046
                                ID:                      r.Policy2ID.Int64,
×
5047
                                Version:                 r.Policy2Version.Int16,
×
5048
                                ChannelID:               r.GraphChannel.ID,
×
5049
                                NodeID:                  r.Policy2NodeID.Int64,
×
5050
                                Timelock:                r.Policy2Timelock.Int32,
×
5051
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
5052
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
5053
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
5054
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
5055
                                LastUpdate:              r.Policy2LastUpdate,
×
5056
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
5057
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
5058
                                Disabled:                r.Policy2Disabled,
×
5059
                                MessageFlags:            r.Policy2MessageFlags,
×
5060
                                ChannelFlags:            r.Policy2ChannelFlags,
×
5061
                                Signature:               r.Policy2Signature,
×
5062
                        }
×
5063
                }
×
5064

5065
                return policy1, policy2, nil
×
5066

5067
        default:
×
5068
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
5069
                        "extractChannelPolicies: %T", r)
×
5070
        }
5071
}
5072

5073
// channelIDToBytes converts a channel ID (SCID) to a byte array
5074
// representation.
5075
func channelIDToBytes(channelID uint64) []byte {
×
5076
        var chanIDB [8]byte
×
5077
        byteOrder.PutUint64(chanIDB[:], channelID)
×
5078

×
5079
        return chanIDB[:]
×
5080
}
×
5081

5082
// buildNodeAddresses converts a slice of nodeAddress into a slice of net.Addr.
5083
func buildNodeAddresses(addresses []nodeAddress) ([]net.Addr, error) {
×
5084
        if len(addresses) == 0 {
×
5085
                return nil, nil
×
5086
        }
×
5087

5088
        result := make([]net.Addr, 0, len(addresses))
×
5089
        for _, addr := range addresses {
×
5090
                netAddr, err := parseAddress(addr.addrType, addr.address)
×
5091
                if err != nil {
×
5092
                        return nil, fmt.Errorf("unable to parse address %s "+
×
5093
                                "of type %d: %w", addr.address, addr.addrType,
×
5094
                                err)
×
5095
                }
×
5096
                if netAddr != nil {
×
5097
                        result = append(result, netAddr)
×
5098
                }
×
5099
        }
5100

5101
        // If we have no valid addresses, return nil instead of empty slice.
5102
        if len(result) == 0 {
×
5103
                return nil, nil
×
5104
        }
×
5105

5106
        return result, nil
×
5107
}
5108

5109
// parseAddress parses the given address string based on the address type
5110
// and returns a net.Addr instance. It supports IPv4, IPv6, Tor v2, Tor v3,
5111
// and opaque addresses.
5112
func parseAddress(addrType dbAddressType, address string) (net.Addr, error) {
×
5113
        switch addrType {
×
5114
        case addressTypeIPv4:
×
5115
                tcp, err := net.ResolveTCPAddr("tcp4", address)
×
5116
                if err != nil {
×
5117
                        return nil, err
×
5118
                }
×
5119

5120
                tcp.IP = tcp.IP.To4()
×
5121

×
5122
                return tcp, nil
×
5123

5124
        case addressTypeIPv6:
×
5125
                tcp, err := net.ResolveTCPAddr("tcp6", address)
×
5126
                if err != nil {
×
5127
                        return nil, err
×
5128
                }
×
5129

5130
                return tcp, nil
×
5131

5132
        case addressTypeTorV3, addressTypeTorV2:
×
5133
                service, portStr, err := net.SplitHostPort(address)
×
5134
                if err != nil {
×
5135
                        return nil, fmt.Errorf("unable to split tor "+
×
5136
                                "address: %v", address)
×
5137
                }
×
5138

5139
                port, err := strconv.Atoi(portStr)
×
5140
                if err != nil {
×
5141
                        return nil, err
×
5142
                }
×
5143

5144
                return &tor.OnionAddr{
×
5145
                        OnionService: service,
×
5146
                        Port:         port,
×
5147
                }, nil
×
5148

5149
        case addressTypeDNS:
×
5150
                hostname, portStr, err := net.SplitHostPort(address)
×
5151
                if err != nil {
×
5152
                        return nil, fmt.Errorf("unable to split DNS "+
×
5153
                                "address: %v", address)
×
5154
                }
×
5155

5156
                port, err := strconv.Atoi(portStr)
×
5157
                if err != nil {
×
5158
                        return nil, err
×
5159
                }
×
5160

5161
                return &lnwire.DNSAddress{
×
5162
                        Hostname: hostname,
×
5163
                        Port:     uint16(port),
×
5164
                }, nil
×
5165

5166
        case addressTypeOpaque:
×
5167
                opaque, err := hex.DecodeString(address)
×
5168
                if err != nil {
×
5169
                        return nil, fmt.Errorf("unable to decode opaque "+
×
5170
                                "address: %v", address)
×
5171
                }
×
5172

5173
                return &lnwire.OpaqueAddrs{
×
5174
                        Payload: opaque,
×
5175
                }, nil
×
5176

5177
        default:
×
5178
                return nil, fmt.Errorf("unknown address type: %v", addrType)
×
5179
        }
5180
}
5181

5182
// batchNodeData holds all the related data for a batch of nodes.
5183
type batchNodeData struct {
5184
        // features is a map from a DB node ID to the feature bits for that
5185
        // node.
5186
        features map[int64][]int
5187

5188
        // addresses is a map from a DB node ID to the node's addresses.
5189
        addresses map[int64][]nodeAddress
5190

5191
        // extraFields is a map from a DB node ID to the extra signed fields
5192
        // for that node.
5193
        extraFields map[int64]map[uint64][]byte
5194
}
5195

5196
// nodeAddress holds the address type, position and address string for a
5197
// node. This is used to batch the fetching of node addresses.
5198
type nodeAddress struct {
5199
        addrType dbAddressType
5200
        position int32
5201
        address  string
5202
}
5203

5204
// batchLoadNodeData loads all related data for a batch of node IDs using the
5205
// provided SQLQueries interface. It returns a batchNodeData instance containing
5206
// the node features, addresses and extra signed fields.
5207
func batchLoadNodeData(ctx context.Context, cfg *sqldb.QueryConfig,
5208
        db SQLQueries, nodeIDs []int64) (*batchNodeData, error) {
×
5209

×
5210
        // Batch load the node features.
×
5211
        features, err := batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
5212
        if err != nil {
×
5213
                return nil, fmt.Errorf("unable to batch load node "+
×
5214
                        "features: %w", err)
×
5215
        }
×
5216

5217
        // Batch load the node addresses.
5218
        addrs, err := batchLoadNodeAddressesHelper(ctx, cfg, db, nodeIDs)
×
5219
        if err != nil {
×
5220
                return nil, fmt.Errorf("unable to batch load node "+
×
5221
                        "addresses: %w", err)
×
5222
        }
×
5223

5224
        // Batch load the node extra signed fields.
5225
        extraTypes, err := batchLoadNodeExtraTypesHelper(ctx, cfg, db, nodeIDs)
×
5226
        if err != nil {
×
5227
                return nil, fmt.Errorf("unable to batch load node extra "+
×
5228
                        "signed fields: %w", err)
×
5229
        }
×
5230

5231
        return &batchNodeData{
×
5232
                features:    features,
×
5233
                addresses:   addrs,
×
5234
                extraFields: extraTypes,
×
5235
        }, nil
×
5236
}
5237

5238
// batchLoadNodeFeaturesHelper loads node features for a batch of node IDs
5239
// using ExecuteBatchQuery wrapper around the GetNodeFeaturesBatch query.
5240
func batchLoadNodeFeaturesHelper(ctx context.Context,
5241
        cfg *sqldb.QueryConfig, db SQLQueries,
5242
        nodeIDs []int64) (map[int64][]int, error) {
×
5243

×
5244
        features := make(map[int64][]int)
×
5245

×
5246
        return features, sqldb.ExecuteBatchQuery(
×
5247
                ctx, cfg, nodeIDs,
×
5248
                func(id int64) int64 {
×
5249
                        return id
×
5250
                },
×
5251
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature,
5252
                        error) {
×
5253

×
5254
                        return db.GetNodeFeaturesBatch(ctx, ids)
×
5255
                },
×
5256
                func(ctx context.Context, feature sqlc.GraphNodeFeature) error {
×
5257
                        features[feature.NodeID] = append(
×
5258
                                features[feature.NodeID],
×
5259
                                int(feature.FeatureBit),
×
5260
                        )
×
5261

×
5262
                        return nil
×
5263
                },
×
5264
        )
5265
}
5266

5267
// batchLoadNodeAddressesHelper loads node addresses using ExecuteBatchQuery
5268
// wrapper around the GetNodeAddressesBatch query. It returns a map from
5269
// node ID to a slice of nodeAddress structs.
5270
func batchLoadNodeAddressesHelper(ctx context.Context,
5271
        cfg *sqldb.QueryConfig, db SQLQueries,
5272
        nodeIDs []int64) (map[int64][]nodeAddress, error) {
×
5273

×
5274
        addrs := make(map[int64][]nodeAddress)
×
5275

×
5276
        return addrs, sqldb.ExecuteBatchQuery(
×
5277
                ctx, cfg, nodeIDs,
×
5278
                func(id int64) int64 {
×
5279
                        return id
×
5280
                },
×
5281
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress,
5282
                        error) {
×
5283

×
5284
                        return db.GetNodeAddressesBatch(ctx, ids)
×
5285
                },
×
5286
                func(ctx context.Context, addr sqlc.GraphNodeAddress) error {
×
5287
                        addrs[addr.NodeID] = append(
×
5288
                                addrs[addr.NodeID], nodeAddress{
×
5289
                                        addrType: dbAddressType(addr.Type),
×
5290
                                        position: addr.Position,
×
5291
                                        address:  addr.Address,
×
5292
                                },
×
5293
                        )
×
5294

×
5295
                        return nil
×
5296
                },
×
5297
        )
5298
}
5299

5300
// batchLoadNodeExtraTypesHelper loads node extra type bytes for a batch of
5301
// node IDs using ExecuteBatchQuery wrapper around the GetNodeExtraTypesBatch
5302
// query.
5303
func batchLoadNodeExtraTypesHelper(ctx context.Context,
5304
        cfg *sqldb.QueryConfig, db SQLQueries,
5305
        nodeIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5306

×
5307
        extraFields := make(map[int64]map[uint64][]byte)
×
5308

×
5309
        callback := func(ctx context.Context,
×
5310
                field sqlc.GraphNodeExtraType) error {
×
5311

×
5312
                if extraFields[field.NodeID] == nil {
×
5313
                        extraFields[field.NodeID] = make(map[uint64][]byte)
×
5314
                }
×
5315
                extraFields[field.NodeID][uint64(field.Type)] = field.Value
×
5316

×
5317
                return nil
×
5318
        }
5319

5320
        return extraFields, sqldb.ExecuteBatchQuery(
×
5321
                ctx, cfg, nodeIDs,
×
5322
                func(id int64) int64 {
×
5323
                        return id
×
5324
                },
×
5325
                func(ctx context.Context, ids []int64) (
5326
                        []sqlc.GraphNodeExtraType, error) {
×
5327

×
5328
                        return db.GetNodeExtraTypesBatch(ctx, ids)
×
5329
                },
×
5330
                callback,
5331
        )
5332
}
5333

5334
// buildChanPoliciesWithBatchData builds two models.ChannelEdgePolicy instances
5335
// from the provided sqlc.GraphChannelPolicy records and the
5336
// provided batchChannelData.
5337
func buildChanPoliciesWithBatchData(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
5338
        channelID uint64, node1, node2 route.Vertex,
5339
        batchData *batchChannelData) (*models.ChannelEdgePolicy,
5340
        *models.ChannelEdgePolicy, error) {
×
5341

×
5342
        pol1, err := buildChanPolicyWithBatchData(
×
5343
                dbPol1, channelID, node2, batchData,
×
5344
        )
×
5345
        if err != nil {
×
5346
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
5347
        }
×
5348

5349
        pol2, err := buildChanPolicyWithBatchData(
×
5350
                dbPol2, channelID, node1, batchData,
×
5351
        )
×
5352
        if err != nil {
×
5353
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
5354
        }
×
5355

5356
        return pol1, pol2, nil
×
5357
}
5358

5359
// buildChanPolicyWithBatchData builds a models.ChannelEdgePolicy instance from
5360
// the provided sqlc.GraphChannelPolicy and the provided batchChannelData.
5361
func buildChanPolicyWithBatchData(dbPol *sqlc.GraphChannelPolicy,
5362
        channelID uint64, toNode route.Vertex,
5363
        batchData *batchChannelData) (*models.ChannelEdgePolicy, error) {
×
5364

×
5365
        if dbPol == nil {
×
5366
                return nil, nil
×
5367
        }
×
5368

5369
        var dbPol1Extras map[uint64][]byte
×
5370
        if extras, exists := batchData.policyExtras[dbPol.ID]; exists {
×
5371
                dbPol1Extras = extras
×
5372
        } else {
×
5373
                dbPol1Extras = make(map[uint64][]byte)
×
5374
        }
×
5375

5376
        return buildChanPolicy(*dbPol, channelID, dbPol1Extras, toNode)
×
5377
}
5378

5379
// batchChannelData holds all the related data for a batch of channels.
5380
type batchChannelData struct {
5381
        // chanFeatures is a map from DB channel ID to a slice of feature bits.
5382
        chanfeatures map[int64][]int
5383

5384
        // chanExtras is a map from DB channel ID to a map of TLV type to
5385
        // extra signed field bytes.
5386
        chanExtraTypes map[int64]map[uint64][]byte
5387

5388
        // policyExtras is a map from DB channel policy ID to a map of TLV type
5389
        // to extra signed field bytes.
5390
        policyExtras map[int64]map[uint64][]byte
5391
}
5392

5393
// batchLoadChannelData loads all related data for batches of channels and
5394
// policies.
5395
func batchLoadChannelData(ctx context.Context, cfg *sqldb.QueryConfig,
5396
        db SQLQueries, channelIDs []int64,
5397
        policyIDs []int64) (*batchChannelData, error) {
×
5398

×
5399
        batchData := &batchChannelData{
×
5400
                chanfeatures:   make(map[int64][]int),
×
5401
                chanExtraTypes: make(map[int64]map[uint64][]byte),
×
5402
                policyExtras:   make(map[int64]map[uint64][]byte),
×
5403
        }
×
5404

×
5405
        // Batch load channel features and extras
×
5406
        var err error
×
5407
        if len(channelIDs) > 0 {
×
5408
                batchData.chanfeatures, err = batchLoadChannelFeaturesHelper(
×
5409
                        ctx, cfg, db, channelIDs,
×
5410
                )
×
5411
                if err != nil {
×
5412
                        return nil, fmt.Errorf("unable to batch load "+
×
5413
                                "channel features: %w", err)
×
5414
                }
×
5415

5416
                batchData.chanExtraTypes, err = batchLoadChannelExtrasHelper(
×
5417
                        ctx, cfg, db, channelIDs,
×
5418
                )
×
5419
                if err != nil {
×
5420
                        return nil, fmt.Errorf("unable to batch load "+
×
5421
                                "channel extras: %w", err)
×
5422
                }
×
5423
        }
5424

5425
        if len(policyIDs) > 0 {
×
5426
                policyExtras, err := batchLoadChannelPolicyExtrasHelper(
×
5427
                        ctx, cfg, db, policyIDs,
×
5428
                )
×
5429
                if err != nil {
×
5430
                        return nil, fmt.Errorf("unable to batch load "+
×
5431
                                "policy extras: %w", err)
×
5432
                }
×
5433
                batchData.policyExtras = policyExtras
×
5434
        }
5435

5436
        return batchData, nil
×
5437
}
5438

5439
// batchLoadChannelFeaturesHelper loads channel features for a batch of
5440
// channel IDs using ExecuteBatchQuery wrapper around the
5441
// GetChannelFeaturesBatch query. It returns a map from DB channel ID to a
5442
// slice of feature bits.
5443
func batchLoadChannelFeaturesHelper(ctx context.Context,
5444
        cfg *sqldb.QueryConfig, db SQLQueries,
5445
        channelIDs []int64) (map[int64][]int, error) {
×
5446

×
5447
        features := make(map[int64][]int)
×
5448

×
5449
        return features, sqldb.ExecuteBatchQuery(
×
5450
                ctx, cfg, channelIDs,
×
5451
                func(id int64) int64 {
×
5452
                        return id
×
5453
                },
×
5454
                func(ctx context.Context,
5455
                        ids []int64) ([]sqlc.GraphChannelFeature, error) {
×
5456

×
5457
                        return db.GetChannelFeaturesBatch(ctx, ids)
×
5458
                },
×
5459
                func(ctx context.Context,
5460
                        feature sqlc.GraphChannelFeature) error {
×
5461

×
5462
                        features[feature.ChannelID] = append(
×
5463
                                features[feature.ChannelID],
×
5464
                                int(feature.FeatureBit),
×
5465
                        )
×
5466

×
5467
                        return nil
×
5468
                },
×
5469
        )
5470
}
5471

5472
// batchLoadChannelExtrasHelper loads channel extra types for a batch of
5473
// channel IDs using ExecuteBatchQuery wrapper around the GetChannelExtrasBatch
5474
// query. It returns a map from DB channel ID to a map of TLV type to extra
5475
// signed field bytes.
5476
func batchLoadChannelExtrasHelper(ctx context.Context,
5477
        cfg *sqldb.QueryConfig, db SQLQueries,
5478
        channelIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5479

×
5480
        extras := make(map[int64]map[uint64][]byte)
×
5481

×
5482
        cb := func(ctx context.Context,
×
5483
                extra sqlc.GraphChannelExtraType) error {
×
5484

×
5485
                if extras[extra.ChannelID] == nil {
×
5486
                        extras[extra.ChannelID] = make(map[uint64][]byte)
×
5487
                }
×
5488
                extras[extra.ChannelID][uint64(extra.Type)] = extra.Value
×
5489

×
5490
                return nil
×
5491
        }
5492

5493
        return extras, sqldb.ExecuteBatchQuery(
×
5494
                ctx, cfg, channelIDs,
×
5495
                func(id int64) int64 {
×
5496
                        return id
×
5497
                },
×
5498
                func(ctx context.Context,
5499
                        ids []int64) ([]sqlc.GraphChannelExtraType, error) {
×
5500

×
5501
                        return db.GetChannelExtrasBatch(ctx, ids)
×
5502
                }, cb,
×
5503
        )
5504
}
5505

5506
// batchLoadChannelPolicyExtrasHelper loads channel policy extra types for a
5507
// batch of policy IDs using ExecuteBatchQuery wrapper around the
5508
// GetChannelPolicyExtraTypesBatch query. It returns a map from DB policy ID to
5509
// a map of TLV type to extra signed field bytes.
5510
func batchLoadChannelPolicyExtrasHelper(ctx context.Context,
5511
        cfg *sqldb.QueryConfig, db SQLQueries,
5512
        policyIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5513

×
5514
        extras := make(map[int64]map[uint64][]byte)
×
5515

×
5516
        return extras, sqldb.ExecuteBatchQuery(
×
5517
                ctx, cfg, policyIDs,
×
5518
                func(id int64) int64 {
×
5519
                        return id
×
5520
                },
×
5521
                func(ctx context.Context, ids []int64) (
5522
                        []sqlc.GetChannelPolicyExtraTypesBatchRow, error) {
×
5523

×
5524
                        return db.GetChannelPolicyExtraTypesBatch(ctx, ids)
×
5525
                },
×
5526
                func(ctx context.Context,
5527
                        row sqlc.GetChannelPolicyExtraTypesBatchRow) error {
×
5528

×
5529
                        if extras[row.PolicyID] == nil {
×
5530
                                extras[row.PolicyID] = make(map[uint64][]byte)
×
5531
                        }
×
5532
                        extras[row.PolicyID][uint64(row.Type)] = row.Value
×
5533

×
5534
                        return nil
×
5535
                },
5536
        )
5537
}
5538

5539
// forEachNodePaginated executes a paginated query to process each node in the
5540
// graph. It uses the provided SQLQueries interface to fetch nodes in batches
5541
// and applies the provided processNode function to each node.
5542
func forEachNodePaginated(ctx context.Context, cfg *sqldb.QueryConfig,
5543
        db SQLQueries, protocol lnwire.GossipVersion,
5544
        processNode func(context.Context, int64,
5545
                *models.Node) error) error {
×
5546

×
5547
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
5548
                limit int32) ([]sqlc.GraphNode, error) {
×
5549

×
5550
                return db.ListNodesPaginated(
×
5551
                        ctx, sqlc.ListNodesPaginatedParams{
×
5552
                                Version: int16(protocol),
×
5553
                                ID:      lastID,
×
5554
                                Limit:   limit,
×
5555
                        },
×
5556
                )
×
5557
        }
×
5558

5559
        extractPageCursor := func(node sqlc.GraphNode) int64 {
×
5560
                return node.ID
×
5561
        }
×
5562

5563
        collectFunc := func(node sqlc.GraphNode) (int64, error) {
×
5564
                return node.ID, nil
×
5565
        }
×
5566

5567
        batchQueryFunc := func(ctx context.Context,
×
5568
                nodeIDs []int64) (*batchNodeData, error) {
×
5569

×
5570
                return batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
5571
        }
×
5572

5573
        processItem := func(ctx context.Context, dbNode sqlc.GraphNode,
×
5574
                batchData *batchNodeData) error {
×
5575

×
5576
                node, err := buildNodeWithBatchData(dbNode, batchData)
×
5577
                if err != nil {
×
5578
                        return fmt.Errorf("unable to build "+
×
5579
                                "node(id=%d): %w", dbNode.ID, err)
×
5580
                }
×
5581

5582
                return processNode(ctx, dbNode.ID, node)
×
5583
        }
5584

5585
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
5586
                ctx, cfg, int64(-1), pageQueryFunc, extractPageCursor,
×
5587
                collectFunc, batchQueryFunc, processItem,
×
5588
        )
×
5589
}
5590

5591
// forEachChannelWithPolicies executes a paginated query to process each channel
5592
// with policies in the graph.
5593
func forEachChannelWithPolicies(ctx context.Context, db SQLQueries,
5594
        cfg *SQLStoreConfig, processChannel func(*models.ChannelEdgeInfo,
5595
                *models.ChannelEdgePolicy,
5596
                *models.ChannelEdgePolicy) error) error {
×
5597

×
5598
        type channelBatchIDs struct {
×
5599
                channelID int64
×
5600
                policyIDs []int64
×
5601
        }
×
5602

×
5603
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
5604
                limit int32) ([]sqlc.ListChannelsWithPoliciesPaginatedRow,
×
5605
                error) {
×
5606

×
5607
                return db.ListChannelsWithPoliciesPaginated(
×
5608
                        ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
5609
                                Version: int16(lnwire.GossipVersion1),
×
5610
                                ID:      lastID,
×
5611
                                Limit:   limit,
×
5612
                        },
×
5613
                )
×
5614
        }
×
5615

5616
        extractPageCursor := func(
×
5617
                row sqlc.ListChannelsWithPoliciesPaginatedRow) int64 {
×
5618

×
5619
                return row.GraphChannel.ID
×
5620
        }
×
5621

5622
        collectFunc := func(row sqlc.ListChannelsWithPoliciesPaginatedRow) (
×
5623
                channelBatchIDs, error) {
×
5624

×
5625
                ids := channelBatchIDs{
×
5626
                        channelID: row.GraphChannel.ID,
×
5627
                }
×
5628

×
5629
                // Extract policy IDs from the row.
×
5630
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5631
                if err != nil {
×
5632
                        return ids, err
×
5633
                }
×
5634

5635
                if dbPol1 != nil {
×
5636
                        ids.policyIDs = append(ids.policyIDs, dbPol1.ID)
×
5637
                }
×
5638
                if dbPol2 != nil {
×
5639
                        ids.policyIDs = append(ids.policyIDs, dbPol2.ID)
×
5640
                }
×
5641

5642
                return ids, nil
×
5643
        }
5644

5645
        batchDataFunc := func(ctx context.Context,
×
5646
                allIDs []channelBatchIDs) (*batchChannelData, error) {
×
5647

×
5648
                // Separate channel IDs from policy IDs.
×
5649
                var (
×
5650
                        channelIDs = make([]int64, len(allIDs))
×
5651
                        policyIDs  = make([]int64, 0, len(allIDs)*2)
×
5652
                )
×
5653

×
5654
                for i, ids := range allIDs {
×
5655
                        channelIDs[i] = ids.channelID
×
5656
                        policyIDs = append(policyIDs, ids.policyIDs...)
×
5657
                }
×
5658

5659
                return batchLoadChannelData(
×
5660
                        ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
5661
                )
×
5662
        }
5663

5664
        processItem := func(ctx context.Context,
×
5665
                row sqlc.ListChannelsWithPoliciesPaginatedRow,
×
5666
                batchData *batchChannelData) error {
×
5667

×
5668
                node1, node2, err := buildNodeVertices(
×
5669
                        row.Node1Pubkey, row.Node2Pubkey,
×
5670
                )
×
5671
                if err != nil {
×
5672
                        return err
×
5673
                }
×
5674

5675
                edge, err := buildEdgeInfoWithBatchData(
×
5676
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
5677
                        batchData,
×
5678
                )
×
5679
                if err != nil {
×
5680
                        return fmt.Errorf("unable to build channel info: %w",
×
5681
                                err)
×
5682
                }
×
5683

5684
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5685
                if err != nil {
×
5686
                        return err
×
5687
                }
×
5688

5689
                p1, p2, err := buildChanPoliciesWithBatchData(
×
5690
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
5691
                )
×
5692
                if err != nil {
×
5693
                        return err
×
5694
                }
×
5695

5696
                return processChannel(edge, p1, p2)
×
5697
        }
5698

5699
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
5700
                ctx, cfg.QueryCfg, int64(-1), pageQueryFunc, extractPageCursor,
×
5701
                collectFunc, batchDataFunc, processItem,
×
5702
        )
×
5703
}
5704

5705
// buildDirectedChannel builds a DirectedChannel instance from the provided
5706
// data.
5707
func buildDirectedChannel(chain chainhash.Hash, nodeID int64,
5708
        nodePub route.Vertex, channelRow sqlc.ListChannelsForNodeIDsRow,
5709
        channelBatchData *batchChannelData, features *lnwire.FeatureVector,
5710
        toNodeCallback func() route.Vertex) (*DirectedChannel, error) {
×
5711

×
5712
        node1, node2, err := buildNodeVertices(
×
5713
                channelRow.Node1Pubkey, channelRow.Node2Pubkey,
×
5714
        )
×
5715
        if err != nil {
×
5716
                return nil, fmt.Errorf("unable to build node vertices: %w", err)
×
5717
        }
×
5718

5719
        edge, err := buildEdgeInfoWithBatchData(
×
5720
                chain, channelRow.GraphChannel, node1, node2, channelBatchData,
×
5721
        )
×
5722
        if err != nil {
×
5723
                return nil, fmt.Errorf("unable to build channel info: %w", err)
×
5724
        }
×
5725

5726
        dbPol1, dbPol2, err := extractChannelPolicies(channelRow)
×
5727
        if err != nil {
×
5728
                return nil, fmt.Errorf("unable to extract channel policies: %w",
×
5729
                        err)
×
5730
        }
×
5731

5732
        p1, p2, err := buildChanPoliciesWithBatchData(
×
5733
                dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
5734
                channelBatchData,
×
5735
        )
×
5736
        if err != nil {
×
5737
                return nil, fmt.Errorf("unable to build channel policies: %w",
×
5738
                        err)
×
5739
        }
×
5740

5741
        // Determine outgoing and incoming policy for this specific node.
5742
        p1ToNode := channelRow.GraphChannel.NodeID2
×
5743
        p2ToNode := channelRow.GraphChannel.NodeID1
×
5744
        outPolicy, inPolicy := p1, p2
×
5745
        if (p1 != nil && p1ToNode == nodeID) ||
×
5746
                (p2 != nil && p2ToNode != nodeID) {
×
5747

×
5748
                outPolicy, inPolicy = p2, p1
×
5749
        }
×
5750

5751
        // Build cached policy.
5752
        var cachedInPolicy *models.CachedEdgePolicy
×
5753
        if inPolicy != nil {
×
5754
                cachedInPolicy = models.NewCachedPolicy(inPolicy)
×
5755
                cachedInPolicy.ToNodePubKey = toNodeCallback
×
5756
                cachedInPolicy.ToNodeFeatures = features
×
5757
        }
×
5758

5759
        // Extract inbound fee.
5760
        var inboundFee lnwire.Fee
×
5761
        if outPolicy != nil {
×
5762
                outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
5763
                        inboundFee = fee
×
5764
                })
×
5765
        }
5766

5767
        // Build directed channel.
5768
        directedChannel := &DirectedChannel{
×
5769
                ChannelID:    edge.ChannelID,
×
5770
                IsNode1:      nodePub == edge.NodeKey1Bytes,
×
5771
                OtherNode:    edge.NodeKey2Bytes,
×
5772
                Capacity:     edge.Capacity,
×
5773
                OutPolicySet: outPolicy != nil,
×
5774
                InPolicy:     cachedInPolicy,
×
5775
                InboundFee:   inboundFee,
×
5776
        }
×
5777

×
5778
        if nodePub == edge.NodeKey2Bytes {
×
5779
                directedChannel.OtherNode = edge.NodeKey1Bytes
×
5780
        }
×
5781

5782
        return directedChannel, nil
×
5783
}
5784

5785
// batchBuildChannelEdges builds a slice of ChannelEdge instances from the
5786
// provided rows. It uses batch loading for channels, policies, and nodes.
5787
func batchBuildChannelEdges[T sqlc.ChannelAndNodes](ctx context.Context,
5788
        cfg *SQLStoreConfig, db SQLQueries, rows []T) ([]ChannelEdge, error) {
×
5789

×
5790
        var (
×
5791
                channelIDs = make([]int64, len(rows))
×
5792
                policyIDs  = make([]int64, 0, len(rows)*2)
×
5793
                nodeIDs    = make([]int64, 0, len(rows)*2)
×
5794

×
5795
                // nodeIDSet is used to ensure we only collect unique node IDs.
×
5796
                nodeIDSet = make(map[int64]bool)
×
5797

×
5798
                // edges will hold the final channel edges built from the rows.
×
5799
                edges = make([]ChannelEdge, 0, len(rows))
×
5800
        )
×
5801

×
5802
        // Collect all IDs needed for batch loading.
×
5803
        for i, row := range rows {
×
5804
                channelIDs[i] = row.Channel().ID
×
5805

×
5806
                // Collect policy IDs
×
5807
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5808
                if err != nil {
×
5809
                        return nil, fmt.Errorf("unable to extract channel "+
×
5810
                                "policies: %w", err)
×
5811
                }
×
5812
                if dbPol1 != nil {
×
5813
                        policyIDs = append(policyIDs, dbPol1.ID)
×
5814
                }
×
5815
                if dbPol2 != nil {
×
5816
                        policyIDs = append(policyIDs, dbPol2.ID)
×
5817
                }
×
5818

5819
                var (
×
5820
                        node1ID = row.Node1().ID
×
5821
                        node2ID = row.Node2().ID
×
5822
                )
×
5823

×
5824
                // Collect unique node IDs.
×
5825
                if !nodeIDSet[node1ID] {
×
5826
                        nodeIDs = append(nodeIDs, node1ID)
×
5827
                        nodeIDSet[node1ID] = true
×
5828
                }
×
5829

5830
                if !nodeIDSet[node2ID] {
×
5831
                        nodeIDs = append(nodeIDs, node2ID)
×
5832
                        nodeIDSet[node2ID] = true
×
5833
                }
×
5834
        }
5835

5836
        // Batch the data for all the channels and policies.
5837
        channelBatchData, err := batchLoadChannelData(
×
5838
                ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
5839
        )
×
5840
        if err != nil {
×
5841
                return nil, fmt.Errorf("unable to batch load channel and "+
×
5842
                        "policy data: %w", err)
×
5843
        }
×
5844

5845
        // Batch the data for all the nodes.
5846
        nodeBatchData, err := batchLoadNodeData(ctx, cfg.QueryCfg, db, nodeIDs)
×
5847
        if err != nil {
×
5848
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
5849
                        err)
×
5850
        }
×
5851

5852
        // Build all channel edges using batch data.
5853
        for _, row := range rows {
×
5854
                // Build nodes using batch data.
×
5855
                node1, err := buildNodeWithBatchData(row.Node1(), nodeBatchData)
×
5856
                if err != nil {
×
5857
                        return nil, fmt.Errorf("unable to build node1: %w", err)
×
5858
                }
×
5859

5860
                node2, err := buildNodeWithBatchData(row.Node2(), nodeBatchData)
×
5861
                if err != nil {
×
5862
                        return nil, fmt.Errorf("unable to build node2: %w", err)
×
5863
                }
×
5864

5865
                // Build channel info using batch data.
5866
                channel, err := buildEdgeInfoWithBatchData(
×
5867
                        cfg.ChainHash, row.Channel(), node1.PubKeyBytes,
×
5868
                        node2.PubKeyBytes, channelBatchData,
×
5869
                )
×
5870
                if err != nil {
×
5871
                        return nil, fmt.Errorf("unable to build channel "+
×
5872
                                "info: %w", err)
×
5873
                }
×
5874

5875
                // Extract and build policies using batch data.
5876
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5877
                if err != nil {
×
5878
                        return nil, fmt.Errorf("unable to extract channel "+
×
5879
                                "policies: %w", err)
×
5880
                }
×
5881

5882
                p1, p2, err := buildChanPoliciesWithBatchData(
×
5883
                        dbPol1, dbPol2, channel.ChannelID,
×
5884
                        node1.PubKeyBytes, node2.PubKeyBytes, channelBatchData,
×
5885
                )
×
5886
                if err != nil {
×
5887
                        return nil, fmt.Errorf("unable to build channel "+
×
5888
                                "policies: %w", err)
×
5889
                }
×
5890

5891
                edges = append(edges, ChannelEdge{
×
5892
                        Info:    channel,
×
5893
                        Policy1: p1,
×
5894
                        Policy2: p2,
×
5895
                        Node1:   node1,
×
5896
                        Node2:   node2,
×
5897
                })
×
5898
        }
5899

5900
        return edges, nil
×
5901
}
5902

5903
// batchBuildChannelInfo builds a slice of models.ChannelEdgeInfo
5904
// instances from the provided rows using batch loading for channel data.
5905
func batchBuildChannelInfo[T sqlc.ChannelAndNodeIDs](ctx context.Context,
5906
        cfg *SQLStoreConfig, db SQLQueries, rows []T) (
5907
        []*models.ChannelEdgeInfo, []int64, error) {
×
5908

×
5909
        if len(rows) == 0 {
×
5910
                return nil, nil, nil
×
5911
        }
×
5912

5913
        // Collect all the channel IDs needed for batch loading.
5914
        channelIDs := make([]int64, len(rows))
×
5915
        for i, row := range rows {
×
5916
                channelIDs[i] = row.Channel().ID
×
5917
        }
×
5918

5919
        // Batch load the channel data.
5920
        channelBatchData, err := batchLoadChannelData(
×
5921
                ctx, cfg.QueryCfg, db, channelIDs, nil,
×
5922
        )
×
5923
        if err != nil {
×
5924
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
5925
                        "data: %w", err)
×
5926
        }
×
5927

5928
        // Build all channel edges using batch data.
5929
        edges := make([]*models.ChannelEdgeInfo, 0, len(rows))
×
5930
        for _, row := range rows {
×
5931
                node1, node2, err := buildNodeVertices(
×
5932
                        row.Node1Pub(), row.Node2Pub(),
×
5933
                )
×
5934
                if err != nil {
×
5935
                        return nil, nil, err
×
5936
                }
×
5937

5938
                // Build channel info using batch data
5939
                info, err := buildEdgeInfoWithBatchData(
×
5940
                        cfg.ChainHash, row.Channel(), node1, node2,
×
5941
                        channelBatchData,
×
5942
                )
×
5943
                if err != nil {
×
5944
                        return nil, nil, err
×
5945
                }
×
5946

5947
                edges = append(edges, info)
×
5948
        }
5949

5950
        return edges, channelIDs, nil
×
5951
}
5952

5953
// handleZombieMarking is a helper function that handles the logic of
5954
// marking a channel as a zombie in the database. It takes into account whether
5955
// we are in strict zombie pruning mode, and adjusts the node public keys
5956
// accordingly based on the last update timestamps of the channel policies.
5957
func handleZombieMarking(ctx context.Context, db SQLQueries,
5958
        row sqlc.GetChannelsBySCIDWithPoliciesRow, info *models.ChannelEdgeInfo,
5959
        strictZombiePruning bool, scid uint64) error {
×
5960

×
5961
        nodeKey1, nodeKey2 := info.NodeKey1Bytes, info.NodeKey2Bytes
×
5962

×
5963
        if strictZombiePruning {
×
5964
                var e1UpdateTime, e2UpdateTime *time.Time
×
5965
                if row.Policy1LastUpdate.Valid {
×
5966
                        e1Time := time.Unix(row.Policy1LastUpdate.Int64, 0)
×
5967
                        e1UpdateTime = &e1Time
×
5968
                }
×
5969
                if row.Policy2LastUpdate.Valid {
×
5970
                        e2Time := time.Unix(row.Policy2LastUpdate.Int64, 0)
×
5971
                        e2UpdateTime = &e2Time
×
5972
                }
×
5973

5974
                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
5975
                        info.NodeKey1Bytes, info.NodeKey2Bytes, e1UpdateTime,
×
5976
                        e2UpdateTime,
×
5977
                )
×
5978
        }
5979

5980
        return db.UpsertZombieChannel(
×
5981
                ctx, sqlc.UpsertZombieChannelParams{
×
5982
                        Version:  int16(lnwire.GossipVersion1),
×
5983
                        Scid:     channelIDToBytes(scid),
×
5984
                        NodeKey1: nodeKey1[:],
×
5985
                        NodeKey2: nodeKey2[:],
×
5986
                },
×
5987
        )
×
5988
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc