• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 19360227540

14 Nov 2025 09:28AM UTC coverage: 65.202% (+0.001%) from 65.201%
19360227540

Pull #10371

github

web-flow
Merge 53fb06b42 into ff20dd281
Pull Request #10371: graph/db: fix SetSourceNode no rows error

0 of 113 new or added lines in 2 files covered. (0.0%)

53 existing lines in 14 files now uncovered.

137603 of 211041 relevant lines covered (65.2%)

20746.52 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        color "image/color"
11
        "iter"
12
        "maps"
13
        "math"
14
        "net"
15
        "slices"
16
        "strconv"
17
        "sync"
18
        "time"
19

20
        "github.com/btcsuite/btcd/btcec/v2"
21
        "github.com/btcsuite/btcd/btcutil"
22
        "github.com/btcsuite/btcd/chaincfg/chainhash"
23
        "github.com/btcsuite/btcd/wire"
24
        "github.com/lightningnetwork/lnd/aliasmgr"
25
        "github.com/lightningnetwork/lnd/batch"
26
        "github.com/lightningnetwork/lnd/fn/v2"
27
        "github.com/lightningnetwork/lnd/graph/db/models"
28
        "github.com/lightningnetwork/lnd/lnwire"
29
        "github.com/lightningnetwork/lnd/routing/route"
30
        "github.com/lightningnetwork/lnd/sqldb"
31
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
32
        "github.com/lightningnetwork/lnd/tlv"
33
        "github.com/lightningnetwork/lnd/tor"
34
)
35

36
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
37
// execute queries against the SQL graph tables.
38
//
39
//nolint:ll,interfacebloat
40
type SQLQueries interface {
41
        /*
42
                Node queries.
43
        */
44
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
45
        UpsertSourceNode(ctx context.Context, arg sqlc.UpsertSourceNodeParams) (int64, error)
46
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.GraphNode, error)
47
        GetNodesByIDs(ctx context.Context, ids []int64) ([]sqlc.GraphNode, error)
48
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
49
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.GraphNode, error)
50
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.GraphNode, error)
51
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
52
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
53
        DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error)
54
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
55
        DeleteNode(ctx context.Context, id int64) error
56

57
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeExtraType, error)
58
        GetNodeExtraTypesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeExtraType, error)
59
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
60
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
61

62
        UpsertNodeAddress(ctx context.Context, arg sqlc.UpsertNodeAddressParams) error
63
        GetNodeAddresses(ctx context.Context, nodeID int64) ([]sqlc.GetNodeAddressesRow, error)
64
        GetNodeAddressesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress, error)
65
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
66

67
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
68
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeFeature, error)
69
        GetNodeFeaturesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature, error)
70
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
71
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
72

73
        /*
74
                Source node queries.
75
        */
76
        AddSourceNode(ctx context.Context, nodeID int64) error
77
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
78

79
        /*
80
                Channel queries.
81
        */
82
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
83
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
84
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.GraphChannel, error)
85
        GetChannelsBySCIDs(ctx context.Context, arg sqlc.GetChannelsBySCIDsParams) ([]sqlc.GraphChannel, error)
86
        GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]sqlc.GetChannelsByOutpointsRow, error)
87
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
88
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
89
        GetChannelsBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelsBySCIDWithPoliciesParams) ([]sqlc.GetChannelsBySCIDWithPoliciesRow, error)
90
        GetChannelsByIDs(ctx context.Context, ids []int64) ([]sqlc.GetChannelsByIDsRow, error)
91
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
92
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
93
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
94
        ListChannelsForNodeIDs(ctx context.Context, arg sqlc.ListChannelsForNodeIDsParams) ([]sqlc.ListChannelsForNodeIDsRow, error)
95
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
96
        ListChannelsWithPoliciesForCachePaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesForCachePaginatedParams) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow, error)
97
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
98
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
99
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
100
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.GraphChannel, error)
101
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
102
        DeleteChannels(ctx context.Context, ids []int64) error
103

104
        UpsertChannelExtraType(ctx context.Context, arg sqlc.UpsertChannelExtraTypeParams) error
105
        GetChannelExtrasBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelExtraType, error)
106
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
107
        GetChannelFeaturesBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelFeature, error)
108

109
        /*
110
                Channel Policy table queries.
111
        */
112
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
113
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.GraphChannelPolicy, error)
114
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
115

116
        UpsertChanPolicyExtraType(ctx context.Context, arg sqlc.UpsertChanPolicyExtraTypeParams) error
117
        GetChannelPolicyExtraTypesBatch(ctx context.Context, policyIds []int64) ([]sqlc.GetChannelPolicyExtraTypesBatchRow, error)
118
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
119

120
        /*
121
                Zombie index queries.
122
        */
123
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
124
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.GraphZombieChannel, error)
125
        GetZombieChannelsSCIDs(ctx context.Context, arg sqlc.GetZombieChannelsSCIDsParams) ([]sqlc.GraphZombieChannel, error)
126
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
127
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
128
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
129

130
        /*
131
                Prune log table queries.
132
        */
133
        GetPruneTip(ctx context.Context) (sqlc.GraphPruneLog, error)
134
        GetPruneHashByHeight(ctx context.Context, blockHeight int64) ([]byte, error)
135
        GetPruneEntriesForHeights(ctx context.Context, heights []int64) ([]sqlc.GraphPruneLog, error)
136
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
137
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
138

139
        /*
140
                Closed SCID table queries.
141
        */
142
        InsertClosedChannel(ctx context.Context, scid []byte) error
143
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
144
        GetClosedChannelsSCIDs(ctx context.Context, scids [][]byte) ([][]byte, error)
145

146
        /*
147
                Migration specific queries.
148

149
                NOTE: these should not be used in code other than migrations.
150
                Once sqldbv2 is in place, these can be removed from this struct
151
                as then migrations will have their own dedicated queries
152
                structs.
153
        */
154
        InsertNodeMig(ctx context.Context, arg sqlc.InsertNodeMigParams) (int64, error)
155
        InsertChannelMig(ctx context.Context, arg sqlc.InsertChannelMigParams) (int64, error)
156
        InsertEdgePolicyMig(ctx context.Context, arg sqlc.InsertEdgePolicyMigParams) (int64, error)
157
}
158

159
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
160
// database operations.
161
type BatchedSQLQueries interface {
162
        SQLQueries
163
        sqldb.BatchedTx[SQLQueries]
164
}
165

166
// SQLStore is an implementation of the V1Store interface that uses a SQL
167
// database as the backend.
168
type SQLStore struct {
169
        cfg *SQLStoreConfig
170
        db  BatchedSQLQueries
171

172
        // cacheMu guards all caches (rejectCache and chanCache). If
173
        // this mutex will be acquired at the same time as the DB mutex then
174
        // the cacheMu MUST be acquired first to prevent deadlock.
175
        cacheMu     sync.RWMutex
176
        rejectCache *rejectCache
177
        chanCache   *channelCache
178

179
        chanScheduler batch.Scheduler[SQLQueries]
180
        nodeScheduler batch.Scheduler[SQLQueries]
181

182
        srcNodes  map[lnwire.GossipVersion]*srcNodeInfo
183
        srcNodeMu sync.Mutex
184
}
185

186
// A compile-time assertion to ensure that SQLStore implements the V1Store
187
// interface.
188
var _ V1Store = (*SQLStore)(nil)
189

190
// SQLStoreConfig holds the configuration for the SQLStore.
191
type SQLStoreConfig struct {
192
        // ChainHash is the genesis hash for the chain that all the gossip
193
        // messages in this store are aimed at.
194
        ChainHash chainhash.Hash
195

196
        // QueryConfig holds configuration values for SQL queries.
197
        QueryCfg *sqldb.QueryConfig
198
}
199

200
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
201
// storage backend.
202
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
203
        options ...StoreOptionModifier) (*SQLStore, error) {
×
204

×
205
        opts := DefaultOptions()
×
206
        for _, o := range options {
×
207
                o(opts)
×
208
        }
×
209

210
        if opts.NoMigration {
×
211
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
212
                        "supported for SQL stores")
×
213
        }
×
214

215
        s := &SQLStore{
×
216
                cfg:         cfg,
×
217
                db:          db,
×
218
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
219
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
220
                srcNodes:    make(map[lnwire.GossipVersion]*srcNodeInfo),
×
221
        }
×
222

×
223
        s.chanScheduler = batch.NewTimeScheduler(
×
224
                db, &s.cacheMu, opts.BatchCommitInterval,
×
225
        )
×
226
        s.nodeScheduler = batch.NewTimeScheduler(
×
227
                db, nil, opts.BatchCommitInterval,
×
228
        )
×
229

×
230
        return s, nil
×
231
}
232

233
// AddNode adds a vertex/node to the graph database. If the node is not
234
// in the database from before, this will add a new, unconnected one to the
235
// graph. If it is present from before, this will update that node's
236
// information.
237
//
238
// NOTE: part of the V1Store interface.
239
func (s *SQLStore) AddNode(ctx context.Context,
240
        node *models.Node, opts ...batch.SchedulerOption) error {
×
241

×
242
        r := &batch.Request[SQLQueries]{
×
243
                Opts: batch.NewSchedulerOptions(opts...),
×
244
                Do: func(queries SQLQueries) error {
×
245
                        _, err := upsertNode(ctx, queries, node)
×
246

×
247
                        // It is possible that two of the same node
×
248
                        // announcements are both being processed in the same
×
249
                        // batch. This may case the UpsertNode conflict to
×
250
                        // be hit since we require at the db layer that the
×
251
                        // new last_update is greater than the existing
×
252
                        // last_update. We need to gracefully handle this here.
×
253
                        if errors.Is(err, sql.ErrNoRows) {
×
254
                                return nil
×
255
                        }
×
256

257
                        return err
×
258
                },
259
        }
260

261
        return s.nodeScheduler.Execute(ctx, r)
×
262
}
263

264
// FetchNode attempts to look up a target node by its identity public
265
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
266
// returned.
267
//
268
// NOTE: part of the V1Store interface.
269
func (s *SQLStore) FetchNode(ctx context.Context,
270
        pubKey route.Vertex) (*models.Node, error) {
×
271

×
272
        var node *models.Node
×
273
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
274
                var err error
×
275
                _, node, err = getNodeByPubKey(ctx, s.cfg.QueryCfg, db, pubKey)
×
276

×
277
                return err
×
278
        }, sqldb.NoOpReset)
×
279
        if err != nil {
×
280
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
281
        }
×
282

283
        return node, nil
×
284
}
285

286
// HasNode determines if the graph has a vertex identified by the
287
// target node identity public key. If the node exists in the database, a
288
// timestamp of when the data for the node was lasted updated is returned along
289
// with a true boolean. Otherwise, an empty time.Time is returned with a false
290
// boolean.
291
//
292
// NOTE: part of the V1Store interface.
293
func (s *SQLStore) HasNode(ctx context.Context,
294
        pubKey [33]byte) (time.Time, bool, error) {
×
295

×
296
        var (
×
297
                exists     bool
×
298
                lastUpdate time.Time
×
299
        )
×
300
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
301
                dbNode, err := db.GetNodeByPubKey(
×
302
                        ctx, sqlc.GetNodeByPubKeyParams{
×
303
                                Version: int16(lnwire.GossipVersion1),
×
304
                                PubKey:  pubKey[:],
×
305
                        },
×
306
                )
×
307
                if errors.Is(err, sql.ErrNoRows) {
×
308
                        return nil
×
309
                } else if err != nil {
×
310
                        return fmt.Errorf("unable to fetch node: %w", err)
×
311
                }
×
312

313
                exists = true
×
314

×
315
                if dbNode.LastUpdate.Valid {
×
316
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
317
                }
×
318

319
                return nil
×
320
        }, sqldb.NoOpReset)
321
        if err != nil {
×
322
                return time.Time{}, false,
×
323
                        fmt.Errorf("unable to fetch node: %w", err)
×
324
        }
×
325

326
        return lastUpdate, exists, nil
×
327
}
328

329
// AddrsForNode returns all known addresses for the target node public key
330
// that the graph DB is aware of. The returned boolean indicates if the
331
// given node is unknown to the graph DB or not.
332
//
333
// NOTE: part of the V1Store interface.
334
func (s *SQLStore) AddrsForNode(ctx context.Context,
335
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
336

×
337
        var (
×
338
                addresses []net.Addr
×
339
                known     bool
×
340
        )
×
341
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
342
                // First, check if the node exists and get its DB ID if it
×
343
                // does.
×
344
                dbID, err := db.GetNodeIDByPubKey(
×
345
                        ctx, sqlc.GetNodeIDByPubKeyParams{
×
346
                                Version: int16(lnwire.GossipVersion1),
×
347
                                PubKey:  nodePub.SerializeCompressed(),
×
348
                        },
×
349
                )
×
350
                if errors.Is(err, sql.ErrNoRows) {
×
351
                        return nil
×
352
                }
×
353

354
                known = true
×
355

×
356
                addresses, err = getNodeAddresses(ctx, db, dbID)
×
357
                if err != nil {
×
358
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
359
                                err)
×
360
                }
×
361

362
                return nil
×
363
        }, sqldb.NoOpReset)
364
        if err != nil {
×
365
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
366
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
367
        }
×
368

369
        return known, addresses, nil
×
370
}
371

372
// DeleteNode starts a new database transaction to remove a vertex/node
373
// from the database according to the node's public key.
374
//
375
// NOTE: part of the V1Store interface.
376
func (s *SQLStore) DeleteNode(ctx context.Context,
377
        pubKey route.Vertex) error {
×
378

×
379
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
380
                res, err := db.DeleteNodeByPubKey(
×
381
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
382
                                Version: int16(lnwire.GossipVersion1),
×
383
                                PubKey:  pubKey[:],
×
384
                        },
×
385
                )
×
386
                if err != nil {
×
387
                        return err
×
388
                }
×
389

390
                rows, err := res.RowsAffected()
×
391
                if err != nil {
×
392
                        return err
×
393
                }
×
394

395
                if rows == 0 {
×
396
                        return ErrGraphNodeNotFound
×
397
                } else if rows > 1 {
×
398
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
399
                }
×
400

401
                return err
×
402
        }, sqldb.NoOpReset)
403
        if err != nil {
×
404
                return fmt.Errorf("unable to delete node: %w", err)
×
405
        }
×
406

407
        return nil
×
408
}
409

410
// FetchNodeFeatures returns the features of the given node. If no features are
411
// known for the node, an empty feature vector is returned.
412
//
413
// NOTE: this is part of the graphdb.NodeTraverser interface.
414
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
415
        *lnwire.FeatureVector, error) {
×
416

×
417
        ctx := context.TODO()
×
418

×
419
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
420
}
×
421

422
// DisabledChannelIDs returns the channel ids of disabled channels.
423
// A channel is disabled when two of the associated ChanelEdgePolicies
424
// have their disabled bit on.
425
//
426
// NOTE: part of the V1Store interface.
427
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
428
        var (
×
429
                ctx     = context.TODO()
×
430
                chanIDs []uint64
×
431
        )
×
432
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
433
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
434
                if err != nil {
×
435
                        return fmt.Errorf("unable to fetch disabled "+
×
436
                                "channels: %w", err)
×
437
                }
×
438

439
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
440

×
441
                return nil
×
442
        }, sqldb.NoOpReset)
443
        if err != nil {
×
444
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
445
                        err)
×
446
        }
×
447

448
        return chanIDs, nil
×
449
}
450

451
// LookupAlias attempts to return the alias as advertised by the target node.
452
//
453
// NOTE: part of the V1Store interface.
454
func (s *SQLStore) LookupAlias(ctx context.Context,
455
        pub *btcec.PublicKey) (string, error) {
×
456

×
457
        var alias string
×
458
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
459
                dbNode, err := db.GetNodeByPubKey(
×
460
                        ctx, sqlc.GetNodeByPubKeyParams{
×
461
                                Version: int16(lnwire.GossipVersion1),
×
462
                                PubKey:  pub.SerializeCompressed(),
×
463
                        },
×
464
                )
×
465
                if errors.Is(err, sql.ErrNoRows) {
×
466
                        return ErrNodeAliasNotFound
×
467
                } else if err != nil {
×
468
                        return fmt.Errorf("unable to fetch node: %w", err)
×
469
                }
×
470

471
                if !dbNode.Alias.Valid {
×
472
                        return ErrNodeAliasNotFound
×
473
                }
×
474

475
                alias = dbNode.Alias.String
×
476

×
477
                return nil
×
478
        }, sqldb.NoOpReset)
479
        if err != nil {
×
480
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
481
        }
×
482

483
        return alias, nil
×
484
}
485

486
// SourceNode returns the source node of the graph. The source node is treated
487
// as the center node within a star-graph. This method may be used to kick off
488
// a path finding algorithm in order to explore the reachability of another
489
// node based off the source node.
490
//
491
// NOTE: part of the V1Store interface.
492
func (s *SQLStore) SourceNode(ctx context.Context) (*models.Node,
493
        error) {
×
494

×
495
        var node *models.Node
×
496
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
497
                _, nodePub, err := s.getSourceNode(
×
498
                        ctx, db, lnwire.GossipVersion1,
×
499
                )
×
500
                if err != nil {
×
501
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
502
                                err)
×
503
                }
×
504

505
                _, node, err = getNodeByPubKey(ctx, s.cfg.QueryCfg, db, nodePub)
×
506

×
507
                return err
×
508
        }, sqldb.NoOpReset)
509
        if err != nil {
×
510
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
511
        }
×
512

513
        return node, nil
×
514
}
515

516
// SetSourceNode sets the source node within the graph database. The source
517
// node is to be used as the center of a star-graph within path finding
518
// algorithms.
519
//
520
// NOTE: part of the V1Store interface.
521
func (s *SQLStore) SetSourceNode(ctx context.Context,
522
        node *models.Node) error {
×
523

×
524
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
NEW
525
                // For the source node, we use a less strict upsert that allows
×
NEW
526
                // updates even when the timestamp hasn't changed. This handles
×
NEW
527
                // the race condition where multiple goroutines (e.g.,
×
NEW
528
                // setSelfNode, createNewHiddenService, RPC updates) read the
×
NEW
529
                // same old timestamp, independently increment it, and try to
×
NEW
530
                // write concurrently. We want all parameter changes to persist,
×
NEW
531
                // even if timestamps collide.
×
NEW
532
                id, err := upsertSourceNode(ctx, db, node)
×
533
                if err != nil {
×
534
                        return fmt.Errorf("unable to upsert source node: %w",
×
535
                                err)
×
536
                }
×
537

538
                // Make sure that if a source node for this version is already
539
                // set, then the ID is the same as the one we are about to set.
540
                dbSourceNodeID, _, err := s.getSourceNode(
×
541
                        ctx, db, lnwire.GossipVersion1,
×
542
                )
×
543
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
544
                        return fmt.Errorf("unable to fetch source node: %w",
×
545
                                err)
×
546
                } else if err == nil {
×
547
                        if dbSourceNodeID != id {
×
548
                                return fmt.Errorf("v1 source node already "+
×
549
                                        "set to a different node: %d vs %d",
×
550
                                        dbSourceNodeID, id)
×
551
                        }
×
552

553
                        return nil
×
554
                }
555

556
                return db.AddSourceNode(ctx, id)
×
557
        }, sqldb.NoOpReset)
558
}
559

560
// NodeUpdatesInHorizon returns all the known lightning node which have an
561
// update timestamp within the passed range. This method can be used by two
562
// nodes to quickly determine if they have the same set of up to date node
563
// announcements.
564
//
565
// NOTE: This is part of the V1Store interface.
566
func (s *SQLStore) NodeUpdatesInHorizon(startTime, endTime time.Time,
567
        opts ...IteratorOption) iter.Seq2[*models.Node, error] {
×
568

×
569
        cfg := defaultIteratorConfig()
×
570
        for _, opt := range opts {
×
571
                opt(cfg)
×
572
        }
×
573

574
        return func(yield func(*models.Node, error) bool) {
×
575
                var (
×
576
                        ctx            = context.TODO()
×
577
                        lastUpdateTime sql.NullInt64
×
578
                        lastPubKey     = make([]byte, 33)
×
579
                        hasMore        = true
×
580
                )
×
581

×
582
                // Each iteration, we'll read a batch amount of nodes, yield
×
583
                // them, then decide is we have more or not.
×
584
                for hasMore {
×
585
                        var batch []*models.Node
×
586

×
587
                        //nolint:ll
×
588
                        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
589
                                //nolint:ll
×
590
                                params := sqlc.GetNodesByLastUpdateRangeParams{
×
591
                                        StartTime: sqldb.SQLInt64(
×
592
                                                startTime.Unix(),
×
593
                                        ),
×
594
                                        EndTime: sqldb.SQLInt64(
×
595
                                                endTime.Unix(),
×
596
                                        ),
×
597
                                        LastUpdate: lastUpdateTime,
×
598
                                        LastPubKey: lastPubKey,
×
599
                                        OnlyPublic: sql.NullBool{
×
600
                                                Bool:  cfg.iterPublicNodes,
×
601
                                                Valid: true,
×
602
                                        },
×
603
                                        MaxResults: sqldb.SQLInt32(
×
604
                                                cfg.nodeUpdateIterBatchSize,
×
605
                                        ),
×
606
                                }
×
607
                                rows, err := db.GetNodesByLastUpdateRange(
×
608
                                        ctx, params,
×
609
                                )
×
610
                                if err != nil {
×
611
                                        return err
×
612
                                }
×
613

614
                                hasMore = len(rows) == cfg.nodeUpdateIterBatchSize
×
615

×
616
                                err = forEachNodeInBatch(
×
617
                                        ctx, s.cfg.QueryCfg, db, rows,
×
618
                                        func(_ int64, node *models.Node) error {
×
619
                                                batch = append(batch, node)
×
620

×
621
                                                // Update pagination cursors
×
622
                                                // based on the last processed
×
623
                                                // node.
×
624
                                                lastUpdateTime = sql.NullInt64{
×
625
                                                        Int64: node.LastUpdate.
×
626
                                                                Unix(),
×
627
                                                        Valid: true,
×
628
                                                }
×
629
                                                lastPubKey = node.PubKeyBytes[:]
×
630

×
631
                                                return nil
×
632
                                        },
×
633
                                )
634
                                if err != nil {
×
635
                                        return fmt.Errorf("unable to build "+
×
636
                                                "nodes: %w", err)
×
637
                                }
×
638

639
                                return nil
×
640
                        }, func() {
×
641
                                batch = []*models.Node{}
×
642
                        })
×
643

644
                        if err != nil {
×
645
                                log.Errorf("NodeUpdatesInHorizon batch "+
×
646
                                        "error: %v", err)
×
647

×
648
                                yield(&models.Node{}, err)
×
649

×
650
                                return
×
651
                        }
×
652

653
                        for _, node := range batch {
×
654
                                if !yield(node, nil) {
×
655
                                        return
×
656
                                }
×
657
                        }
658

659
                        // If the batch didn't yield anything, then we're done.
660
                        if len(batch) == 0 {
×
661
                                break
×
662
                        }
663
                }
664
        }
665
}
666

667
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
668
// undirected edge from the two target nodes are created. The information stored
669
// denotes the static attributes of the channel, such as the channelID, the keys
670
// involved in creation of the channel, and the set of features that the channel
671
// supports. The chanPoint and chanID are used to uniquely identify the edge
672
// globally within the database.
673
//
674
// NOTE: part of the V1Store interface.
675
func (s *SQLStore) AddChannelEdge(ctx context.Context,
676
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
677

×
678
        var alreadyExists bool
×
679
        r := &batch.Request[SQLQueries]{
×
680
                Opts: batch.NewSchedulerOptions(opts...),
×
681
                Reset: func() {
×
682
                        alreadyExists = false
×
683
                },
×
684
                Do: func(tx SQLQueries) error {
×
685
                        chanIDB := channelIDToBytes(edge.ChannelID)
×
686

×
687
                        // Make sure that the channel doesn't already exist. We
×
688
                        // do this explicitly instead of relying on catching a
×
689
                        // unique constraint error because relying on SQL to
×
690
                        // throw that error would abort the entire batch of
×
691
                        // transactions.
×
692
                        _, err := tx.GetChannelBySCID(
×
693
                                ctx, sqlc.GetChannelBySCIDParams{
×
694
                                        Scid:    chanIDB,
×
695
                                        Version: int16(lnwire.GossipVersion1),
×
696
                                },
×
697
                        )
×
698
                        if err == nil {
×
699
                                alreadyExists = true
×
700
                                return nil
×
701
                        } else if !errors.Is(err, sql.ErrNoRows) {
×
702
                                return fmt.Errorf("unable to fetch channel: %w",
×
703
                                        err)
×
704
                        }
×
705

706
                        return insertChannel(ctx, tx, edge)
×
707
                },
708
                OnCommit: func(err error) error {
×
709
                        switch {
×
710
                        case err != nil:
×
711
                                return err
×
712
                        case alreadyExists:
×
713
                                return ErrEdgeAlreadyExist
×
714
                        default:
×
715
                                s.rejectCache.remove(edge.ChannelID)
×
716
                                s.chanCache.remove(edge.ChannelID)
×
717
                                return nil
×
718
                        }
719
                },
720
        }
721

722
        return s.chanScheduler.Execute(ctx, r)
×
723
}
724

725
// HighestChanID returns the "highest" known channel ID in the channel graph.
726
// This represents the "newest" channel from the PoV of the chain. This method
727
// can be used by peers to quickly determine if their graphs are in sync.
728
//
729
// NOTE: This is part of the V1Store interface.
730
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
731
        var highestChanID uint64
×
732
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
733
                chanID, err := db.HighestSCID(ctx, int16(lnwire.GossipVersion1))
×
734
                if errors.Is(err, sql.ErrNoRows) {
×
735
                        return nil
×
736
                } else if err != nil {
×
737
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
738
                                err)
×
739
                }
×
740

741
                highestChanID = byteOrder.Uint64(chanID)
×
742

×
743
                return nil
×
744
        }, sqldb.NoOpReset)
745
        if err != nil {
×
746
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
747
        }
×
748

749
        return highestChanID, nil
×
750
}
751

752
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
753
// within the database for the referenced channel. The `flags` attribute within
754
// the ChannelEdgePolicy determines which of the directed edges are being
755
// updated. If the flag is 1, then the first node's information is being
756
// updated, otherwise it's the second node's information. The node ordering is
757
// determined by the lexicographical ordering of the identity public keys of the
758
// nodes on either side of the channel.
759
//
760
// NOTE: part of the V1Store interface.
761
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
762
        edge *models.ChannelEdgePolicy,
763
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
764

×
765
        var (
×
766
                isUpdate1    bool
×
767
                edgeNotFound bool
×
768
                from, to     route.Vertex
×
769
        )
×
770

×
771
        r := &batch.Request[SQLQueries]{
×
772
                Opts: batch.NewSchedulerOptions(opts...),
×
773
                Reset: func() {
×
774
                        isUpdate1 = false
×
775
                        edgeNotFound = false
×
776
                },
×
777
                Do: func(tx SQLQueries) error {
×
778
                        var err error
×
779
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
780
                                ctx, tx, edge,
×
781
                        )
×
782
                        // It is possible that two of the same policy
×
783
                        // announcements are both being processed in the same
×
784
                        // batch. This may case the UpsertEdgePolicy conflict to
×
785
                        // be hit since we require at the db layer that the
×
786
                        // new last_update is greater than the existing
×
787
                        // last_update. We need to gracefully handle this here.
×
788
                        if errors.Is(err, sql.ErrNoRows) {
×
789
                                return nil
×
790
                        } else if err != nil {
×
791
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
792
                        }
×
793

794
                        // Silence ErrEdgeNotFound so that the batch can
795
                        // succeed, but propagate the error via local state.
796
                        if errors.Is(err, ErrEdgeNotFound) {
×
797
                                edgeNotFound = true
×
798
                                return nil
×
799
                        }
×
800

801
                        return err
×
802
                },
803
                OnCommit: func(err error) error {
×
804
                        switch {
×
805
                        case err != nil:
×
806
                                return err
×
807
                        case edgeNotFound:
×
808
                                return ErrEdgeNotFound
×
809
                        default:
×
810
                                s.updateEdgeCache(edge, isUpdate1)
×
811
                                return nil
×
812
                        }
813
                },
814
        }
815

816
        err := s.chanScheduler.Execute(ctx, r)
×
817

×
818
        return from, to, err
×
819
}
820

821
// updateEdgeCache updates our reject and channel caches with the new
822
// edge policy information.
823
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
824
        isUpdate1 bool) {
×
825

×
826
        // If an entry for this channel is found in reject cache, we'll modify
×
827
        // the entry with the updated timestamp for the direction that was just
×
828
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
829
        // during the next query for this edge.
×
830
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
831
                if isUpdate1 {
×
832
                        entry.upd1Time = e.LastUpdate.Unix()
×
833
                } else {
×
834
                        entry.upd2Time = e.LastUpdate.Unix()
×
835
                }
×
836
                s.rejectCache.insert(e.ChannelID, entry)
×
837
        }
838

839
        // If an entry for this channel is found in channel cache, we'll modify
840
        // the entry with the updated policy for the direction that was just
841
        // written. If the edge doesn't exist, we'll defer loading the info and
842
        // policies and lazily read from disk during the next query.
843
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
844
                if isUpdate1 {
×
845
                        channel.Policy1 = e
×
846
                } else {
×
847
                        channel.Policy2 = e
×
848
                }
×
849
                s.chanCache.insert(e.ChannelID, channel)
×
850
        }
851
}
852

853
// ForEachSourceNodeChannel iterates through all channels of the source node,
854
// executing the passed callback on each. The call-back is provided with the
855
// channel's outpoint, whether we have a policy for the channel and the channel
856
// peer's node information.
857
//
858
// NOTE: part of the V1Store interface.
859
func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context,
860
        cb func(chanPoint wire.OutPoint, havePolicy bool,
861
                otherNode *models.Node) error, reset func()) error {
×
862

×
863
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
864
                nodeID, nodePub, err := s.getSourceNode(
×
865
                        ctx, db, lnwire.GossipVersion1,
×
866
                )
×
867
                if err != nil {
×
868
                        return fmt.Errorf("unable to fetch source node: %w",
×
869
                                err)
×
870
                }
×
871

872
                return forEachNodeChannel(
×
873
                        ctx, db, s.cfg, nodeID,
×
874
                        func(info *models.ChannelEdgeInfo,
×
875
                                outPolicy *models.ChannelEdgePolicy,
×
876
                                _ *models.ChannelEdgePolicy) error {
×
877

×
878
                                // Fetch the other node.
×
879
                                var (
×
880
                                        otherNodePub [33]byte
×
881
                                        node1        = info.NodeKey1Bytes
×
882
                                        node2        = info.NodeKey2Bytes
×
883
                                )
×
884
                                switch {
×
885
                                case bytes.Equal(node1[:], nodePub[:]):
×
886
                                        otherNodePub = node2
×
887
                                case bytes.Equal(node2[:], nodePub[:]):
×
888
                                        otherNodePub = node1
×
889
                                default:
×
890
                                        return fmt.Errorf("node not " +
×
891
                                                "participating in this channel")
×
892
                                }
893

894
                                _, otherNode, err := getNodeByPubKey(
×
895
                                        ctx, s.cfg.QueryCfg, db, otherNodePub,
×
896
                                )
×
897
                                if err != nil {
×
898
                                        return fmt.Errorf("unable to fetch "+
×
899
                                                "other node(%x): %w",
×
900
                                                otherNodePub, err)
×
901
                                }
×
902

903
                                return cb(
×
904
                                        info.ChannelPoint, outPolicy != nil,
×
905
                                        otherNode,
×
906
                                )
×
907
                        },
908
                )
909
        }, reset)
910
}
911

912
// ForEachNode iterates through all the stored vertices/nodes in the graph,
913
// executing the passed callback with each node encountered. If the callback
914
// returns an error, then the transaction is aborted and the iteration stops
915
// early.
916
//
917
// NOTE: part of the V1Store interface.
918
func (s *SQLStore) ForEachNode(ctx context.Context,
919
        cb func(node *models.Node) error, reset func()) error {
×
920

×
921
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
922
                return forEachNodePaginated(
×
923
                        ctx, s.cfg.QueryCfg, db,
×
924
                        lnwire.GossipVersion1, func(_ context.Context, _ int64,
×
925
                                node *models.Node) error {
×
926

×
927
                                return cb(node)
×
928
                        },
×
929
                )
930
        }, reset)
931
}
932

933
// ForEachNodeDirectedChannel iterates through all channels of a given node,
934
// executing the passed callback on the directed edge representing the channel
935
// and its incoming policy. If the callback returns an error, then the iteration
936
// is halted with the error propagated back up to the caller.
937
//
938
// Unknown policies are passed into the callback as nil values.
939
//
940
// NOTE: this is part of the graphdb.NodeTraverser interface.
941
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
942
        cb func(channel *DirectedChannel) error, reset func()) error {
×
943

×
944
        var ctx = context.TODO()
×
945

×
946
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
947
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
948
        }, reset)
×
949
}
950

951
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
952
// graph, executing the passed callback with each node encountered. If the
953
// callback returns an error, then the transaction is aborted and the iteration
954
// stops early.
955
func (s *SQLStore) ForEachNodeCacheable(ctx context.Context,
956
        cb func(route.Vertex, *lnwire.FeatureVector) error,
957
        reset func()) error {
×
958

×
959
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
960
                return forEachNodeCacheable(
×
961
                        ctx, s.cfg.QueryCfg, db,
×
962
                        func(_ int64, nodePub route.Vertex,
×
963
                                features *lnwire.FeatureVector) error {
×
964

×
965
                                return cb(nodePub, features)
×
966
                        },
×
967
                )
968
        }, reset)
969
        if err != nil {
×
970
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
971
        }
×
972

973
        return nil
×
974
}
975

976
// ForEachNodeChannel iterates through all channels of the given node,
977
// executing the passed callback with an edge info structure and the policies
978
// of each end of the channel. The first edge policy is the outgoing edge *to*
979
// the connecting node, while the second is the incoming edge *from* the
980
// connecting node. If the callback returns an error, then the iteration is
981
// halted with the error propagated back up to the caller.
982
//
983
// Unknown policies are passed into the callback as nil values.
984
//
985
// NOTE: part of the V1Store interface.
986
func (s *SQLStore) ForEachNodeChannel(ctx context.Context, nodePub route.Vertex,
987
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
988
                *models.ChannelEdgePolicy) error, reset func()) error {
×
989

×
990
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
991
                dbNode, err := db.GetNodeByPubKey(
×
992
                        ctx, sqlc.GetNodeByPubKeyParams{
×
993
                                Version: int16(lnwire.GossipVersion1),
×
994
                                PubKey:  nodePub[:],
×
995
                        },
×
996
                )
×
997
                if errors.Is(err, sql.ErrNoRows) {
×
998
                        return nil
×
999
                } else if err != nil {
×
1000
                        return fmt.Errorf("unable to fetch node: %w", err)
×
1001
                }
×
1002

1003
                return forEachNodeChannel(ctx, db, s.cfg, dbNode.ID, cb)
×
1004
        }, reset)
1005
}
1006

1007
// extractMaxUpdateTime returns the maximum of the two policy update times.
1008
// This is used for pagination cursor tracking.
1009
func extractMaxUpdateTime(
1010
        row sqlc.GetChannelsByPolicyLastUpdateRangeRow) int64 {
×
1011

×
1012
        switch {
×
1013
        case row.Policy1LastUpdate.Valid && row.Policy2LastUpdate.Valid:
×
1014
                return max(row.Policy1LastUpdate.Int64,
×
1015
                        row.Policy2LastUpdate.Int64)
×
1016
        case row.Policy1LastUpdate.Valid:
×
1017
                return row.Policy1LastUpdate.Int64
×
1018
        case row.Policy2LastUpdate.Valid:
×
1019
                return row.Policy2LastUpdate.Int64
×
1020
        default:
×
1021
                return 0
×
1022
        }
1023
}
1024

1025
// buildChannelFromRow constructs a ChannelEdge from a database row.
1026
// This includes building the nodes, channel info, and policies.
1027
func (s *SQLStore) buildChannelFromRow(ctx context.Context, db SQLQueries,
1028
        row sqlc.GetChannelsByPolicyLastUpdateRangeRow) (ChannelEdge, error) {
×
1029

×
1030
        node1, err := buildNode(ctx, s.cfg.QueryCfg, db, row.GraphNode)
×
1031
        if err != nil {
×
1032
                return ChannelEdge{}, fmt.Errorf("unable to build node1: %w",
×
1033
                        err)
×
1034
        }
×
1035

1036
        node2, err := buildNode(ctx, s.cfg.QueryCfg, db, row.GraphNode_2)
×
1037
        if err != nil {
×
1038
                return ChannelEdge{}, fmt.Errorf("unable to build node2: %w",
×
1039
                        err)
×
1040
        }
×
1041

1042
        channel, err := getAndBuildEdgeInfo(
×
1043
                ctx, s.cfg, db,
×
1044
                row.GraphChannel, node1.PubKeyBytes,
×
1045
                node2.PubKeyBytes,
×
1046
        )
×
1047
        if err != nil {
×
1048
                return ChannelEdge{}, fmt.Errorf("unable to build "+
×
1049
                        "channel info: %w", err)
×
1050
        }
×
1051

1052
        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1053
        if err != nil {
×
1054
                return ChannelEdge{}, fmt.Errorf("unable to extract "+
×
1055
                        "channel policies: %w", err)
×
1056
        }
×
1057

1058
        p1, p2, err := getAndBuildChanPolicies(
×
1059
                ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, channel.ChannelID,
×
1060
                node1.PubKeyBytes, node2.PubKeyBytes,
×
1061
        )
×
1062
        if err != nil {
×
1063
                return ChannelEdge{}, fmt.Errorf("unable to build "+
×
1064
                        "channel policies: %w", err)
×
1065
        }
×
1066

1067
        return ChannelEdge{
×
1068
                Info:    channel,
×
1069
                Policy1: p1,
×
1070
                Policy2: p2,
×
1071
                Node1:   node1,
×
1072
                Node2:   node2,
×
1073
        }, nil
×
1074
}
1075

1076
// updateChanCacheBatch updates the channel cache with multiple edges at once.
1077
// This method acquires the cache lock only once for the entire batch.
1078
func (s *SQLStore) updateChanCacheBatch(edgesToCache map[uint64]ChannelEdge) {
×
1079
        if len(edgesToCache) == 0 {
×
1080
                return
×
1081
        }
×
1082

1083
        s.cacheMu.Lock()
×
1084
        defer s.cacheMu.Unlock()
×
1085

×
1086
        for chanID, edge := range edgesToCache {
×
1087
                s.chanCache.insert(chanID, edge)
×
1088
        }
×
1089
}
1090

1091
// ChanUpdatesInHorizon returns all the known channel edges which have at least
1092
// one edge that has an update timestamp within the specified horizon.
1093
//
1094
// Iterator Lifecycle:
1095
// 1. Initialize state (edgesSeen map, cache tracking, pagination cursors)
1096
// 2. Query batch of channels with policies in time range
1097
// 3. For each channel: check if seen, check cache, or build from DB
1098
// 4. Yield channels to caller
1099
// 5. Update cache after successful batch
1100
// 6. Repeat with updated pagination cursor until no more results
1101
//
1102
// NOTE: This is part of the V1Store interface.
1103
func (s *SQLStore) ChanUpdatesInHorizon(startTime, endTime time.Time,
1104
        opts ...IteratorOption) iter.Seq2[ChannelEdge, error] {
×
1105

×
1106
        // Apply options.
×
1107
        cfg := defaultIteratorConfig()
×
1108
        for _, opt := range opts {
×
1109
                opt(cfg)
×
1110
        }
×
1111

1112
        return func(yield func(ChannelEdge, error) bool) {
×
1113
                var (
×
1114
                        ctx            = context.TODO()
×
1115
                        edgesSeen      = make(map[uint64]struct{})
×
1116
                        edgesToCache   = make(map[uint64]ChannelEdge)
×
1117
                        hits           int
×
1118
                        total          int
×
1119
                        lastUpdateTime sql.NullInt64
×
1120
                        lastID         sql.NullInt64
×
1121
                        hasMore        = true
×
1122
                )
×
1123

×
1124
                // Each iteration, we'll read a batch amount of channel updates
×
1125
                // (consulting the cache along the way), yield them, then loop
×
1126
                // back to decide if we have any more updates to read out.
×
1127
                for hasMore {
×
1128
                        var batch []ChannelEdge
×
1129

×
1130
                        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(),
×
1131
                                func(db SQLQueries) error {
×
1132
                                        //nolint:ll
×
1133
                                        params := sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
1134
                                                Version: int16(lnwire.GossipVersion1),
×
1135
                                                StartTime: sqldb.SQLInt64(
×
1136
                                                        startTime.Unix(),
×
1137
                                                ),
×
1138
                                                EndTime: sqldb.SQLInt64(
×
1139
                                                        endTime.Unix(),
×
1140
                                                ),
×
1141
                                                LastUpdateTime: lastUpdateTime,
×
1142
                                                LastID:         lastID,
×
1143
                                                MaxResults: sql.NullInt32{
×
1144
                                                        Int32: int32(
×
1145
                                                                cfg.chanUpdateIterBatchSize,
×
1146
                                                        ),
×
1147
                                                        Valid: true,
×
1148
                                                },
×
1149
                                        }
×
1150
                                        //nolint:ll
×
1151
                                        rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
1152
                                                ctx, params,
×
1153
                                        )
×
1154
                                        if err != nil {
×
1155
                                                return err
×
1156
                                        }
×
1157

1158
                                        //nolint:ll
1159
                                        hasMore = len(rows) == cfg.chanUpdateIterBatchSize
×
1160

×
1161
                                        //nolint:ll
×
1162
                                        for _, row := range rows {
×
1163
                                                lastUpdateTime = sql.NullInt64{
×
1164
                                                        Int64: extractMaxUpdateTime(row),
×
1165
                                                        Valid: true,
×
1166
                                                }
×
1167
                                                lastID = sql.NullInt64{
×
1168
                                                        Int64: row.GraphChannel.ID,
×
1169
                                                        Valid: true,
×
1170
                                                }
×
1171

×
1172
                                                // Skip if we've already
×
1173
                                                // processed this channel.
×
1174
                                                chanIDInt := byteOrder.Uint64(
×
1175
                                                        row.GraphChannel.Scid,
×
1176
                                                )
×
1177
                                                _, ok := edgesSeen[chanIDInt]
×
1178
                                                if ok {
×
1179
                                                        continue
×
1180
                                                }
1181

1182
                                                s.cacheMu.RLock()
×
1183
                                                channel, ok := s.chanCache.get(
×
1184
                                                        chanIDInt,
×
1185
                                                )
×
1186
                                                s.cacheMu.RUnlock()
×
1187
                                                if ok {
×
1188
                                                        hits++
×
1189
                                                        total++
×
1190
                                                        edgesSeen[chanIDInt] = struct{}{}
×
1191
                                                        batch = append(batch, channel)
×
1192

×
1193
                                                        continue
×
1194
                                                }
1195

1196
                                                chanEdge, err := s.buildChannelFromRow(
×
1197
                                                        ctx, db, row,
×
1198
                                                )
×
1199
                                                if err != nil {
×
1200
                                                        return err
×
1201
                                                }
×
1202

1203
                                                edgesSeen[chanIDInt] = struct{}{}
×
1204
                                                edgesToCache[chanIDInt] = chanEdge
×
1205

×
1206
                                                batch = append(batch, chanEdge)
×
1207

×
1208
                                                total++
×
1209
                                        }
1210

1211
                                        return nil
×
1212
                                }, func() {
×
1213
                                        batch = nil
×
1214
                                        edgesSeen = make(map[uint64]struct{})
×
1215
                                        edgesToCache = make(
×
1216
                                                map[uint64]ChannelEdge,
×
1217
                                        )
×
1218
                                })
×
1219

1220
                        if err != nil {
×
1221
                                log.Errorf("ChanUpdatesInHorizon "+
×
1222
                                        "batch error: %v", err)
×
1223

×
1224
                                yield(ChannelEdge{}, err)
×
1225

×
1226
                                return
×
1227
                        }
×
1228

1229
                        for _, edge := range batch {
×
1230
                                if !yield(edge, nil) {
×
1231
                                        return
×
1232
                                }
×
1233
                        }
1234

1235
                        // Update cache after successful batch yield, setting
1236
                        // the cache lock only once for the entire batch.
1237
                        s.updateChanCacheBatch(edgesToCache)
×
1238
                        edgesToCache = make(map[uint64]ChannelEdge)
×
1239

×
1240
                        // If the batch didn't yield anything, then we're done.
×
1241
                        if len(batch) == 0 {
×
1242
                                break
×
1243
                        }
1244
                }
1245

1246
                if total > 0 {
×
1247
                        log.Debugf("ChanUpdatesInHorizon hit percentage: "+
×
1248
                                "%.2f (%d/%d)",
×
1249
                                float64(hits)*100/float64(total), hits, total)
×
1250
                } else {
×
1251
                        log.Debugf("ChanUpdatesInHorizon returned no edges "+
×
1252
                                "in horizon (%s, %s)", startTime, endTime)
×
1253
                }
×
1254
        }
1255
}
1256

1257
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1258
// data to the call-back. If withAddrs is true, then the call-back will also be
1259
// provided with the addresses associated with the node. The address retrieval
1260
// result in an additional round-trip to the database, so it should only be used
1261
// if the addresses are actually needed.
1262
//
1263
// NOTE: part of the V1Store interface.
1264
func (s *SQLStore) ForEachNodeCached(ctx context.Context, withAddrs bool,
1265
        cb func(ctx context.Context, node route.Vertex, addrs []net.Addr,
1266
                chans map[uint64]*DirectedChannel) error, reset func()) error {
×
1267

×
1268
        type nodeCachedBatchData struct {
×
1269
                features      map[int64][]int
×
1270
                addrs         map[int64][]nodeAddress
×
1271
                chanBatchData *batchChannelData
×
1272
                chanMap       map[int64][]sqlc.ListChannelsForNodeIDsRow
×
1273
        }
×
1274

×
1275
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1276
                // pageQueryFunc is used to query the next page of nodes.
×
1277
                pageQueryFunc := func(ctx context.Context, lastID int64,
×
1278
                        limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
1279

×
1280
                        return db.ListNodeIDsAndPubKeys(
×
1281
                                ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
1282
                                        Version: int16(lnwire.GossipVersion1),
×
1283
                                        ID:      lastID,
×
1284
                                        Limit:   limit,
×
1285
                                },
×
1286
                        )
×
1287
                }
×
1288

1289
                // batchDataFunc is then used to batch load the data required
1290
                // for each page of nodes.
1291
                batchDataFunc := func(ctx context.Context,
×
1292
                        nodeIDs []int64) (*nodeCachedBatchData, error) {
×
1293

×
1294
                        // Batch load node features.
×
1295
                        nodeFeatures, err := batchLoadNodeFeaturesHelper(
×
1296
                                ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1297
                        )
×
1298
                        if err != nil {
×
1299
                                return nil, fmt.Errorf("unable to batch load "+
×
1300
                                        "node features: %w", err)
×
1301
                        }
×
1302

1303
                        // Maybe fetch the node's addresses if requested.
1304
                        var nodeAddrs map[int64][]nodeAddress
×
1305
                        if withAddrs {
×
1306
                                nodeAddrs, err = batchLoadNodeAddressesHelper(
×
1307
                                        ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1308
                                )
×
1309
                                if err != nil {
×
1310
                                        return nil, fmt.Errorf("unable to "+
×
1311
                                                "batch load node "+
×
1312
                                                "addresses: %w", err)
×
1313
                                }
×
1314
                        }
1315

1316
                        // Batch load ALL unique channels for ALL nodes in this
1317
                        // page.
1318
                        allChannels, err := db.ListChannelsForNodeIDs(
×
1319
                                ctx, sqlc.ListChannelsForNodeIDsParams{
×
1320
                                        Version:  int16(lnwire.GossipVersion1),
×
1321
                                        Node1Ids: nodeIDs,
×
1322
                                        Node2Ids: nodeIDs,
×
1323
                                },
×
1324
                        )
×
1325
                        if err != nil {
×
1326
                                return nil, fmt.Errorf("unable to batch "+
×
1327
                                        "fetch channels for nodes: %w", err)
×
1328
                        }
×
1329

1330
                        // Deduplicate channels and collect IDs.
1331
                        var (
×
1332
                                allChannelIDs []int64
×
1333
                                allPolicyIDs  []int64
×
1334
                        )
×
1335
                        uniqueChannels := make(
×
1336
                                map[int64]sqlc.ListChannelsForNodeIDsRow,
×
1337
                        )
×
1338

×
1339
                        for _, channel := range allChannels {
×
1340
                                channelID := channel.GraphChannel.ID
×
1341

×
1342
                                // Only process each unique channel once.
×
1343
                                _, exists := uniqueChannels[channelID]
×
1344
                                if exists {
×
1345
                                        continue
×
1346
                                }
1347

1348
                                uniqueChannels[channelID] = channel
×
1349
                                allChannelIDs = append(allChannelIDs, channelID)
×
1350

×
1351
                                if channel.Policy1ID.Valid {
×
1352
                                        allPolicyIDs = append(
×
1353
                                                allPolicyIDs,
×
1354
                                                channel.Policy1ID.Int64,
×
1355
                                        )
×
1356
                                }
×
1357
                                if channel.Policy2ID.Valid {
×
1358
                                        allPolicyIDs = append(
×
1359
                                                allPolicyIDs,
×
1360
                                                channel.Policy2ID.Int64,
×
1361
                                        )
×
1362
                                }
×
1363
                        }
1364

1365
                        // Batch load channel data for all unique channels.
1366
                        channelBatchData, err := batchLoadChannelData(
×
1367
                                ctx, s.cfg.QueryCfg, db, allChannelIDs,
×
1368
                                allPolicyIDs,
×
1369
                        )
×
1370
                        if err != nil {
×
1371
                                return nil, fmt.Errorf("unable to batch "+
×
1372
                                        "load channel data: %w", err)
×
1373
                        }
×
1374

1375
                        // Create map of node ID to channels that involve this
1376
                        // node.
1377
                        nodeIDSet := make(map[int64]bool)
×
1378
                        for _, nodeID := range nodeIDs {
×
1379
                                nodeIDSet[nodeID] = true
×
1380
                        }
×
1381

1382
                        nodeChannelMap := make(
×
1383
                                map[int64][]sqlc.ListChannelsForNodeIDsRow,
×
1384
                        )
×
1385
                        for _, channel := range uniqueChannels {
×
1386
                                // Add channel to both nodes if they're in our
×
1387
                                // current page.
×
1388
                                node1 := channel.GraphChannel.NodeID1
×
1389
                                if nodeIDSet[node1] {
×
1390
                                        nodeChannelMap[node1] = append(
×
1391
                                                nodeChannelMap[node1], channel,
×
1392
                                        )
×
1393
                                }
×
1394
                                node2 := channel.GraphChannel.NodeID2
×
1395
                                if nodeIDSet[node2] {
×
1396
                                        nodeChannelMap[node2] = append(
×
1397
                                                nodeChannelMap[node2], channel,
×
1398
                                        )
×
1399
                                }
×
1400
                        }
1401

1402
                        return &nodeCachedBatchData{
×
1403
                                features:      nodeFeatures,
×
1404
                                addrs:         nodeAddrs,
×
1405
                                chanBatchData: channelBatchData,
×
1406
                                chanMap:       nodeChannelMap,
×
1407
                        }, nil
×
1408
                }
1409

1410
                // processItem is used to process each node in the current page.
1411
                processItem := func(ctx context.Context,
×
1412
                        nodeData sqlc.ListNodeIDsAndPubKeysRow,
×
1413
                        batchData *nodeCachedBatchData) error {
×
1414

×
1415
                        // Build feature vector for this node.
×
1416
                        fv := lnwire.EmptyFeatureVector()
×
1417
                        features, exists := batchData.features[nodeData.ID]
×
1418
                        if exists {
×
1419
                                for _, bit := range features {
×
1420
                                        fv.Set(lnwire.FeatureBit(bit))
×
1421
                                }
×
1422
                        }
1423

1424
                        var nodePub route.Vertex
×
1425
                        copy(nodePub[:], nodeData.PubKey)
×
1426

×
1427
                        nodeChannels := batchData.chanMap[nodeData.ID]
×
1428

×
1429
                        toNodeCallback := func() route.Vertex {
×
1430
                                return nodePub
×
1431
                        }
×
1432

1433
                        // Build cached channels map for this node.
1434
                        channels := make(map[uint64]*DirectedChannel)
×
1435
                        for _, channelRow := range nodeChannels {
×
1436
                                directedChan, err := buildDirectedChannel(
×
1437
                                        s.cfg.ChainHash, nodeData.ID, nodePub,
×
1438
                                        channelRow, batchData.chanBatchData, fv,
×
1439
                                        toNodeCallback,
×
1440
                                )
×
1441
                                if err != nil {
×
1442
                                        return err
×
1443
                                }
×
1444

1445
                                channels[directedChan.ChannelID] = directedChan
×
1446
                        }
1447

1448
                        addrs, err := buildNodeAddresses(
×
1449
                                batchData.addrs[nodeData.ID],
×
1450
                        )
×
1451
                        if err != nil {
×
1452
                                return fmt.Errorf("unable to build node "+
×
1453
                                        "addresses: %w", err)
×
1454
                        }
×
1455

1456
                        return cb(ctx, nodePub, addrs, channels)
×
1457
                }
1458

1459
                return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
1460
                        ctx, s.cfg.QueryCfg, int64(-1), pageQueryFunc,
×
1461
                        func(node sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
1462
                                return node.ID
×
1463
                        },
×
1464
                        func(node sqlc.ListNodeIDsAndPubKeysRow) (int64,
1465
                                error) {
×
1466

×
1467
                                return node.ID, nil
×
1468
                        },
×
1469
                        batchDataFunc, processItem,
1470
                )
1471
        }, reset)
1472
}
1473

1474
// ForEachChannelCacheable iterates through all the channel edges stored
1475
// within the graph and invokes the passed callback for each edge. The
1476
// callback takes two edges as since this is a directed graph, both the
1477
// in/out edges are visited. If the callback returns an error, then the
1478
// transaction is aborted and the iteration stops early.
1479
//
1480
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1481
// pointer for that particular channel edge routing policy will be
1482
// passed into the callback.
1483
//
1484
// NOTE: this method is like ForEachChannel but fetches only the data
1485
// required for the graph cache.
1486
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1487
        *models.CachedEdgePolicy, *models.CachedEdgePolicy) error,
1488
        reset func()) error {
×
1489

×
1490
        ctx := context.TODO()
×
1491

×
1492
        handleChannel := func(_ context.Context,
×
1493
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) error {
×
1494

×
1495
                node1, node2, err := buildNodeVertices(
×
1496
                        row.Node1Pubkey, row.Node2Pubkey,
×
1497
                )
×
1498
                if err != nil {
×
1499
                        return err
×
1500
                }
×
1501

1502
                edge := buildCacheableChannelInfo(
×
1503
                        row.Scid, row.Capacity.Int64, node1, node2,
×
1504
                )
×
1505

×
1506
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1507
                if err != nil {
×
1508
                        return err
×
1509
                }
×
1510

1511
                pol1, pol2, err := buildCachedChanPolicies(
×
1512
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1513
                )
×
1514
                if err != nil {
×
1515
                        return err
×
1516
                }
×
1517

1518
                return cb(edge, pol1, pol2)
×
1519
        }
1520

1521
        extractCursor := func(
×
1522
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) int64 {
×
1523

×
1524
                return row.ID
×
1525
        }
×
1526

1527
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1528
                //nolint:ll
×
1529
                queryFunc := func(ctx context.Context, lastID int64,
×
1530
                        limit int32) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow,
×
1531
                        error) {
×
1532

×
1533
                        return db.ListChannelsWithPoliciesForCachePaginated(
×
1534
                                ctx, sqlc.ListChannelsWithPoliciesForCachePaginatedParams{
×
1535
                                        Version: int16(lnwire.GossipVersion1),
×
1536
                                        ID:      lastID,
×
1537
                                        Limit:   limit,
×
1538
                                },
×
1539
                        )
×
1540
                }
×
1541

1542
                return sqldb.ExecutePaginatedQuery(
×
1543
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
1544
                        extractCursor, handleChannel,
×
1545
                )
×
1546
        }, reset)
1547
}
1548

1549
// ForEachChannel iterates through all the channel edges stored within the
1550
// graph and invokes the passed callback for each edge. The callback takes two
1551
// edges as since this is a directed graph, both the in/out edges are visited.
1552
// If the callback returns an error, then the transaction is aborted and the
1553
// iteration stops early.
1554
//
1555
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1556
// for that particular channel edge routing policy will be passed into the
1557
// callback.
1558
//
1559
// NOTE: part of the V1Store interface.
1560
func (s *SQLStore) ForEachChannel(ctx context.Context,
1561
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1562
                *models.ChannelEdgePolicy) error, reset func()) error {
×
1563

×
1564
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1565
                return forEachChannelWithPolicies(ctx, db, s.cfg, cb)
×
1566
        }, reset)
×
1567
}
1568

1569
// FilterChannelRange returns the channel ID's of all known channels which were
1570
// mined in a block height within the passed range. The channel IDs are grouped
1571
// by their common block height. This method can be used to quickly share with a
1572
// peer the set of channels we know of within a particular range to catch them
1573
// up after a period of time offline. If withTimestamps is true then the
1574
// timestamp info of the latest received channel update messages of the channel
1575
// will be included in the response.
1576
//
1577
// NOTE: This is part of the V1Store interface.
1578
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1579
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1580

×
1581
        var (
×
1582
                ctx       = context.TODO()
×
1583
                startSCID = &lnwire.ShortChannelID{
×
1584
                        BlockHeight: startHeight,
×
1585
                }
×
1586
                endSCID = lnwire.ShortChannelID{
×
1587
                        BlockHeight: endHeight,
×
1588
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1589
                        TxPosition:  math.MaxUint16,
×
1590
                }
×
1591
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1592
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1593
        )
×
1594

×
1595
        // 1) get all channels where channelID is between start and end chan ID.
×
1596
        // 2) skip if not public (ie, no channel_proof)
×
1597
        // 3) collect that channel.
×
1598
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1599
        //    and add those timestamps to the collected channel.
×
1600
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1601
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1602
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1603
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1604
                                StartScid: chanIDStart,
×
1605
                                EndScid:   chanIDEnd,
×
1606
                        },
×
1607
                )
×
1608
                if err != nil {
×
1609
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1610
                                err)
×
1611
                }
×
1612

1613
                for _, dbChan := range dbChans {
×
1614
                        cid := lnwire.NewShortChanIDFromInt(
×
1615
                                byteOrder.Uint64(dbChan.Scid),
×
1616
                        )
×
1617
                        chanInfo := NewChannelUpdateInfo(
×
1618
                                cid, time.Time{}, time.Time{},
×
1619
                        )
×
1620

×
1621
                        if !withTimestamps {
×
1622
                                channelsPerBlock[cid.BlockHeight] = append(
×
1623
                                        channelsPerBlock[cid.BlockHeight],
×
1624
                                        chanInfo,
×
1625
                                )
×
1626

×
1627
                                continue
×
1628
                        }
1629

1630
                        //nolint:ll
1631
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1632
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1633
                                        Version:   int16(lnwire.GossipVersion1),
×
1634
                                        ChannelID: dbChan.ID,
×
1635
                                        NodeID:    dbChan.NodeID1,
×
1636
                                },
×
1637
                        )
×
1638
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1639
                                return fmt.Errorf("unable to fetch node1 "+
×
1640
                                        "policy: %w", err)
×
1641
                        } else if err == nil {
×
1642
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1643
                                        node1Policy.LastUpdate.Int64, 0,
×
1644
                                )
×
1645
                        }
×
1646

1647
                        //nolint:ll
1648
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1649
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1650
                                        Version:   int16(lnwire.GossipVersion1),
×
1651
                                        ChannelID: dbChan.ID,
×
1652
                                        NodeID:    dbChan.NodeID2,
×
1653
                                },
×
1654
                        )
×
1655
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1656
                                return fmt.Errorf("unable to fetch node2 "+
×
1657
                                        "policy: %w", err)
×
1658
                        } else if err == nil {
×
1659
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1660
                                        node2Policy.LastUpdate.Int64, 0,
×
1661
                                )
×
1662
                        }
×
1663

1664
                        channelsPerBlock[cid.BlockHeight] = append(
×
1665
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1666
                        )
×
1667
                }
1668

1669
                return nil
×
1670
        }, func() {
×
1671
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1672
        })
×
1673
        if err != nil {
×
1674
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1675
        }
×
1676

1677
        if len(channelsPerBlock) == 0 {
×
1678
                return nil, nil
×
1679
        }
×
1680

1681
        // Return the channel ranges in ascending block height order.
1682
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1683
        slices.Sort(blocks)
×
1684

×
1685
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1686
                return BlockChannelRange{
×
1687
                        Height:   block,
×
1688
                        Channels: channelsPerBlock[block],
×
1689
                }
×
1690
        }), nil
×
1691
}
1692

1693
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1694
// zombie. This method is used on an ad-hoc basis, when channels need to be
1695
// marked as zombies outside the normal pruning cycle.
1696
//
1697
// NOTE: part of the V1Store interface.
1698
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1699
        pubKey1, pubKey2 [33]byte) error {
×
1700

×
1701
        ctx := context.TODO()
×
1702

×
1703
        s.cacheMu.Lock()
×
1704
        defer s.cacheMu.Unlock()
×
1705

×
1706
        chanIDB := channelIDToBytes(chanID)
×
1707

×
1708
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1709
                return db.UpsertZombieChannel(
×
1710
                        ctx, sqlc.UpsertZombieChannelParams{
×
1711
                                Version:  int16(lnwire.GossipVersion1),
×
1712
                                Scid:     chanIDB,
×
1713
                                NodeKey1: pubKey1[:],
×
1714
                                NodeKey2: pubKey2[:],
×
1715
                        },
×
1716
                )
×
1717
        }, sqldb.NoOpReset)
×
1718
        if err != nil {
×
1719
                return fmt.Errorf("unable to upsert zombie channel "+
×
1720
                        "(channel_id=%d): %w", chanID, err)
×
1721
        }
×
1722

1723
        s.rejectCache.remove(chanID)
×
1724
        s.chanCache.remove(chanID)
×
1725

×
1726
        return nil
×
1727
}
1728

1729
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1730
//
1731
// NOTE: part of the V1Store interface.
1732
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1733
        s.cacheMu.Lock()
×
1734
        defer s.cacheMu.Unlock()
×
1735

×
1736
        var (
×
1737
                ctx     = context.TODO()
×
1738
                chanIDB = channelIDToBytes(chanID)
×
1739
        )
×
1740

×
1741
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1742
                res, err := db.DeleteZombieChannel(
×
1743
                        ctx, sqlc.DeleteZombieChannelParams{
×
1744
                                Scid:    chanIDB,
×
1745
                                Version: int16(lnwire.GossipVersion1),
×
1746
                        },
×
1747
                )
×
1748
                if err != nil {
×
1749
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1750
                                err)
×
1751
                }
×
1752

1753
                rows, err := res.RowsAffected()
×
1754
                if err != nil {
×
1755
                        return err
×
1756
                }
×
1757

1758
                if rows == 0 {
×
1759
                        return ErrZombieEdgeNotFound
×
1760
                } else if rows > 1 {
×
1761
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1762
                                "expected 1", rows)
×
1763
                }
×
1764

1765
                return nil
×
1766
        }, sqldb.NoOpReset)
1767
        if err != nil {
×
1768
                return fmt.Errorf("unable to mark edge live "+
×
1769
                        "(channel_id=%d): %w", chanID, err)
×
1770
        }
×
1771

1772
        s.rejectCache.remove(chanID)
×
1773
        s.chanCache.remove(chanID)
×
1774

×
1775
        return err
×
1776
}
1777

1778
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1779
// zombie, then the two node public keys corresponding to this edge are also
1780
// returned.
1781
//
1782
// NOTE: part of the V1Store interface.
1783
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
1784
        error) {
×
1785

×
1786
        var (
×
1787
                ctx              = context.TODO()
×
1788
                isZombie         bool
×
1789
                pubKey1, pubKey2 route.Vertex
×
1790
                chanIDB          = channelIDToBytes(chanID)
×
1791
        )
×
1792

×
1793
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1794
                zombie, err := db.GetZombieChannel(
×
1795
                        ctx, sqlc.GetZombieChannelParams{
×
1796
                                Scid:    chanIDB,
×
1797
                                Version: int16(lnwire.GossipVersion1),
×
1798
                        },
×
1799
                )
×
1800
                if errors.Is(err, sql.ErrNoRows) {
×
1801
                        return nil
×
1802
                }
×
1803
                if err != nil {
×
1804
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1805
                                err)
×
1806
                }
×
1807

1808
                copy(pubKey1[:], zombie.NodeKey1)
×
1809
                copy(pubKey2[:], zombie.NodeKey2)
×
1810
                isZombie = true
×
1811

×
1812
                return nil
×
1813
        }, sqldb.NoOpReset)
1814
        if err != nil {
×
1815
                return false, route.Vertex{}, route.Vertex{},
×
1816
                        fmt.Errorf("%w: %w (chanID=%d)",
×
1817
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
×
1818
        }
×
1819

1820
        return isZombie, pubKey1, pubKey2, nil
×
1821
}
1822

1823
// NumZombies returns the current number of zombie channels in the graph.
1824
//
1825
// NOTE: part of the V1Store interface.
1826
func (s *SQLStore) NumZombies() (uint64, error) {
×
1827
        var (
×
1828
                ctx        = context.TODO()
×
1829
                numZombies uint64
×
1830
        )
×
1831
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1832
                count, err := db.CountZombieChannels(
×
1833
                        ctx, int16(lnwire.GossipVersion1),
×
1834
                )
×
1835
                if err != nil {
×
1836
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1837
                                err)
×
1838
                }
×
1839

1840
                numZombies = uint64(count)
×
1841

×
1842
                return nil
×
1843
        }, sqldb.NoOpReset)
1844
        if err != nil {
×
1845
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1846
        }
×
1847

1848
        return numZombies, nil
×
1849
}
1850

1851
// DeleteChannelEdges removes edges with the given channel IDs from the
1852
// database and marks them as zombies. This ensures that we're unable to re-add
1853
// it to our database once again. If an edge does not exist within the
1854
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1855
// true, then when we mark these edges as zombies, we'll set up the keys such
1856
// that we require the node that failed to send the fresh update to be the one
1857
// that resurrects the channel from its zombie state. The markZombie bool
1858
// denotes whether to mark the channel as a zombie.
1859
//
1860
// NOTE: part of the V1Store interface.
1861
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1862
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
1863

×
1864
        s.cacheMu.Lock()
×
1865
        defer s.cacheMu.Unlock()
×
1866

×
1867
        // Keep track of which channels we end up finding so that we can
×
1868
        // correctly return ErrEdgeNotFound if we do not find a channel.
×
1869
        chanLookup := make(map[uint64]struct{}, len(chanIDs))
×
1870
        for _, chanID := range chanIDs {
×
1871
                chanLookup[chanID] = struct{}{}
×
1872
        }
×
1873

1874
        var (
×
1875
                ctx   = context.TODO()
×
1876
                edges []*models.ChannelEdgeInfo
×
1877
        )
×
1878
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1879
                // First, collect all channel rows.
×
1880
                var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow
×
1881
                chanCallBack := func(ctx context.Context,
×
1882
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
1883

×
1884
                        // Deleting the entry from the map indicates that we
×
1885
                        // have found the channel.
×
1886
                        scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1887
                        delete(chanLookup, scid)
×
1888

×
1889
                        channelRows = append(channelRows, row)
×
1890

×
1891
                        return nil
×
1892
                }
×
1893

1894
                err := s.forEachChanWithPoliciesInSCIDList(
×
1895
                        ctx, db, chanCallBack, chanIDs,
×
1896
                )
×
1897
                if err != nil {
×
1898
                        return err
×
1899
                }
×
1900

1901
                if len(chanLookup) > 0 {
×
1902
                        return ErrEdgeNotFound
×
1903
                }
×
1904

1905
                if len(channelRows) == 0 {
×
1906
                        return nil
×
1907
                }
×
1908

1909
                // Batch build all channel edges.
1910
                var chanIDsToDelete []int64
×
1911
                edges, chanIDsToDelete, err = batchBuildChannelInfo(
×
1912
                        ctx, s.cfg, db, channelRows,
×
1913
                )
×
1914
                if err != nil {
×
1915
                        return err
×
1916
                }
×
1917

1918
                if markZombie {
×
1919
                        for i, row := range channelRows {
×
1920
                                scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1921

×
1922
                                err := handleZombieMarking(
×
1923
                                        ctx, db, row, edges[i],
×
1924
                                        strictZombiePruning, scid,
×
1925
                                )
×
1926
                                if err != nil {
×
1927
                                        return fmt.Errorf("unable to mark "+
×
1928
                                                "channel as zombie: %w", err)
×
1929
                                }
×
1930
                        }
1931
                }
1932

1933
                return s.deleteChannels(ctx, db, chanIDsToDelete)
×
1934
        }, func() {
×
1935
                edges = nil
×
1936

×
1937
                // Re-fill the lookup map.
×
1938
                for _, chanID := range chanIDs {
×
1939
                        chanLookup[chanID] = struct{}{}
×
1940
                }
×
1941
        })
1942
        if err != nil {
×
1943
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1944
                        err)
×
1945
        }
×
1946

1947
        for _, chanID := range chanIDs {
×
1948
                s.rejectCache.remove(chanID)
×
1949
                s.chanCache.remove(chanID)
×
1950
        }
×
1951

1952
        return edges, nil
×
1953
}
1954

1955
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1956
// channel identified by the channel ID. If the channel can't be found, then
1957
// ErrEdgeNotFound is returned. A struct which houses the general information
1958
// for the channel itself is returned as well as two structs that contain the
1959
// routing policies for the channel in either direction.
1960
//
1961
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1962
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1963
// the ChannelEdgeInfo will only include the public keys of each node.
1964
//
1965
// NOTE: part of the V1Store interface.
1966
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1967
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1968
        *models.ChannelEdgePolicy, error) {
×
1969

×
1970
        var (
×
1971
                ctx              = context.TODO()
×
1972
                edge             *models.ChannelEdgeInfo
×
1973
                policy1, policy2 *models.ChannelEdgePolicy
×
1974
                chanIDB          = channelIDToBytes(chanID)
×
1975
        )
×
1976
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1977
                row, err := db.GetChannelBySCIDWithPolicies(
×
1978
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1979
                                Scid:    chanIDB,
×
1980
                                Version: int16(lnwire.GossipVersion1),
×
1981
                        },
×
1982
                )
×
1983
                if errors.Is(err, sql.ErrNoRows) {
×
1984
                        // First check if this edge is perhaps in the zombie
×
1985
                        // index.
×
1986
                        zombie, err := db.GetZombieChannel(
×
1987
                                ctx, sqlc.GetZombieChannelParams{
×
1988
                                        Scid:    chanIDB,
×
1989
                                        Version: int16(lnwire.GossipVersion1),
×
1990
                                },
×
1991
                        )
×
1992
                        if errors.Is(err, sql.ErrNoRows) {
×
1993
                                return ErrEdgeNotFound
×
1994
                        } else if err != nil {
×
1995
                                return fmt.Errorf("unable to check if "+
×
1996
                                        "channel is zombie: %w", err)
×
1997
                        }
×
1998

1999
                        // At this point, we know the channel is a zombie, so
2000
                        // we'll return an error indicating this, and we will
2001
                        // populate the edge info with the public keys of each
2002
                        // party as this is the only information we have about
2003
                        // it.
2004
                        edge = &models.ChannelEdgeInfo{}
×
2005
                        copy(edge.NodeKey1Bytes[:], zombie.NodeKey1)
×
2006
                        copy(edge.NodeKey2Bytes[:], zombie.NodeKey2)
×
2007

×
2008
                        return ErrZombieEdge
×
2009
                } else if err != nil {
×
2010
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2011
                }
×
2012

2013
                node1, node2, err := buildNodeVertices(
×
2014
                        row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
2015
                )
×
2016
                if err != nil {
×
2017
                        return err
×
2018
                }
×
2019

2020
                edge, err = getAndBuildEdgeInfo(
×
2021
                        ctx, s.cfg, db, row.GraphChannel, node1, node2,
×
2022
                )
×
2023
                if err != nil {
×
2024
                        return fmt.Errorf("unable to build channel info: %w",
×
2025
                                err)
×
2026
                }
×
2027

2028
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2029
                if err != nil {
×
2030
                        return fmt.Errorf("unable to extract channel "+
×
2031
                                "policies: %w", err)
×
2032
                }
×
2033

2034
                policy1, policy2, err = getAndBuildChanPolicies(
×
2035
                        ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, edge.ChannelID,
×
2036
                        node1, node2,
×
2037
                )
×
2038
                if err != nil {
×
2039
                        return fmt.Errorf("unable to build channel "+
×
2040
                                "policies: %w", err)
×
2041
                }
×
2042

2043
                return nil
×
2044
        }, sqldb.NoOpReset)
2045
        if err != nil {
×
2046
                // If we are returning the ErrZombieEdge, then we also need to
×
2047
                // return the edge info as the method comment indicates that
×
2048
                // this will be populated when the edge is a zombie.
×
2049
                return edge, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
2050
                        err)
×
2051
        }
×
2052

2053
        return edge, policy1, policy2, nil
×
2054
}
2055

2056
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
2057
// the channel identified by the funding outpoint. If the channel can't be
2058
// found, then ErrEdgeNotFound is returned. A struct which houses the general
2059
// information for the channel itself is returned as well as two structs that
2060
// contain the routing policies for the channel in either direction.
2061
//
2062
// NOTE: part of the V1Store interface.
2063
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
2064
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
2065
        *models.ChannelEdgePolicy, error) {
×
2066

×
2067
        var (
×
2068
                ctx              = context.TODO()
×
2069
                edge             *models.ChannelEdgeInfo
×
2070
                policy1, policy2 *models.ChannelEdgePolicy
×
2071
        )
×
2072
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2073
                row, err := db.GetChannelByOutpointWithPolicies(
×
2074
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
2075
                                Outpoint: op.String(),
×
2076
                                Version:  int16(lnwire.GossipVersion1),
×
2077
                        },
×
2078
                )
×
2079
                if errors.Is(err, sql.ErrNoRows) {
×
2080
                        return ErrEdgeNotFound
×
2081
                } else if err != nil {
×
2082
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2083
                }
×
2084

2085
                node1, node2, err := buildNodeVertices(
×
2086
                        row.Node1Pubkey, row.Node2Pubkey,
×
2087
                )
×
2088
                if err != nil {
×
2089
                        return err
×
2090
                }
×
2091

2092
                edge, err = getAndBuildEdgeInfo(
×
2093
                        ctx, s.cfg, db, row.GraphChannel, node1, node2,
×
2094
                )
×
2095
                if err != nil {
×
2096
                        return fmt.Errorf("unable to build channel info: %w",
×
2097
                                err)
×
2098
                }
×
2099

2100
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2101
                if err != nil {
×
2102
                        return fmt.Errorf("unable to extract channel "+
×
2103
                                "policies: %w", err)
×
2104
                }
×
2105

2106
                policy1, policy2, err = getAndBuildChanPolicies(
×
2107
                        ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, edge.ChannelID,
×
2108
                        node1, node2,
×
2109
                )
×
2110
                if err != nil {
×
2111
                        return fmt.Errorf("unable to build channel "+
×
2112
                                "policies: %w", err)
×
2113
                }
×
2114

2115
                return nil
×
2116
        }, sqldb.NoOpReset)
2117
        if err != nil {
×
2118
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
2119
                        err)
×
2120
        }
×
2121

2122
        return edge, policy1, policy2, nil
×
2123
}
2124

2125
// HasChannelEdge returns true if the database knows of a channel edge with the
2126
// passed channel ID, and false otherwise. If an edge with that ID is found
2127
// within the graph, then two time stamps representing the last time the edge
2128
// was updated for both directed edges are returned along with the boolean. If
2129
// it is not found, then the zombie index is checked and its result is returned
2130
// as the second boolean.
2131
//
2132
// NOTE: part of the V1Store interface.
2133
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
2134
        bool, error) {
×
2135

×
2136
        ctx := context.TODO()
×
2137

×
2138
        var (
×
2139
                exists          bool
×
2140
                isZombie        bool
×
2141
                node1LastUpdate time.Time
×
2142
                node2LastUpdate time.Time
×
2143
        )
×
2144

×
2145
        // We'll query the cache with the shared lock held to allow multiple
×
2146
        // readers to access values in the cache concurrently if they exist.
×
2147
        s.cacheMu.RLock()
×
2148
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2149
                s.cacheMu.RUnlock()
×
2150
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2151
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2152
                exists, isZombie = entry.flags.unpack()
×
2153

×
2154
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2155
        }
×
2156
        s.cacheMu.RUnlock()
×
2157

×
2158
        s.cacheMu.Lock()
×
2159
        defer s.cacheMu.Unlock()
×
2160

×
2161
        // The item was not found with the shared lock, so we'll acquire the
×
2162
        // exclusive lock and check the cache again in case another method added
×
2163
        // the entry to the cache while no lock was held.
×
2164
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2165
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2166
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2167
                exists, isZombie = entry.flags.unpack()
×
2168

×
2169
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2170
        }
×
2171

2172
        chanIDB := channelIDToBytes(chanID)
×
2173
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2174
                channel, err := db.GetChannelBySCID(
×
2175
                        ctx, sqlc.GetChannelBySCIDParams{
×
2176
                                Scid:    chanIDB,
×
2177
                                Version: int16(lnwire.GossipVersion1),
×
2178
                        },
×
2179
                )
×
2180
                if errors.Is(err, sql.ErrNoRows) {
×
2181
                        // Check if it is a zombie channel.
×
2182
                        isZombie, err = db.IsZombieChannel(
×
2183
                                ctx, sqlc.IsZombieChannelParams{
×
2184
                                        Scid:    chanIDB,
×
2185
                                        Version: int16(lnwire.GossipVersion1),
×
2186
                                },
×
2187
                        )
×
2188
                        if err != nil {
×
2189
                                return fmt.Errorf("could not check if channel "+
×
2190
                                        "is zombie: %w", err)
×
2191
                        }
×
2192

2193
                        return nil
×
2194
                } else if err != nil {
×
2195
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2196
                }
×
2197

2198
                exists = true
×
2199

×
2200
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
2201
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2202
                                Version:   int16(lnwire.GossipVersion1),
×
2203
                                ChannelID: channel.ID,
×
2204
                                NodeID:    channel.NodeID1,
×
2205
                        },
×
2206
                )
×
2207
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2208
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2209
                                err)
×
2210
                } else if err == nil {
×
2211
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
2212
                }
×
2213

2214
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
2215
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2216
                                Version:   int16(lnwire.GossipVersion1),
×
2217
                                ChannelID: channel.ID,
×
2218
                                NodeID:    channel.NodeID2,
×
2219
                        },
×
2220
                )
×
2221
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2222
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2223
                                err)
×
2224
                } else if err == nil {
×
2225
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
2226
                }
×
2227

2228
                return nil
×
2229
        }, sqldb.NoOpReset)
2230
        if err != nil {
×
2231
                return time.Time{}, time.Time{}, false, false,
×
2232
                        fmt.Errorf("unable to fetch channel: %w", err)
×
2233
        }
×
2234

2235
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
2236
                upd1Time: node1LastUpdate.Unix(),
×
2237
                upd2Time: node2LastUpdate.Unix(),
×
2238
                flags:    packRejectFlags(exists, isZombie),
×
2239
        })
×
2240

×
2241
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2242
}
2243

2244
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2245
// passed channel point (outpoint). If the passed channel doesn't exist within
2246
// the database, then ErrEdgeNotFound is returned.
2247
//
2248
// NOTE: part of the V1Store interface.
2249
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
2250
        var (
×
2251
                ctx       = context.TODO()
×
2252
                channelID uint64
×
2253
        )
×
2254
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2255
                chanID, err := db.GetSCIDByOutpoint(
×
2256
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2257
                                Outpoint: chanPoint.String(),
×
2258
                                Version:  int16(lnwire.GossipVersion1),
×
2259
                        },
×
2260
                )
×
2261
                if errors.Is(err, sql.ErrNoRows) {
×
2262
                        return ErrEdgeNotFound
×
2263
                } else if err != nil {
×
2264
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2265
                                err)
×
2266
                }
×
2267

2268
                channelID = byteOrder.Uint64(chanID)
×
2269

×
2270
                return nil
×
2271
        }, sqldb.NoOpReset)
2272
        if err != nil {
×
2273
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2274
        }
×
2275

2276
        return channelID, nil
×
2277
}
2278

2279
// IsPublicNode is a helper method that determines whether the node with the
2280
// given public key is seen as a public node in the graph from the graph's
2281
// source node's point of view.
2282
//
2283
// NOTE: part of the V1Store interface.
2284
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
2285
        ctx := context.TODO()
×
2286

×
2287
        var isPublic bool
×
2288
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2289
                var err error
×
2290
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2291

×
2292
                return err
×
2293
        }, sqldb.NoOpReset)
×
2294
        if err != nil {
×
2295
                return false, fmt.Errorf("unable to check if node is "+
×
2296
                        "public: %w", err)
×
2297
        }
×
2298

2299
        return isPublic, nil
×
2300
}
2301

2302
// FetchChanInfos returns the set of channel edges that correspond to the passed
2303
// channel ID's. If an edge is the query is unknown to the database, it will
2304
// skipped and the result will contain only those edges that exist at the time
2305
// of the query. This can be used to respond to peer queries that are seeking to
2306
// fill in gaps in their view of the channel graph.
2307
//
2308
// NOTE: part of the V1Store interface.
2309
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2310
        var (
×
2311
                ctx   = context.TODO()
×
2312
                edges = make(map[uint64]ChannelEdge)
×
2313
        )
×
2314
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2315
                // First, collect all channel rows.
×
2316
                var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow
×
2317
                chanCallBack := func(ctx context.Context,
×
2318
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
2319

×
2320
                        channelRows = append(channelRows, row)
×
2321
                        return nil
×
2322
                }
×
2323

2324
                err := s.forEachChanWithPoliciesInSCIDList(
×
2325
                        ctx, db, chanCallBack, chanIDs,
×
2326
                )
×
2327
                if err != nil {
×
2328
                        return err
×
2329
                }
×
2330

2331
                if len(channelRows) == 0 {
×
2332
                        return nil
×
2333
                }
×
2334

2335
                // Batch build all channel edges.
2336
                chans, err := batchBuildChannelEdges(
×
2337
                        ctx, s.cfg, db, channelRows,
×
2338
                )
×
2339
                if err != nil {
×
2340
                        return fmt.Errorf("unable to build channel edges: %w",
×
2341
                                err)
×
2342
                }
×
2343

2344
                for _, c := range chans {
×
2345
                        edges[c.Info.ChannelID] = c
×
2346
                }
×
2347

2348
                return err
×
2349
        }, func() {
×
2350
                clear(edges)
×
2351
        })
×
2352
        if err != nil {
×
2353
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2354
        }
×
2355

2356
        res := make([]ChannelEdge, 0, len(edges))
×
2357
        for _, chanID := range chanIDs {
×
2358
                edge, ok := edges[chanID]
×
2359
                if !ok {
×
2360
                        continue
×
2361
                }
2362

2363
                res = append(res, edge)
×
2364
        }
2365

2366
        return res, nil
×
2367
}
2368

2369
// forEachChanWithPoliciesInSCIDList is a wrapper around the
2370
// GetChannelsBySCIDWithPolicies query that allows us to iterate through
2371
// channels in a paginated manner.
2372
func (s *SQLStore) forEachChanWithPoliciesInSCIDList(ctx context.Context,
2373
        db SQLQueries, cb func(ctx context.Context,
2374
                row sqlc.GetChannelsBySCIDWithPoliciesRow) error,
2375
        chanIDs []uint64) error {
×
2376

×
2377
        queryWrapper := func(ctx context.Context,
×
2378
                scids [][]byte) ([]sqlc.GetChannelsBySCIDWithPoliciesRow,
×
2379
                error) {
×
2380

×
2381
                return db.GetChannelsBySCIDWithPolicies(
×
2382
                        ctx, sqlc.GetChannelsBySCIDWithPoliciesParams{
×
2383
                                Version: int16(lnwire.GossipVersion1),
×
2384
                                Scids:   scids,
×
2385
                        },
×
2386
                )
×
2387
        }
×
2388

2389
        return sqldb.ExecuteBatchQuery(
×
2390
                ctx, s.cfg.QueryCfg, chanIDs, channelIDToBytes, queryWrapper,
×
2391
                cb,
×
2392
        )
×
2393
}
2394

2395
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2396
// ID's that we don't know and are not known zombies of the passed set. In other
2397
// words, we perform a set difference of our set of chan ID's and the ones
2398
// passed in. This method can be used by callers to determine the set of
2399
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2400
// known zombies is also returned.
2401
//
2402
// NOTE: part of the V1Store interface.
2403
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
2404
        []ChannelUpdateInfo, error) {
×
2405

×
2406
        var (
×
2407
                ctx          = context.TODO()
×
2408
                newChanIDs   []uint64
×
2409
                knownZombies []ChannelUpdateInfo
×
2410
                infoLookup   = make(
×
2411
                        map[uint64]ChannelUpdateInfo, len(chansInfo),
×
2412
                )
×
2413
        )
×
2414

×
2415
        // We first build a lookup map of the channel ID's to the
×
2416
        // ChannelUpdateInfo. This allows us to quickly delete channels that we
×
2417
        // already know about.
×
2418
        for _, chanInfo := range chansInfo {
×
2419
                infoLookup[chanInfo.ShortChannelID.ToUint64()] = chanInfo
×
2420
        }
×
2421

2422
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2423
                // The call-back function deletes known channels from
×
2424
                // infoLookup, so that we can later check which channels are
×
2425
                // zombies by only looking at the remaining channels in the set.
×
2426
                cb := func(ctx context.Context,
×
2427
                        channel sqlc.GraphChannel) error {
×
2428

×
2429
                        delete(infoLookup, byteOrder.Uint64(channel.Scid))
×
2430

×
2431
                        return nil
×
2432
                }
×
2433

2434
                err := s.forEachChanInSCIDList(ctx, db, cb, chansInfo)
×
2435
                if err != nil {
×
2436
                        return fmt.Errorf("unable to iterate through "+
×
2437
                                "channels: %w", err)
×
2438
                }
×
2439

2440
                // We want to ensure that we deal with the channels in the
2441
                // same order that they were passed in, so we iterate over the
2442
                // original chansInfo slice and then check if that channel is
2443
                // still in the infoLookup map.
2444
                for _, chanInfo := range chansInfo {
×
2445
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
2446
                        if _, ok := infoLookup[channelID]; !ok {
×
2447
                                continue
×
2448
                        }
2449

2450
                        isZombie, err := db.IsZombieChannel(
×
2451
                                ctx, sqlc.IsZombieChannelParams{
×
2452
                                        Scid:    channelIDToBytes(channelID),
×
2453
                                        Version: int16(lnwire.GossipVersion1),
×
2454
                                },
×
2455
                        )
×
2456
                        if err != nil {
×
2457
                                return fmt.Errorf("unable to fetch zombie "+
×
2458
                                        "channel: %w", err)
×
2459
                        }
×
2460

2461
                        if isZombie {
×
2462
                                knownZombies = append(knownZombies, chanInfo)
×
2463

×
2464
                                continue
×
2465
                        }
2466

2467
                        newChanIDs = append(newChanIDs, channelID)
×
2468
                }
2469

2470
                return nil
×
2471
        }, func() {
×
2472
                newChanIDs = nil
×
2473
                knownZombies = nil
×
2474
                // Rebuild the infoLookup map in case of a rollback.
×
2475
                for _, chanInfo := range chansInfo {
×
2476
                        scid := chanInfo.ShortChannelID.ToUint64()
×
2477
                        infoLookup[scid] = chanInfo
×
2478
                }
×
2479
        })
2480
        if err != nil {
×
2481
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2482
        }
×
2483

2484
        return newChanIDs, knownZombies, nil
×
2485
}
2486

2487
// forEachChanInSCIDList is a helper method that executes a paged query
2488
// against the database to fetch all channels that match the passed
2489
// ChannelUpdateInfo slice. The callback function is called for each channel
2490
// that is found.
2491
func (s *SQLStore) forEachChanInSCIDList(ctx context.Context, db SQLQueries,
2492
        cb func(ctx context.Context, channel sqlc.GraphChannel) error,
2493
        chansInfo []ChannelUpdateInfo) error {
×
2494

×
2495
        queryWrapper := func(ctx context.Context,
×
2496
                scids [][]byte) ([]sqlc.GraphChannel, error) {
×
2497

×
2498
                return db.GetChannelsBySCIDs(
×
2499
                        ctx, sqlc.GetChannelsBySCIDsParams{
×
2500
                                Version: int16(lnwire.GossipVersion1),
×
2501
                                Scids:   scids,
×
2502
                        },
×
2503
                )
×
2504
        }
×
2505

2506
        chanIDConverter := func(chanInfo ChannelUpdateInfo) []byte {
×
2507
                channelID := chanInfo.ShortChannelID.ToUint64()
×
2508

×
2509
                return channelIDToBytes(channelID)
×
2510
        }
×
2511

2512
        return sqldb.ExecuteBatchQuery(
×
2513
                ctx, s.cfg.QueryCfg, chansInfo, chanIDConverter, queryWrapper,
×
2514
                cb,
×
2515
        )
×
2516
}
2517

2518
// PruneGraphNodes is a garbage collection method which attempts to prune out
2519
// any nodes from the channel graph that are currently unconnected. This ensure
2520
// that we only maintain a graph of reachable nodes. In the event that a pruned
2521
// node gains more channels, it will be re-added back to the graph.
2522
//
2523
// NOTE: this prunes nodes across protocol versions. It will never prune the
2524
// source nodes.
2525
//
2526
// NOTE: part of the V1Store interface.
2527
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
2528
        var ctx = context.TODO()
×
2529

×
2530
        var prunedNodes []route.Vertex
×
2531
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2532
                var err error
×
2533
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2534

×
2535
                return err
×
2536
        }, func() {
×
2537
                prunedNodes = nil
×
2538
        })
×
2539
        if err != nil {
×
2540
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2541
        }
×
2542

2543
        return prunedNodes, nil
×
2544
}
2545

2546
// PruneGraph prunes newly closed channels from the channel graph in response
2547
// to a new block being solved on the network. Any transactions which spend the
2548
// funding output of any known channels within he graph will be deleted.
2549
// Additionally, the "prune tip", or the last block which has been used to
2550
// prune the graph is stored so callers can ensure the graph is fully in sync
2551
// with the current UTXO state. A slice of channels that have been closed by
2552
// the target block along with any pruned nodes are returned if the function
2553
// succeeds without error.
2554
//
2555
// NOTE: part of the V1Store interface.
2556
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2557
        blockHash *chainhash.Hash, blockHeight uint32) (
2558
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
2559

×
2560
        ctx := context.TODO()
×
2561

×
2562
        s.cacheMu.Lock()
×
2563
        defer s.cacheMu.Unlock()
×
2564

×
2565
        var (
×
2566
                closedChans []*models.ChannelEdgeInfo
×
2567
                prunedNodes []route.Vertex
×
2568
        )
×
2569
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2570
                // First, collect all channel rows that need to be pruned.
×
2571
                var channelRows []sqlc.GetChannelsByOutpointsRow
×
2572
                channelCallback := func(ctx context.Context,
×
2573
                        row sqlc.GetChannelsByOutpointsRow) error {
×
2574

×
2575
                        channelRows = append(channelRows, row)
×
2576

×
2577
                        return nil
×
2578
                }
×
2579

2580
                err := s.forEachChanInOutpoints(
×
2581
                        ctx, db, spentOutputs, channelCallback,
×
2582
                )
×
2583
                if err != nil {
×
2584
                        return fmt.Errorf("unable to fetch channels by "+
×
2585
                                "outpoints: %w", err)
×
2586
                }
×
2587

2588
                if len(channelRows) == 0 {
×
2589
                        // There are no channels to prune. So we can exit early
×
2590
                        // after updating the prune log.
×
2591
                        err = db.UpsertPruneLogEntry(
×
2592
                                ctx, sqlc.UpsertPruneLogEntryParams{
×
2593
                                        BlockHash:   blockHash[:],
×
2594
                                        BlockHeight: int64(blockHeight),
×
2595
                                },
×
2596
                        )
×
2597
                        if err != nil {
×
2598
                                return fmt.Errorf("unable to insert prune log "+
×
2599
                                        "entry: %w", err)
×
2600
                        }
×
2601

2602
                        return nil
×
2603
                }
2604

2605
                // Batch build all channel edges for pruning.
2606
                var chansToDelete []int64
×
2607
                closedChans, chansToDelete, err = batchBuildChannelInfo(
×
2608
                        ctx, s.cfg, db, channelRows,
×
2609
                )
×
2610
                if err != nil {
×
2611
                        return err
×
2612
                }
×
2613

2614
                err = s.deleteChannels(ctx, db, chansToDelete)
×
2615
                if err != nil {
×
2616
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2617
                }
×
2618

2619
                err = db.UpsertPruneLogEntry(
×
2620
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
2621
                                BlockHash:   blockHash[:],
×
2622
                                BlockHeight: int64(blockHeight),
×
2623
                        },
×
2624
                )
×
2625
                if err != nil {
×
2626
                        return fmt.Errorf("unable to insert prune log "+
×
2627
                                "entry: %w", err)
×
2628
                }
×
2629

2630
                // Now that we've pruned some channels, we'll also prune any
2631
                // nodes that no longer have any channels.
2632
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2633
                if err != nil {
×
2634
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2635
                                err)
×
2636
                }
×
2637

2638
                return nil
×
2639
        }, func() {
×
2640
                prunedNodes = nil
×
2641
                closedChans = nil
×
2642
        })
×
2643
        if err != nil {
×
2644
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2645
        }
×
2646

2647
        for _, channel := range closedChans {
×
2648
                s.rejectCache.remove(channel.ChannelID)
×
2649
                s.chanCache.remove(channel.ChannelID)
×
2650
        }
×
2651

2652
        return closedChans, prunedNodes, nil
×
2653
}
2654

2655
// forEachChanInOutpoints is a helper function that executes a paginated
2656
// query to fetch channels by their outpoints and applies the given call-back
2657
// to each.
2658
//
2659
// NOTE: this fetches channels for all protocol versions.
2660
func (s *SQLStore) forEachChanInOutpoints(ctx context.Context, db SQLQueries,
2661
        outpoints []*wire.OutPoint, cb func(ctx context.Context,
2662
                row sqlc.GetChannelsByOutpointsRow) error) error {
×
2663

×
2664
        // Create a wrapper that uses the transaction's db instance to execute
×
2665
        // the query.
×
2666
        queryWrapper := func(ctx context.Context,
×
2667
                pageOutpoints []string) ([]sqlc.GetChannelsByOutpointsRow,
×
2668
                error) {
×
2669

×
2670
                return db.GetChannelsByOutpoints(ctx, pageOutpoints)
×
2671
        }
×
2672

2673
        // Define the conversion function from Outpoint to string.
2674
        outpointToString := func(outpoint *wire.OutPoint) string {
×
2675
                return outpoint.String()
×
2676
        }
×
2677

2678
        return sqldb.ExecuteBatchQuery(
×
2679
                ctx, s.cfg.QueryCfg, outpoints, outpointToString,
×
2680
                queryWrapper, cb,
×
2681
        )
×
2682
}
2683

2684
func (s *SQLStore) deleteChannels(ctx context.Context, db SQLQueries,
2685
        dbIDs []int64) error {
×
2686

×
2687
        // Create a wrapper that uses the transaction's db instance to execute
×
2688
        // the query.
×
2689
        queryWrapper := func(ctx context.Context, ids []int64) ([]any, error) {
×
2690
                return nil, db.DeleteChannels(ctx, ids)
×
2691
        }
×
2692

2693
        idConverter := func(id int64) int64 {
×
2694
                return id
×
2695
        }
×
2696

2697
        return sqldb.ExecuteBatchQuery(
×
2698
                ctx, s.cfg.QueryCfg, dbIDs, idConverter,
×
2699
                queryWrapper, func(ctx context.Context, _ any) error {
×
2700
                        return nil
×
2701
                },
×
2702
        )
2703
}
2704

2705
// ChannelView returns the verifiable edge information for each active channel
2706
// within the known channel graph. The set of UTXOs (along with their scripts)
2707
// returned are the ones that need to be watched on chain to detect channel
2708
// closes on the resident blockchain.
2709
//
2710
// NOTE: part of the V1Store interface.
2711
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
2712
        var (
×
2713
                ctx        = context.TODO()
×
2714
                edgePoints []EdgePoint
×
2715
        )
×
2716

×
2717
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2718
                handleChannel := func(_ context.Context,
×
2719
                        channel sqlc.ListChannelsPaginatedRow) error {
×
2720

×
2721
                        pkScript, err := genMultiSigP2WSH(
×
2722
                                channel.BitcoinKey1, channel.BitcoinKey2,
×
2723
                        )
×
2724
                        if err != nil {
×
2725
                                return err
×
2726
                        }
×
2727

2728
                        op, err := wire.NewOutPointFromString(channel.Outpoint)
×
2729
                        if err != nil {
×
2730
                                return err
×
2731
                        }
×
2732

2733
                        edgePoints = append(edgePoints, EdgePoint{
×
2734
                                FundingPkScript: pkScript,
×
2735
                                OutPoint:        *op,
×
2736
                        })
×
2737

×
2738
                        return nil
×
2739
                }
2740

2741
                queryFunc := func(ctx context.Context, lastID int64,
×
2742
                        limit int32) ([]sqlc.ListChannelsPaginatedRow, error) {
×
2743

×
2744
                        return db.ListChannelsPaginated(
×
2745
                                ctx, sqlc.ListChannelsPaginatedParams{
×
2746
                                        Version: int16(lnwire.GossipVersion1),
×
2747
                                        ID:      lastID,
×
2748
                                        Limit:   limit,
×
2749
                                },
×
2750
                        )
×
2751
                }
×
2752

2753
                extractCursor := func(row sqlc.ListChannelsPaginatedRow) int64 {
×
2754
                        return row.ID
×
2755
                }
×
2756

2757
                return sqldb.ExecutePaginatedQuery(
×
2758
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
2759
                        extractCursor, handleChannel,
×
2760
                )
×
2761
        }, func() {
×
2762
                edgePoints = nil
×
2763
        })
×
2764
        if err != nil {
×
2765
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
2766
        }
×
2767

2768
        return edgePoints, nil
×
2769
}
2770

2771
// PruneTip returns the block height and hash of the latest block that has been
2772
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2773
// to tell if the graph is currently in sync with the current best known UTXO
2774
// state.
2775
//
2776
// NOTE: part of the V1Store interface.
2777
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
2778
        var (
×
2779
                ctx       = context.TODO()
×
2780
                tipHash   chainhash.Hash
×
2781
                tipHeight uint32
×
2782
        )
×
2783
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2784
                pruneTip, err := db.GetPruneTip(ctx)
×
2785
                if errors.Is(err, sql.ErrNoRows) {
×
2786
                        return ErrGraphNeverPruned
×
2787
                } else if err != nil {
×
2788
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
2789
                }
×
2790

2791
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
2792
                tipHeight = uint32(pruneTip.BlockHeight)
×
2793

×
2794
                return nil
×
2795
        }, sqldb.NoOpReset)
2796
        if err != nil {
×
2797
                return nil, 0, err
×
2798
        }
×
2799

2800
        return &tipHash, tipHeight, nil
×
2801
}
2802

2803
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2804
//
2805
// NOTE: this prunes nodes across protocol versions. It will never prune the
2806
// source nodes.
2807
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
2808
        db SQLQueries) ([]route.Vertex, error) {
×
2809

×
2810
        nodeKeys, err := db.DeleteUnconnectedNodes(ctx)
×
2811
        if err != nil {
×
2812
                return nil, fmt.Errorf("unable to delete unconnected "+
×
2813
                        "nodes: %w", err)
×
2814
        }
×
2815

2816
        prunedNodes := make([]route.Vertex, len(nodeKeys))
×
2817
        for i, nodeKey := range nodeKeys {
×
2818
                pub, err := route.NewVertexFromBytes(nodeKey)
×
2819
                if err != nil {
×
2820
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
2821
                                "from bytes: %w", err)
×
2822
                }
×
2823

2824
                prunedNodes[i] = pub
×
2825
        }
2826

2827
        return prunedNodes, nil
×
2828
}
2829

2830
// DisconnectBlockAtHeight is used to indicate that the block specified
2831
// by the passed height has been disconnected from the main chain. This
2832
// will "rewind" the graph back to the height below, deleting channels
2833
// that are no longer confirmed from the graph. The prune log will be
2834
// set to the last prune height valid for the remaining chain.
2835
// Channels that were removed from the graph resulting from the
2836
// disconnected block are returned.
2837
//
2838
// NOTE: part of the V1Store interface.
2839
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
2840
        []*models.ChannelEdgeInfo, error) {
×
2841

×
2842
        ctx := context.TODO()
×
2843

×
2844
        var (
×
2845
                // Every channel having a ShortChannelID starting at 'height'
×
2846
                // will no longer be confirmed.
×
2847
                startShortChanID = lnwire.ShortChannelID{
×
2848
                        BlockHeight: height,
×
2849
                }
×
2850

×
2851
                // Delete everything after this height from the db up until the
×
2852
                // SCID alias range.
×
2853
                endShortChanID = aliasmgr.StartingAlias
×
2854

×
2855
                removedChans []*models.ChannelEdgeInfo
×
2856

×
2857
                chanIDStart = channelIDToBytes(startShortChanID.ToUint64())
×
2858
                chanIDEnd   = channelIDToBytes(endShortChanID.ToUint64())
×
2859
        )
×
2860

×
2861
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2862
                rows, err := db.GetChannelsBySCIDRange(
×
2863
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
2864
                                StartScid: chanIDStart,
×
2865
                                EndScid:   chanIDEnd,
×
2866
                        },
×
2867
                )
×
2868
                if err != nil {
×
2869
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2870
                }
×
2871

2872
                if len(rows) == 0 {
×
2873
                        // No channels to disconnect, but still clean up prune
×
2874
                        // log.
×
2875
                        return db.DeletePruneLogEntriesInRange(
×
2876
                                ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2877
                                        StartHeight: int64(height),
×
2878
                                        EndHeight: int64(
×
2879
                                                endShortChanID.BlockHeight,
×
2880
                                        ),
×
2881
                                },
×
2882
                        )
×
2883
                }
×
2884

2885
                // Batch build all channel edges for disconnection.
2886
                channelEdges, chanIDsToDelete, err := batchBuildChannelInfo(
×
2887
                        ctx, s.cfg, db, rows,
×
2888
                )
×
2889
                if err != nil {
×
2890
                        return err
×
2891
                }
×
2892

2893
                removedChans = channelEdges
×
2894

×
2895
                err = s.deleteChannels(ctx, db, chanIDsToDelete)
×
2896
                if err != nil {
×
2897
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2898
                }
×
2899

2900
                return db.DeletePruneLogEntriesInRange(
×
2901
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2902
                                StartHeight: int64(height),
×
2903
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
2904
                        },
×
2905
                )
×
2906
        }, func() {
×
2907
                removedChans = nil
×
2908
        })
×
2909
        if err != nil {
×
2910
                return nil, fmt.Errorf("unable to disconnect block at "+
×
2911
                        "height: %w", err)
×
2912
        }
×
2913

2914
        for _, channel := range removedChans {
×
2915
                s.rejectCache.remove(channel.ChannelID)
×
2916
                s.chanCache.remove(channel.ChannelID)
×
2917
        }
×
2918

2919
        return removedChans, nil
×
2920
}
2921

2922
// AddEdgeProof sets the proof of an existing edge in the graph database.
2923
//
2924
// NOTE: part of the V1Store interface.
2925
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
2926
        proof *models.ChannelAuthProof) error {
×
2927

×
2928
        var (
×
2929
                ctx       = context.TODO()
×
2930
                scidBytes = channelIDToBytes(scid.ToUint64())
×
2931
        )
×
2932

×
2933
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2934
                res, err := db.AddV1ChannelProof(
×
2935
                        ctx, sqlc.AddV1ChannelProofParams{
×
2936
                                Scid:              scidBytes,
×
2937
                                Node1Signature:    proof.NodeSig1Bytes,
×
2938
                                Node2Signature:    proof.NodeSig2Bytes,
×
2939
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
2940
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
2941
                        },
×
2942
                )
×
2943
                if err != nil {
×
2944
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
2945
                }
×
2946

2947
                n, err := res.RowsAffected()
×
2948
                if err != nil {
×
2949
                        return err
×
2950
                }
×
2951

2952
                if n == 0 {
×
2953
                        return fmt.Errorf("no rows affected when adding edge "+
×
2954
                                "proof for SCID %v", scid)
×
2955
                } else if n > 1 {
×
2956
                        return fmt.Errorf("multiple rows affected when adding "+
×
2957
                                "edge proof for SCID %v: %d rows affected",
×
2958
                                scid, n)
×
2959
                }
×
2960

2961
                return nil
×
2962
        }, sqldb.NoOpReset)
2963
        if err != nil {
×
2964
                return fmt.Errorf("unable to add edge proof: %w", err)
×
2965
        }
×
2966

2967
        return nil
×
2968
}
2969

2970
// PutClosedScid stores a SCID for a closed channel in the database. This is so
2971
// that we can ignore channel announcements that we know to be closed without
2972
// having to validate them and fetch a block.
2973
//
2974
// NOTE: part of the V1Store interface.
2975
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
2976
        var (
×
2977
                ctx     = context.TODO()
×
2978
                chanIDB = channelIDToBytes(scid.ToUint64())
×
2979
        )
×
2980

×
2981
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2982
                return db.InsertClosedChannel(ctx, chanIDB)
×
2983
        }, sqldb.NoOpReset)
×
2984
}
2985

2986
// IsClosedScid checks whether a channel identified by the passed in scid is
2987
// closed. This helps avoid having to perform expensive validation checks.
2988
//
2989
// NOTE: part of the V1Store interface.
2990
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
2991
        var (
×
2992
                ctx      = context.TODO()
×
2993
                isClosed bool
×
2994
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
2995
        )
×
2996
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2997
                var err error
×
2998
                isClosed, err = db.IsClosedChannel(ctx, chanIDB)
×
2999
                if err != nil {
×
3000
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
3001
                                err)
×
3002
                }
×
3003

3004
                return nil
×
3005
        }, sqldb.NoOpReset)
3006
        if err != nil {
×
3007
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
3008
                        err)
×
3009
        }
×
3010

3011
        return isClosed, nil
×
3012
}
3013

3014
// GraphSession will provide the call-back with access to a NodeTraverser
3015
// instance which can be used to perform queries against the channel graph.
3016
//
3017
// NOTE: part of the V1Store interface.
3018
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error,
3019
        reset func()) error {
×
3020

×
3021
        var ctx = context.TODO()
×
3022

×
3023
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
3024
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
3025
        }, reset)
×
3026
}
3027

3028
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
3029
// read only transaction for a consistent view of the graph.
3030
type sqlNodeTraverser struct {
3031
        db    SQLQueries
3032
        chain chainhash.Hash
3033
}
3034

3035
// A compile-time assertion to ensure that sqlNodeTraverser implements the
3036
// NodeTraverser interface.
3037
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
3038

3039
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
3040
func newSQLNodeTraverser(db SQLQueries,
3041
        chain chainhash.Hash) *sqlNodeTraverser {
×
3042

×
3043
        return &sqlNodeTraverser{
×
3044
                db:    db,
×
3045
                chain: chain,
×
3046
        }
×
3047
}
×
3048

3049
// ForEachNodeDirectedChannel calls the callback for every channel of the given
3050
// node.
3051
//
3052
// NOTE: Part of the NodeTraverser interface.
3053
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
3054
        cb func(channel *DirectedChannel) error, _ func()) error {
×
3055

×
3056
        ctx := context.TODO()
×
3057

×
3058
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
3059
}
×
3060

3061
// FetchNodeFeatures returns the features of the given node. If the node is
3062
// unknown, assume no additional features are supported.
3063
//
3064
// NOTE: Part of the NodeTraverser interface.
3065
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
3066
        *lnwire.FeatureVector, error) {
×
3067

×
3068
        ctx := context.TODO()
×
3069

×
3070
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
3071
}
×
3072

3073
// forEachNodeDirectedChannel iterates through all channels of a given
3074
// node, executing the passed callback on the directed edge representing the
3075
// channel and its incoming policy. If the node is not found, no error is
3076
// returned.
3077
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
3078
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
3079

×
3080
        toNodeCallback := func() route.Vertex {
×
3081
                return nodePub
×
3082
        }
×
3083

3084
        dbID, err := db.GetNodeIDByPubKey(
×
3085
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
3086
                        Version: int16(lnwire.GossipVersion1),
×
3087
                        PubKey:  nodePub[:],
×
3088
                },
×
3089
        )
×
3090
        if errors.Is(err, sql.ErrNoRows) {
×
3091
                return nil
×
3092
        } else if err != nil {
×
3093
                return fmt.Errorf("unable to fetch node: %w", err)
×
3094
        }
×
3095

3096
        rows, err := db.ListChannelsByNodeID(
×
3097
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3098
                        Version: int16(lnwire.GossipVersion1),
×
3099
                        NodeID1: dbID,
×
3100
                },
×
3101
        )
×
3102
        if err != nil {
×
3103
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3104
        }
×
3105

3106
        // Exit early if there are no channels for this node so we don't
3107
        // do the unnecessary feature fetching.
3108
        if len(rows) == 0 {
×
3109
                return nil
×
3110
        }
×
3111

3112
        features, err := getNodeFeatures(ctx, db, dbID)
×
3113
        if err != nil {
×
3114
                return fmt.Errorf("unable to fetch node features: %w", err)
×
3115
        }
×
3116

3117
        for _, row := range rows {
×
3118
                node1, node2, err := buildNodeVertices(
×
3119
                        row.Node1Pubkey, row.Node2Pubkey,
×
3120
                )
×
3121
                if err != nil {
×
3122
                        return fmt.Errorf("unable to build node vertices: %w",
×
3123
                                err)
×
3124
                }
×
3125

3126
                edge := buildCacheableChannelInfo(
×
3127
                        row.GraphChannel.Scid, row.GraphChannel.Capacity.Int64,
×
3128
                        node1, node2,
×
3129
                )
×
3130

×
3131
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3132
                if err != nil {
×
3133
                        return err
×
3134
                }
×
3135

3136
                p1, p2, err := buildCachedChanPolicies(
×
3137
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
3138
                )
×
3139
                if err != nil {
×
3140
                        return err
×
3141
                }
×
3142

3143
                // Determine the outgoing and incoming policy for this
3144
                // channel and node combo.
3145
                outPolicy, inPolicy := p1, p2
×
3146
                if p1 != nil && node2 == nodePub {
×
3147
                        outPolicy, inPolicy = p2, p1
×
3148
                } else if p2 != nil && node1 != nodePub {
×
3149
                        outPolicy, inPolicy = p2, p1
×
3150
                }
×
3151

3152
                var cachedInPolicy *models.CachedEdgePolicy
×
3153
                if inPolicy != nil {
×
3154
                        cachedInPolicy = inPolicy
×
3155
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
3156
                        cachedInPolicy.ToNodeFeatures = features
×
3157
                }
×
3158

3159
                directedChannel := &DirectedChannel{
×
3160
                        ChannelID:    edge.ChannelID,
×
3161
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
3162
                        OtherNode:    edge.NodeKey2Bytes,
×
3163
                        Capacity:     edge.Capacity,
×
3164
                        OutPolicySet: outPolicy != nil,
×
3165
                        InPolicy:     cachedInPolicy,
×
3166
                }
×
3167
                if outPolicy != nil {
×
3168
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3169
                                directedChannel.InboundFee = fee
×
3170
                        })
×
3171
                }
3172

3173
                if nodePub == edge.NodeKey2Bytes {
×
3174
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
3175
                }
×
3176

3177
                if err := cb(directedChannel); err != nil {
×
3178
                        return err
×
3179
                }
×
3180
        }
3181

3182
        return nil
×
3183
}
3184

3185
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
3186
// and executes the provided callback for each node. It does so via pagination
3187
// along with batch loading of the node feature bits.
3188
func forEachNodeCacheable(ctx context.Context, cfg *sqldb.QueryConfig,
3189
        db SQLQueries, processNode func(nodeID int64, nodePub route.Vertex,
3190
                features *lnwire.FeatureVector) error) error {
×
3191

×
3192
        handleNode := func(_ context.Context,
×
3193
                dbNode sqlc.ListNodeIDsAndPubKeysRow,
×
3194
                featureBits map[int64][]int) error {
×
3195

×
3196
                fv := lnwire.EmptyFeatureVector()
×
3197
                if features, exists := featureBits[dbNode.ID]; exists {
×
3198
                        for _, bit := range features {
×
3199
                                fv.Set(lnwire.FeatureBit(bit))
×
3200
                        }
×
3201
                }
3202

3203
                var pub route.Vertex
×
3204
                copy(pub[:], dbNode.PubKey)
×
3205

×
3206
                return processNode(dbNode.ID, pub, fv)
×
3207
        }
3208

3209
        queryFunc := func(ctx context.Context, lastID int64,
×
3210
                limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
3211

×
3212
                return db.ListNodeIDsAndPubKeys(
×
3213
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
3214
                                Version: int16(lnwire.GossipVersion1),
×
3215
                                ID:      lastID,
×
3216
                                Limit:   limit,
×
3217
                        },
×
3218
                )
×
3219
        }
×
3220

3221
        extractCursor := func(row sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
3222
                return row.ID
×
3223
        }
×
3224

3225
        collectFunc := func(node sqlc.ListNodeIDsAndPubKeysRow) (int64, error) {
×
3226
                return node.ID, nil
×
3227
        }
×
3228

3229
        batchQueryFunc := func(ctx context.Context,
×
3230
                nodeIDs []int64) (map[int64][]int, error) {
×
3231

×
3232
                return batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
3233
        }
×
3234

3235
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
3236
                ctx, cfg, int64(-1), queryFunc, extractCursor, collectFunc,
×
3237
                batchQueryFunc, handleNode,
×
3238
        )
×
3239
}
3240

3241
// forEachNodeChannel iterates through all channels of a node, executing
3242
// the passed callback on each. The call-back is provided with the channel's
3243
// edge information, the outgoing policy and the incoming policy for the
3244
// channel and node combo.
3245
func forEachNodeChannel(ctx context.Context, db SQLQueries,
3246
        cfg *SQLStoreConfig, id int64, cb func(*models.ChannelEdgeInfo,
3247
                *models.ChannelEdgePolicy,
3248
                *models.ChannelEdgePolicy) error) error {
×
3249

×
3250
        // Get all the V1 channels for this node.
×
3251
        rows, err := db.ListChannelsByNodeID(
×
3252
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3253
                        Version: int16(lnwire.GossipVersion1),
×
3254
                        NodeID1: id,
×
3255
                },
×
3256
        )
×
3257
        if err != nil {
×
3258
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3259
        }
×
3260

3261
        // Collect all the channel and policy IDs.
3262
        var (
×
3263
                chanIDs   = make([]int64, 0, len(rows))
×
3264
                policyIDs = make([]int64, 0, 2*len(rows))
×
3265
        )
×
3266
        for _, row := range rows {
×
3267
                chanIDs = append(chanIDs, row.GraphChannel.ID)
×
3268

×
3269
                if row.Policy1ID.Valid {
×
3270
                        policyIDs = append(policyIDs, row.Policy1ID.Int64)
×
3271
                }
×
3272
                if row.Policy2ID.Valid {
×
3273
                        policyIDs = append(policyIDs, row.Policy2ID.Int64)
×
3274
                }
×
3275
        }
3276

3277
        batchData, err := batchLoadChannelData(
×
3278
                ctx, cfg.QueryCfg, db, chanIDs, policyIDs,
×
3279
        )
×
3280
        if err != nil {
×
3281
                return fmt.Errorf("unable to batch load channel data: %w", err)
×
3282
        }
×
3283

3284
        // Call the call-back for each channel and its known policies.
3285
        for _, row := range rows {
×
3286
                node1, node2, err := buildNodeVertices(
×
3287
                        row.Node1Pubkey, row.Node2Pubkey,
×
3288
                )
×
3289
                if err != nil {
×
3290
                        return fmt.Errorf("unable to build node vertices: %w",
×
3291
                                err)
×
3292
                }
×
3293

3294
                edge, err := buildEdgeInfoWithBatchData(
×
3295
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
3296
                        batchData,
×
3297
                )
×
3298
                if err != nil {
×
3299
                        return fmt.Errorf("unable to build channel info: %w",
×
3300
                                err)
×
3301
                }
×
3302

3303
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3304
                if err != nil {
×
3305
                        return fmt.Errorf("unable to extract channel "+
×
3306
                                "policies: %w", err)
×
3307
                }
×
3308

3309
                p1, p2, err := buildChanPoliciesWithBatchData(
×
3310
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
3311
                )
×
3312
                if err != nil {
×
3313
                        return fmt.Errorf("unable to build channel "+
×
3314
                                "policies: %w", err)
×
3315
                }
×
3316

3317
                // Determine the outgoing and incoming policy for this
3318
                // channel and node combo.
3319
                p1ToNode := row.GraphChannel.NodeID2
×
3320
                p2ToNode := row.GraphChannel.NodeID1
×
3321
                outPolicy, inPolicy := p1, p2
×
3322
                if (p1 != nil && p1ToNode == id) ||
×
3323
                        (p2 != nil && p2ToNode != id) {
×
3324

×
3325
                        outPolicy, inPolicy = p2, p1
×
3326
                }
×
3327

3328
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3329
                        return err
×
3330
                }
×
3331
        }
3332

3333
        return nil
×
3334
}
3335

3336
// updateChanEdgePolicy upserts the channel policy info we have stored for
3337
// a channel we already know of.
3338
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3339
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
3340
        error) {
×
3341

×
3342
        var (
×
3343
                node1Pub, node2Pub route.Vertex
×
3344
                isNode1            bool
×
3345
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3346
        )
×
3347

×
3348
        // Check that this edge policy refers to a channel that we already
×
3349
        // know of. We do this explicitly so that we can return the appropriate
×
3350
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
3351
        // abort the transaction which would abort the entire batch.
×
3352
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3353
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3354
                        Scid:    chanIDB,
×
3355
                        Version: int16(lnwire.GossipVersion1),
×
3356
                },
×
3357
        )
×
3358
        if errors.Is(err, sql.ErrNoRows) {
×
3359
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3360
        } else if err != nil {
×
3361
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3362
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3363
        }
×
3364

3365
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3366
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3367

×
3368
        // Figure out which node this edge is from.
×
3369
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3370
        nodeID := dbChan.NodeID1
×
3371
        if !isNode1 {
×
3372
                nodeID = dbChan.NodeID2
×
3373
        }
×
3374

3375
        var (
×
3376
                inboundBase sql.NullInt64
×
3377
                inboundRate sql.NullInt64
×
3378
        )
×
3379
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3380
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3381
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3382
        })
×
3383

3384
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
3385
                Version:     int16(lnwire.GossipVersion1),
×
3386
                ChannelID:   dbChan.ID,
×
3387
                NodeID:      nodeID,
×
3388
                Timelock:    int32(edge.TimeLockDelta),
×
3389
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
3390
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
3391
                MinHtlcMsat: int64(edge.MinHTLC),
×
3392
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3393
                Disabled: sql.NullBool{
×
3394
                        Valid: true,
×
3395
                        Bool:  edge.IsDisabled(),
×
3396
                },
×
3397
                MaxHtlcMsat: sql.NullInt64{
×
3398
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3399
                        Int64: int64(edge.MaxHTLC),
×
3400
                },
×
3401
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
3402
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
3403
                InboundBaseFeeMsat:      inboundBase,
×
3404
                InboundFeeRateMilliMsat: inboundRate,
×
3405
                Signature:               edge.SigBytes,
×
3406
        })
×
3407
        if err != nil {
×
3408
                return node1Pub, node2Pub, isNode1,
×
3409
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3410
        }
×
3411

3412
        // Convert the flat extra opaque data into a map of TLV types to
3413
        // values.
3414
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3415
        if err != nil {
×
3416
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3417
                        "marshal extra opaque data: %w", err)
×
3418
        }
×
3419

3420
        // Update the channel policy's extra signed fields.
3421
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3422
        if err != nil {
×
3423
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3424
                        "policy extra TLVs: %w", err)
×
3425
        }
×
3426

3427
        return node1Pub, node2Pub, isNode1, nil
×
3428
}
3429

3430
// getNodeByPubKey attempts to look up a target node by its public key.
3431
func getNodeByPubKey(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries,
3432
        pubKey route.Vertex) (int64, *models.Node, error) {
×
3433

×
3434
        dbNode, err := db.GetNodeByPubKey(
×
3435
                ctx, sqlc.GetNodeByPubKeyParams{
×
3436
                        Version: int16(lnwire.GossipVersion1),
×
3437
                        PubKey:  pubKey[:],
×
3438
                },
×
3439
        )
×
3440
        if errors.Is(err, sql.ErrNoRows) {
×
3441
                return 0, nil, ErrGraphNodeNotFound
×
3442
        } else if err != nil {
×
3443
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3444
        }
×
3445

3446
        node, err := buildNode(ctx, cfg, db, dbNode)
×
3447
        if err != nil {
×
3448
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3449
        }
×
3450

3451
        return dbNode.ID, node, nil
×
3452
}
3453

3454
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3455
// provided parameters.
3456
func buildCacheableChannelInfo(scid []byte, capacity int64, node1Pub,
3457
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
3458

×
3459
        return &models.CachedEdgeInfo{
×
3460
                ChannelID:     byteOrder.Uint64(scid),
×
3461
                NodeKey1Bytes: node1Pub,
×
3462
                NodeKey2Bytes: node2Pub,
×
3463
                Capacity:      btcutil.Amount(capacity),
×
3464
        }
×
3465
}
×
3466

3467
// buildNode constructs a Node instance from the given database node
3468
// record. The node's features, addresses and extra signed fields are also
3469
// fetched from the database and set on the node.
3470
func buildNode(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries,
3471
        dbNode sqlc.GraphNode) (*models.Node, error) {
×
3472

×
3473
        data, err := batchLoadNodeData(ctx, cfg, db, []int64{dbNode.ID})
×
3474
        if err != nil {
×
3475
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
3476
                        err)
×
3477
        }
×
3478

3479
        return buildNodeWithBatchData(dbNode, data)
×
3480
}
3481

3482
// buildNodeWithBatchData builds a models.Node instance
3483
// from the provided sqlc.GraphNode and batchNodeData. If the node does have
3484
// features/addresses/extra fields, then the corresponding fields are expected
3485
// to be present in the batchNodeData.
3486
func buildNodeWithBatchData(dbNode sqlc.GraphNode,
3487
        batchData *batchNodeData) (*models.Node, error) {
×
3488

×
3489
        if dbNode.Version != int16(lnwire.GossipVersion1) {
×
3490
                return nil, fmt.Errorf("unsupported node version: %d",
×
3491
                        dbNode.Version)
×
3492
        }
×
3493

3494
        var pub [33]byte
×
3495
        copy(pub[:], dbNode.PubKey)
×
3496

×
3497
        node := models.NewV1ShellNode(pub)
×
3498

×
3499
        if len(dbNode.Signature) == 0 {
×
3500
                return node, nil
×
3501
        }
×
3502

3503
        node.AuthSigBytes = dbNode.Signature
×
3504

×
3505
        if dbNode.Alias.Valid {
×
3506
                node.Alias = fn.Some(dbNode.Alias.String)
×
3507
        }
×
3508
        if dbNode.LastUpdate.Valid {
×
3509
                node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
3510
        }
×
3511

3512
        var err error
×
3513
        if dbNode.Color.Valid {
×
3514
                nodeColor, err := DecodeHexColor(dbNode.Color.String)
×
3515
                if err != nil {
×
3516
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3517
                                err)
×
3518
                }
×
3519

3520
                node.Color = fn.Some(nodeColor)
×
3521
        }
3522

3523
        // Use preloaded features.
3524
        if features, exists := batchData.features[dbNode.ID]; exists {
×
3525
                fv := lnwire.EmptyFeatureVector()
×
3526
                for _, bit := range features {
×
3527
                        fv.Set(lnwire.FeatureBit(bit))
×
3528
                }
×
3529
                node.Features = fv
×
3530
        }
3531

3532
        // Use preloaded addresses.
3533
        addresses, exists := batchData.addresses[dbNode.ID]
×
3534
        if exists && len(addresses) > 0 {
×
3535
                node.Addresses, err = buildNodeAddresses(addresses)
×
3536
                if err != nil {
×
3537
                        return nil, fmt.Errorf("unable to build addresses "+
×
3538
                                "for node(%d): %w", dbNode.ID, err)
×
3539
                }
×
3540
        }
3541

3542
        // Use preloaded extra fields.
3543
        if extraFields, exists := batchData.extraFields[dbNode.ID]; exists {
×
3544
                recs, err := lnwire.CustomRecords(extraFields).Serialize()
×
3545
                if err != nil {
×
3546
                        return nil, fmt.Errorf("unable to serialize extra "+
×
3547
                                "signed fields: %w", err)
×
3548
                }
×
3549
                if len(recs) != 0 {
×
3550
                        node.ExtraOpaqueData = recs
×
3551
                }
×
3552
        }
3553

3554
        return node, nil
×
3555
}
3556

3557
// forEachNodeInBatch fetches all nodes in the provided batch, builds them
3558
// with the preloaded data, and executes the provided callback for each node.
3559
func forEachNodeInBatch(ctx context.Context, cfg *sqldb.QueryConfig,
3560
        db SQLQueries, nodes []sqlc.GraphNode,
3561
        cb func(dbID int64, node *models.Node) error) error {
×
3562

×
3563
        // Extract node IDs for batch loading.
×
3564
        nodeIDs := make([]int64, len(nodes))
×
3565
        for i, node := range nodes {
×
3566
                nodeIDs[i] = node.ID
×
3567
        }
×
3568

3569
        // Batch load all related data for this page.
3570
        batchData, err := batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
3571
        if err != nil {
×
3572
                return fmt.Errorf("unable to batch load node data: %w", err)
×
3573
        }
×
3574

3575
        for _, dbNode := range nodes {
×
3576
                node, err := buildNodeWithBatchData(dbNode, batchData)
×
3577
                if err != nil {
×
3578
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
3579
                                dbNode.ID, err)
×
3580
                }
×
3581

3582
                if err := cb(dbNode.ID, node); err != nil {
×
3583
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
3584
                                dbNode.ID, err)
×
3585
                }
×
3586
        }
3587

3588
        return nil
×
3589
}
3590

3591
// getNodeFeatures fetches the feature bits and constructs the feature vector
3592
// for a node with the given DB ID.
3593
func getNodeFeatures(ctx context.Context, db SQLQueries,
3594
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3595

×
3596
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3597
        if err != nil {
×
3598
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3599
                        nodeID, err)
×
3600
        }
×
3601

3602
        features := lnwire.EmptyFeatureVector()
×
3603
        for _, feature := range rows {
×
3604
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3605
        }
×
3606

3607
        return features, nil
×
3608
}
3609

3610
// upsertNodeAncillaryData updates the node's features, addresses, and extra
3611
// signed fields. This is common logic shared by upsertNode and
3612
// upsertSourceNode.
3613
func upsertNodeAncillaryData(ctx context.Context, db SQLQueries,
NEW
3614
        nodeID int64, node *models.Node) error {
×
3615

×
NEW
3616
        // Update the node's features.
×
NEW
3617
        err := upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
NEW
3618
        if err != nil {
×
NEW
3619
                return fmt.Errorf("inserting node features: %w", err)
×
3620
        }
×
3621

3622
        // Update the node's addresses.
NEW
3623
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
NEW
3624
        if err != nil {
×
NEW
3625
                return fmt.Errorf("inserting node addresses: %w", err)
×
NEW
3626
        }
×
3627

3628
        // Convert the flat extra opaque data into a map of TLV types to
3629
        // values.
NEW
3630
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
NEW
3631
        if err != nil {
×
NEW
3632
                return fmt.Errorf("unable to marshal extra opaque data: %w",
×
NEW
3633
                        err)
×
NEW
3634
        }
×
3635

3636
        // Update the node's extra signed fields.
NEW
3637
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
NEW
3638
        if err != nil {
×
NEW
3639
                return fmt.Errorf("inserting node extra TLVs: %w", err)
×
NEW
3640
        }
×
3641

NEW
3642
        return nil
×
3643
}
3644

3645
// populateNodeParams populates the common node parameters from a models.Node.
3646
// This is a helper for building UpsertNodeParams and UpsertSourceNodeParams.
3647
func populateNodeParams(node *models.Node,
3648
        setParams func(lastUpdate sql.NullInt64, alias,
NEW
3649
                colorStr sql.NullString, signature []byte)) error {
×
NEW
3650

×
NEW
3651
        if !node.HaveAnnouncement() {
×
NEW
3652
                return nil
×
NEW
3653
        }
×
3654

NEW
3655
        switch node.Version {
×
NEW
3656
        case lnwire.GossipVersion1:
×
NEW
3657
                lastUpdate := sqldb.SQLInt64(node.LastUpdate.Unix())
×
NEW
3658
                var alias, colorStr sql.NullString
×
3659

×
3660
                node.Color.WhenSome(func(rgba color.RGBA) {
×
NEW
3661
                        colorStr = sqldb.SQLStrValid(EncodeHexColor(rgba))
×
3662
                })
×
3663
                node.Alias.WhenSome(func(s string) {
×
NEW
3664
                        alias = sqldb.SQLStrValid(s)
×
3665
                })
×
3666

NEW
3667
                setParams(lastUpdate, alias, colorStr, node.AuthSigBytes)
×
3668

NEW
3669
        case lnwire.GossipVersion2:
×
3670
                // No-op for now.
3671

NEW
3672
        default:
×
NEW
3673
                return fmt.Errorf("unknown gossip version: %d", node.Version)
×
3674
        }
3675

NEW
3676
        return nil
×
3677
}
3678

3679
// buildNodeUpsertParams builds the parameters for upserting a node using the
3680
// strict UpsertNode query (requires timestamp to be increasing).
NEW
3681
func buildNodeUpsertParams(node *models.Node) (sqlc.UpsertNodeParams, error) {
×
NEW
3682
        params := sqlc.UpsertNodeParams{
×
NEW
3683
                Version: int16(lnwire.GossipVersion1),
×
NEW
3684
                PubKey:  node.PubKeyBytes[:],
×
NEW
3685
        }
×
NEW
3686

×
NEW
3687
        err := populateNodeParams(
×
NEW
3688
                node, func(lastUpdate sql.NullInt64, alias,
×
NEW
3689
                        colorStr sql.NullString,
×
NEW
3690
                        signature []byte) {
×
NEW
3691

×
NEW
3692
                        params.LastUpdate = lastUpdate
×
NEW
3693
                        params.Alias = alias
×
NEW
3694
                        params.Color = colorStr
×
NEW
3695
                        params.Signature = signature
×
NEW
3696
                })
×
3697

NEW
3698
        return params, err
×
3699
}
3700

3701
// buildSourceNodeUpsertParams builds the parameters for upserting the source
3702
// node using the lenient UpsertSourceNode query (allows same timestamp).
3703
func buildSourceNodeUpsertParams(node *models.Node) (
NEW
3704
        sqlc.UpsertSourceNodeParams, error) {
×
NEW
3705

×
NEW
3706
        params := sqlc.UpsertSourceNodeParams{
×
NEW
3707
                Version: int16(lnwire.GossipVersion1),
×
NEW
3708
                PubKey:  node.PubKeyBytes[:],
×
NEW
3709
        }
×
NEW
3710

×
NEW
3711
        err := populateNodeParams(
×
NEW
3712
                node, func(lastUpdate sql.NullInt64, alias,
×
NEW
3713
                        colorStr sql.NullString, signature []byte) {
×
NEW
3714

×
NEW
3715
                        params.LastUpdate = lastUpdate
×
NEW
3716
                        params.Alias = alias
×
NEW
3717
                        params.Color = colorStr
×
NEW
3718
                        params.Signature = signature
×
NEW
3719
                },
×
3720
        )
3721

NEW
3722
        return params, err
×
3723
}
3724

3725
// upsertSourceNode upserts the source node record into the database using a
3726
// less strict upsert that allows updates even when the timestamp hasn't
3727
// changed. This is necessary to handle concurrent updates to our own node
3728
// during startup and runtime. The node's features, addresses and extra TLV
3729
// types are also updated. The node's DB ID is returned.
3730
func upsertSourceNode(ctx context.Context, db SQLQueries,
NEW
3731
        node *models.Node) (int64, error) {
×
NEW
3732

×
NEW
3733
        params, err := buildSourceNodeUpsertParams(node)
×
3734
        if err != nil {
×
NEW
3735
                return 0, err
×
NEW
3736
        }
×
3737

NEW
3738
        nodeID, err := db.UpsertSourceNode(ctx, params)
×
NEW
3739
        if err != nil {
×
NEW
3740
                return 0, fmt.Errorf("upserting source node(%x): %w",
×
NEW
3741
                        node.PubKeyBytes, err)
×
UNCOV
3742
        }
×
3743

3744
        // We can exit here if we don't have the announcement yet.
3745
        if !node.HaveAnnouncement() {
×
3746
                return nodeID, nil
×
3747
        }
×
3748

3749
        // Update the ancillary node data (features, addresses, extra fields).
NEW
3750
        err = upsertNodeAncillaryData(ctx, db, nodeID, node)
×
3751
        if err != nil {
×
NEW
3752
                return 0, err
×
3753
        }
×
3754

NEW
3755
        return nodeID, nil
×
3756
}
3757

3758
// upsertNode upserts the node record into the database. If the node already
3759
// exists, then the node's information is updated. If the node doesn't exist,
3760
// then a new node is created. The node's features, addresses and extra TLV
3761
// types are also updated. The node's DB ID is returned.
3762
func upsertNode(ctx context.Context, db SQLQueries,
NEW
3763
        node *models.Node) (int64, error) {
×
NEW
3764

×
NEW
3765
        params, err := buildNodeUpsertParams(node)
×
3766
        if err != nil {
×
NEW
3767
                return 0, err
×
3768
        }
×
3769

NEW
3770
        nodeID, err := db.UpsertNode(ctx, params)
×
UNCOV
3771
        if err != nil {
×
NEW
3772
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3773
                        err)
×
3774
        }
×
3775

3776
        // We can exit here if we don't have the announcement yet.
NEW
3777
        if !node.HaveAnnouncement() {
×
NEW
3778
                return nodeID, nil
×
NEW
3779
        }
×
3780

3781
        // Update the ancillary node data (features, addresses, extra fields).
NEW
3782
        err = upsertNodeAncillaryData(ctx, db, nodeID, node)
×
3783
        if err != nil {
×
NEW
3784
                return 0, err
×
3785
        }
×
3786

3787
        return nodeID, nil
×
3788
}
3789

3790
// upsertNodeFeatures updates the node's features node_features table. This
3791
// includes deleting any feature bits no longer present and inserting any new
3792
// feature bits. If the feature bit does not yet exist in the features table,
3793
// then an entry is created in that table first.
3794
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3795
        features *lnwire.FeatureVector) error {
×
3796

×
3797
        // Get any existing features for the node.
×
3798
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3799
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3800
                return err
×
3801
        }
×
3802

3803
        // Copy the nodes latest set of feature bits.
3804
        newFeatures := make(map[int32]struct{})
×
3805
        if features != nil {
×
3806
                for feature := range features.Features() {
×
3807
                        newFeatures[int32(feature)] = struct{}{}
×
3808
                }
×
3809
        }
3810

3811
        // For any current feature that already exists in the DB, remove it from
3812
        // the in-memory map. For any existing feature that does not exist in
3813
        // the in-memory map, delete it from the database.
3814
        for _, feature := range existingFeatures {
×
3815
                // The feature is still present, so there are no updates to be
×
3816
                // made.
×
3817
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3818
                        delete(newFeatures, feature.FeatureBit)
×
3819
                        continue
×
3820
                }
3821

3822
                // The feature is no longer present, so we remove it from the
3823
                // database.
3824
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
3825
                        NodeID:     nodeID,
×
3826
                        FeatureBit: feature.FeatureBit,
×
3827
                })
×
3828
                if err != nil {
×
3829
                        return fmt.Errorf("unable to delete node(%d) "+
×
3830
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3831
                                err)
×
3832
                }
×
3833
        }
3834

3835
        // Any remaining entries in newFeatures are new features that need to be
3836
        // added to the database for the first time.
3837
        for feature := range newFeatures {
×
3838
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3839
                        NodeID:     nodeID,
×
3840
                        FeatureBit: feature,
×
3841
                })
×
3842
                if err != nil {
×
3843
                        return fmt.Errorf("unable to insert node(%d) "+
×
3844
                                "feature(%v): %w", nodeID, feature, err)
×
3845
                }
×
3846
        }
3847

3848
        return nil
×
3849
}
3850

3851
// fetchNodeFeatures fetches the features for a node with the given public key.
3852
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3853
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3854

×
3855
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3856
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3857
                        PubKey:  nodePub[:],
×
3858
                        Version: int16(lnwire.GossipVersion1),
×
3859
                },
×
3860
        )
×
3861
        if err != nil {
×
3862
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
3863
                        nodePub, err)
×
3864
        }
×
3865

3866
        features := lnwire.EmptyFeatureVector()
×
3867
        for _, bit := range rows {
×
3868
                features.Set(lnwire.FeatureBit(bit))
×
3869
        }
×
3870

3871
        return features, nil
×
3872
}
3873

3874
// dbAddressType is an enum type that represents the different address types
3875
// that we store in the node_addresses table. The address type determines how
3876
// the address is to be serialised/deserialize.
3877
type dbAddressType uint8
3878

3879
const (
3880
        addressTypeIPv4   dbAddressType = 1
3881
        addressTypeIPv6   dbAddressType = 2
3882
        addressTypeTorV2  dbAddressType = 3
3883
        addressTypeTorV3  dbAddressType = 4
3884
        addressTypeDNS    dbAddressType = 5
3885
        addressTypeOpaque dbAddressType = math.MaxInt8
3886
)
3887

3888
// collectAddressRecords collects the addresses from the provided
3889
// net.Addr slice and returns a map of dbAddressType to a slice of address
3890
// strings.
3891
func collectAddressRecords(addresses []net.Addr) (map[dbAddressType][]string,
3892
        error) {
×
3893

×
3894
        // Copy the nodes latest set of addresses.
×
3895
        newAddresses := map[dbAddressType][]string{
×
3896
                addressTypeIPv4:   {},
×
3897
                addressTypeIPv6:   {},
×
3898
                addressTypeTorV2:  {},
×
3899
                addressTypeTorV3:  {},
×
3900
                addressTypeDNS:    {},
×
3901
                addressTypeOpaque: {},
×
3902
        }
×
3903
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3904
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3905
        }
×
3906

3907
        for _, address := range addresses {
×
3908
                switch addr := address.(type) {
×
3909
                case *net.TCPAddr:
×
3910
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3911
                                addAddr(addressTypeIPv4, addr)
×
3912
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3913
                                addAddr(addressTypeIPv6, addr)
×
3914
                        } else {
×
3915
                                return nil, fmt.Errorf("unhandled IP "+
×
3916
                                        "address: %v", addr)
×
3917
                        }
×
3918

3919
                case *tor.OnionAddr:
×
3920
                        switch len(addr.OnionService) {
×
3921
                        case tor.V2Len:
×
3922
                                addAddr(addressTypeTorV2, addr)
×
3923
                        case tor.V3Len:
×
3924
                                addAddr(addressTypeTorV3, addr)
×
3925
                        default:
×
3926
                                return nil, fmt.Errorf("invalid length for " +
×
3927
                                        "a tor address")
×
3928
                        }
3929

3930
                case *lnwire.DNSAddress:
×
3931
                        addAddr(addressTypeDNS, addr)
×
3932

3933
                case *lnwire.OpaqueAddrs:
×
3934
                        addAddr(addressTypeOpaque, addr)
×
3935

3936
                default:
×
3937
                        return nil, fmt.Errorf("unhandled address type: %T",
×
3938
                                addr)
×
3939
                }
3940
        }
3941

3942
        return newAddresses, nil
×
3943
}
3944

3945
// upsertNodeAddresses updates the node's addresses in the database. This
3946
// includes deleting any existing addresses and inserting the new set of
3947
// addresses. The deletion is necessary since the ordering of the addresses may
3948
// change, and we need to ensure that the database reflects the latest set of
3949
// addresses so that at the time of reconstructing the node announcement, the
3950
// order is preserved and the signature over the message remains valid.
3951
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
3952
        addresses []net.Addr) error {
×
3953

×
3954
        // Delete any existing addresses for the node. This is required since
×
3955
        // even if the new set of addresses is the same, the ordering may have
×
3956
        // changed for a given address type.
×
3957
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
3958
        if err != nil {
×
3959
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
3960
                        nodeID, err)
×
3961
        }
×
3962

3963
        newAddresses, err := collectAddressRecords(addresses)
×
3964
        if err != nil {
×
3965
                return err
×
3966
        }
×
3967

3968
        // Any remaining entries in newAddresses are new addresses that need to
3969
        // be added to the database for the first time.
3970
        for addrType, addrList := range newAddresses {
×
3971
                for position, addr := range addrList {
×
3972
                        err := db.UpsertNodeAddress(
×
3973
                                ctx, sqlc.UpsertNodeAddressParams{
×
3974
                                        NodeID:   nodeID,
×
3975
                                        Type:     int16(addrType),
×
3976
                                        Address:  addr,
×
3977
                                        Position: int32(position),
×
3978
                                },
×
3979
                        )
×
3980
                        if err != nil {
×
3981
                                return fmt.Errorf("unable to insert "+
×
3982
                                        "node(%d) address(%v): %w", nodeID,
×
3983
                                        addr, err)
×
3984
                        }
×
3985
                }
3986
        }
3987

3988
        return nil
×
3989
}
3990

3991
// getNodeAddresses fetches the addresses for a node with the given DB ID.
3992
func getNodeAddresses(ctx context.Context, db SQLQueries, id int64) ([]net.Addr,
3993
        error) {
×
3994

×
3995
        // GetNodeAddresses ensures that the addresses for a given type are
×
3996
        // returned in the same order as they were inserted.
×
3997
        rows, err := db.GetNodeAddresses(ctx, id)
×
3998
        if err != nil {
×
3999
                return nil, err
×
4000
        }
×
4001

4002
        addresses := make([]net.Addr, 0, len(rows))
×
4003
        for _, row := range rows {
×
4004
                address := row.Address
×
4005

×
4006
                addr, err := parseAddress(dbAddressType(row.Type), address)
×
4007
                if err != nil {
×
4008
                        return nil, fmt.Errorf("unable to parse address "+
×
4009
                                "for node(%d): %v: %w", id, address, err)
×
4010
                }
×
4011

4012
                addresses = append(addresses, addr)
×
4013
        }
4014

4015
        // If we have no addresses, then we'll return nil instead of an
4016
        // empty slice.
4017
        if len(addresses) == 0 {
×
4018
                addresses = nil
×
4019
        }
×
4020

4021
        return addresses, nil
×
4022
}
4023

4024
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
4025
// database. This includes updating any existing types, inserting any new types,
4026
// and deleting any types that are no longer present.
4027
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
4028
        nodeID int64, extraFields map[uint64][]byte) error {
×
4029

×
4030
        // Get any existing extra signed fields for the node.
×
4031
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
4032
        if err != nil {
×
4033
                return err
×
4034
        }
×
4035

4036
        // Make a lookup map of the existing field types so that we can use it
4037
        // to keep track of any fields we should delete.
4038
        m := make(map[uint64]bool)
×
4039
        for _, field := range existingFields {
×
4040
                m[uint64(field.Type)] = true
×
4041
        }
×
4042

4043
        // For all the new fields, we'll upsert them and remove them from the
4044
        // map of existing fields.
4045
        for tlvType, value := range extraFields {
×
4046
                err = db.UpsertNodeExtraType(
×
4047
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
4048
                                NodeID: nodeID,
×
4049
                                Type:   int64(tlvType),
×
4050
                                Value:  value,
×
4051
                        },
×
4052
                )
×
4053
                if err != nil {
×
4054
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
4055
                                "signed field(%v): %w", nodeID, tlvType, err)
×
4056
                }
×
4057

4058
                // Remove the field from the map of existing fields if it was
4059
                // present.
4060
                delete(m, tlvType)
×
4061
        }
4062

4063
        // For all the fields that are left in the map of existing fields, we'll
4064
        // delete them as they are no longer present in the new set of fields.
4065
        for tlvType := range m {
×
4066
                err = db.DeleteExtraNodeType(
×
4067
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
4068
                                NodeID: nodeID,
×
4069
                                Type:   int64(tlvType),
×
4070
                        },
×
4071
                )
×
4072
                if err != nil {
×
4073
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
4074
                                "signed field(%v): %w", nodeID, tlvType, err)
×
4075
                }
×
4076
        }
4077

4078
        return nil
×
4079
}
4080

4081
// srcNodeInfo holds the information about the source node of the graph.
4082
type srcNodeInfo struct {
4083
        // id is the DB level ID of the source node entry in the "nodes" table.
4084
        id int64
4085

4086
        // pub is the public key of the source node.
4087
        pub route.Vertex
4088
}
4089

4090
// sourceNode returns the DB node ID and pub key of the source node for the
4091
// specified protocol version.
4092
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
4093
        version lnwire.GossipVersion) (int64, route.Vertex, error) {
×
4094

×
4095
        s.srcNodeMu.Lock()
×
4096
        defer s.srcNodeMu.Unlock()
×
4097

×
4098
        // If we already have the source node ID and pub key cached, then
×
4099
        // return them.
×
4100
        if info, ok := s.srcNodes[version]; ok {
×
4101
                return info.id, info.pub, nil
×
4102
        }
×
4103

4104
        var pubKey route.Vertex
×
4105

×
4106
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
4107
        if err != nil {
×
4108
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
4109
                        err)
×
4110
        }
×
4111

4112
        if len(nodes) == 0 {
×
4113
                return 0, pubKey, ErrSourceNodeNotSet
×
4114
        } else if len(nodes) > 1 {
×
4115
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
4116
                        "protocol %s found", version)
×
4117
        }
×
4118

4119
        copy(pubKey[:], nodes[0].PubKey)
×
4120

×
4121
        s.srcNodes[version] = &srcNodeInfo{
×
4122
                id:  nodes[0].NodeID,
×
4123
                pub: pubKey,
×
4124
        }
×
4125

×
4126
        return nodes[0].NodeID, pubKey, nil
×
4127
}
4128

4129
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
4130
// This then produces a map from TLV type to value. If the input is not a
4131
// valid TLV stream, then an error is returned.
4132
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
4133
        r := bytes.NewReader(data)
×
4134

×
4135
        tlvStream, err := tlv.NewStream()
×
4136
        if err != nil {
×
4137
                return nil, err
×
4138
        }
×
4139

4140
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
4141
        // pass it into the P2P decoding variant.
4142
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
4143
        if err != nil {
×
4144
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
4145
        }
×
4146
        if len(parsedTypes) == 0 {
×
4147
                return nil, nil
×
4148
        }
×
4149

4150
        records := make(map[uint64][]byte)
×
4151
        for k, v := range parsedTypes {
×
4152
                records[uint64(k)] = v
×
4153
        }
×
4154

4155
        return records, nil
×
4156
}
4157

4158
// insertChannel inserts a new channel record into the database.
4159
func insertChannel(ctx context.Context, db SQLQueries,
4160
        edge *models.ChannelEdgeInfo) error {
×
4161

×
4162
        // Make sure that at least a "shell" entry for each node is present in
×
4163
        // the nodes table.
×
4164
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
4165
        if err != nil {
×
4166
                return fmt.Errorf("unable to create shell node: %w", err)
×
4167
        }
×
4168

4169
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
4170
        if err != nil {
×
4171
                return fmt.Errorf("unable to create shell node: %w", err)
×
4172
        }
×
4173

4174
        var capacity sql.NullInt64
×
4175
        if edge.Capacity != 0 {
×
4176
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
4177
        }
×
4178

4179
        createParams := sqlc.CreateChannelParams{
×
4180
                Version:     int16(lnwire.GossipVersion1),
×
4181
                Scid:        channelIDToBytes(edge.ChannelID),
×
4182
                NodeID1:     node1DBID,
×
4183
                NodeID2:     node2DBID,
×
4184
                Outpoint:    edge.ChannelPoint.String(),
×
4185
                Capacity:    capacity,
×
4186
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
4187
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
4188
        }
×
4189

×
4190
        if edge.AuthProof != nil {
×
4191
                proof := edge.AuthProof
×
4192

×
4193
                createParams.Node1Signature = proof.NodeSig1Bytes
×
4194
                createParams.Node2Signature = proof.NodeSig2Bytes
×
4195
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
4196
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
4197
        }
×
4198

4199
        // Insert the new channel record.
4200
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
4201
        if err != nil {
×
4202
                return err
×
4203
        }
×
4204

4205
        // Insert any channel features.
4206
        for feature := range edge.Features.Features() {
×
4207
                err = db.InsertChannelFeature(
×
4208
                        ctx, sqlc.InsertChannelFeatureParams{
×
4209
                                ChannelID:  dbChanID,
×
4210
                                FeatureBit: int32(feature),
×
4211
                        },
×
4212
                )
×
4213
                if err != nil {
×
4214
                        return fmt.Errorf("unable to insert channel(%d) "+
×
4215
                                "feature(%v): %w", dbChanID, feature, err)
×
4216
                }
×
4217
        }
4218

4219
        // Finally, insert any extra TLV fields in the channel announcement.
4220
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
4221
        if err != nil {
×
4222
                return fmt.Errorf("unable to marshal extra opaque data: %w",
×
4223
                        err)
×
4224
        }
×
4225

4226
        for tlvType, value := range extra {
×
4227
                err := db.UpsertChannelExtraType(
×
4228
                        ctx, sqlc.UpsertChannelExtraTypeParams{
×
4229
                                ChannelID: dbChanID,
×
4230
                                Type:      int64(tlvType),
×
4231
                                Value:     value,
×
4232
                        },
×
4233
                )
×
4234
                if err != nil {
×
4235
                        return fmt.Errorf("unable to upsert channel(%d) "+
×
4236
                                "extra signed field(%v): %w", edge.ChannelID,
×
4237
                                tlvType, err)
×
4238
                }
×
4239
        }
4240

4241
        return nil
×
4242
}
4243

4244
// maybeCreateShellNode checks if a shell node entry exists for the
4245
// given public key. If it does not exist, then a new shell node entry is
4246
// created. The ID of the node is returned. A shell node only has a protocol
4247
// version and public key persisted.
4248
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
4249
        pubKey route.Vertex) (int64, error) {
×
4250

×
4251
        dbNode, err := db.GetNodeByPubKey(
×
4252
                ctx, sqlc.GetNodeByPubKeyParams{
×
4253
                        PubKey:  pubKey[:],
×
4254
                        Version: int16(lnwire.GossipVersion1),
×
4255
                },
×
4256
        )
×
4257
        // The node exists. Return the ID.
×
4258
        if err == nil {
×
4259
                return dbNode.ID, nil
×
4260
        } else if !errors.Is(err, sql.ErrNoRows) {
×
4261
                return 0, err
×
4262
        }
×
4263

4264
        // Otherwise, the node does not exist, so we create a shell entry for
4265
        // it.
4266
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
4267
                Version: int16(lnwire.GossipVersion1),
×
4268
                PubKey:  pubKey[:],
×
4269
        })
×
4270
        if err != nil {
×
4271
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
4272
        }
×
4273

4274
        return id, nil
×
4275
}
4276

4277
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
4278
// the database. This includes deleting any existing types and then inserting
4279
// the new types.
4280
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
4281
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
4282

×
4283
        // Delete all existing extra signed fields for the channel policy.
×
4284
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
4285
        if err != nil {
×
4286
                return fmt.Errorf("unable to delete "+
×
4287
                        "existing policy extra signed fields for policy %d: %w",
×
4288
                        chanPolicyID, err)
×
4289
        }
×
4290

4291
        // Insert all new extra signed fields for the channel policy.
4292
        for tlvType, value := range extraFields {
×
4293
                err = db.UpsertChanPolicyExtraType(
×
4294
                        ctx, sqlc.UpsertChanPolicyExtraTypeParams{
×
4295
                                ChannelPolicyID: chanPolicyID,
×
4296
                                Type:            int64(tlvType),
×
4297
                                Value:           value,
×
4298
                        },
×
4299
                )
×
4300
                if err != nil {
×
4301
                        return fmt.Errorf("unable to insert "+
×
4302
                                "channel_policy(%d) extra signed field(%v): %w",
×
4303
                                chanPolicyID, tlvType, err)
×
4304
                }
×
4305
        }
4306

4307
        return nil
×
4308
}
4309

4310
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
4311
// provided dbChanRow and also fetches any other required information
4312
// to construct the edge info.
4313
func getAndBuildEdgeInfo(ctx context.Context, cfg *SQLStoreConfig,
4314
        db SQLQueries, dbChan sqlc.GraphChannel, node1,
4315
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
4316

×
4317
        data, err := batchLoadChannelData(
×
4318
                ctx, cfg.QueryCfg, db, []int64{dbChan.ID}, nil,
×
4319
        )
×
4320
        if err != nil {
×
4321
                return nil, fmt.Errorf("unable to batch load channel data: %w",
×
4322
                        err)
×
4323
        }
×
4324

4325
        return buildEdgeInfoWithBatchData(
×
4326
                cfg.ChainHash, dbChan, node1, node2, data,
×
4327
        )
×
4328
}
4329

4330
// buildEdgeInfoWithBatchData builds edge info using pre-loaded batch data.
4331
func buildEdgeInfoWithBatchData(chain chainhash.Hash,
4332
        dbChan sqlc.GraphChannel, node1, node2 route.Vertex,
4333
        batchData *batchChannelData) (*models.ChannelEdgeInfo, error) {
×
4334

×
4335
        if dbChan.Version != int16(lnwire.GossipVersion1) {
×
4336
                return nil, fmt.Errorf("unsupported channel version: %d",
×
4337
                        dbChan.Version)
×
4338
        }
×
4339

4340
        // Use pre-loaded features and extras types.
4341
        fv := lnwire.EmptyFeatureVector()
×
4342
        if features, exists := batchData.chanfeatures[dbChan.ID]; exists {
×
4343
                for _, bit := range features {
×
4344
                        fv.Set(lnwire.FeatureBit(bit))
×
4345
                }
×
4346
        }
4347

4348
        var extras map[uint64][]byte
×
4349
        channelExtras, exists := batchData.chanExtraTypes[dbChan.ID]
×
4350
        if exists {
×
4351
                extras = channelExtras
×
4352
        } else {
×
4353
                extras = make(map[uint64][]byte)
×
4354
        }
×
4355

4356
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
4357
        if err != nil {
×
4358
                return nil, err
×
4359
        }
×
4360

4361
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4362
        if err != nil {
×
4363
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4364
                        "fields: %w", err)
×
4365
        }
×
4366
        if recs == nil {
×
4367
                recs = make([]byte, 0)
×
4368
        }
×
4369

4370
        var btcKey1, btcKey2 route.Vertex
×
4371
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
4372
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
4373

×
4374
        channel := &models.ChannelEdgeInfo{
×
4375
                ChainHash:        chain,
×
4376
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
4377
                NodeKey1Bytes:    node1,
×
4378
                NodeKey2Bytes:    node2,
×
4379
                BitcoinKey1Bytes: btcKey1,
×
4380
                BitcoinKey2Bytes: btcKey2,
×
4381
                ChannelPoint:     *op,
×
4382
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
4383
                Features:         fv,
×
4384
                ExtraOpaqueData:  recs,
×
4385
        }
×
4386

×
4387
        // We always set all the signatures at the same time, so we can
×
4388
        // safely check if one signature is present to determine if we have the
×
4389
        // rest of the signatures for the auth proof.
×
4390
        if len(dbChan.Bitcoin1Signature) > 0 {
×
4391
                channel.AuthProof = &models.ChannelAuthProof{
×
4392
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
4393
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
4394
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
4395
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
4396
                }
×
4397
        }
×
4398

4399
        return channel, nil
×
4400
}
4401

4402
// buildNodeVertices is a helper that converts raw node public keys
4403
// into route.Vertex instances.
4404
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4405
        route.Vertex, error) {
×
4406

×
4407
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4408
        if err != nil {
×
4409
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4410
                        "create vertex from node1 pubkey: %w", err)
×
4411
        }
×
4412

4413
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
4414
        if err != nil {
×
4415
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4416
                        "create vertex from node2 pubkey: %w", err)
×
4417
        }
×
4418

4419
        return node1Vertex, node2Vertex, nil
×
4420
}
4421

4422
// getAndBuildChanPolicies uses the given sqlc.GraphChannelPolicy and also
4423
// retrieves all the extra info required to build the complete
4424
// models.ChannelEdgePolicy types. It returns two policies, which may be nil if
4425
// the provided sqlc.GraphChannelPolicy records are nil.
4426
func getAndBuildChanPolicies(ctx context.Context, cfg *sqldb.QueryConfig,
4427
        db SQLQueries, dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4428
        channelID uint64, node1, node2 route.Vertex) (*models.ChannelEdgePolicy,
4429
        *models.ChannelEdgePolicy, error) {
×
4430

×
4431
        if dbPol1 == nil && dbPol2 == nil {
×
4432
                return nil, nil, nil
×
4433
        }
×
4434

4435
        var policyIDs = make([]int64, 0, 2)
×
4436
        if dbPol1 != nil {
×
4437
                policyIDs = append(policyIDs, dbPol1.ID)
×
4438
        }
×
4439
        if dbPol2 != nil {
×
4440
                policyIDs = append(policyIDs, dbPol2.ID)
×
4441
        }
×
4442

4443
        batchData, err := batchLoadChannelData(ctx, cfg, db, nil, policyIDs)
×
4444
        if err != nil {
×
4445
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
4446
                        "data: %w", err)
×
4447
        }
×
4448

4449
        pol1, err := buildChanPolicyWithBatchData(
×
4450
                dbPol1, channelID, node2, batchData,
×
4451
        )
×
4452
        if err != nil {
×
4453
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
4454
        }
×
4455

4456
        pol2, err := buildChanPolicyWithBatchData(
×
4457
                dbPol2, channelID, node1, batchData,
×
4458
        )
×
4459
        if err != nil {
×
4460
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
4461
        }
×
4462

4463
        return pol1, pol2, nil
×
4464
}
4465

4466
// buildCachedChanPolicies builds models.CachedEdgePolicy instances from the
4467
// provided sqlc.GraphChannelPolicy objects. If a provided policy is nil,
4468
// then nil is returned for it.
4469
func buildCachedChanPolicies(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4470
        channelID uint64, node1, node2 route.Vertex) (*models.CachedEdgePolicy,
4471
        *models.CachedEdgePolicy, error) {
×
4472

×
4473
        var p1, p2 *models.CachedEdgePolicy
×
4474
        if dbPol1 != nil {
×
4475
                policy1, err := buildChanPolicy(*dbPol1, channelID, nil, node2)
×
4476
                if err != nil {
×
4477
                        return nil, nil, err
×
4478
                }
×
4479

4480
                p1 = models.NewCachedPolicy(policy1)
×
4481
        }
4482
        if dbPol2 != nil {
×
4483
                policy2, err := buildChanPolicy(*dbPol2, channelID, nil, node1)
×
4484
                if err != nil {
×
4485
                        return nil, nil, err
×
4486
                }
×
4487

4488
                p2 = models.NewCachedPolicy(policy2)
×
4489
        }
4490

4491
        return p1, p2, nil
×
4492
}
4493

4494
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4495
// provided sqlc.GraphChannelPolicy and other required information.
4496
func buildChanPolicy(dbPolicy sqlc.GraphChannelPolicy, channelID uint64,
4497
        extras map[uint64][]byte,
4498
        toNode route.Vertex) (*models.ChannelEdgePolicy, error) {
×
4499

×
4500
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4501
        if err != nil {
×
4502
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4503
                        "fields: %w", err)
×
4504
        }
×
4505

4506
        var inboundFee fn.Option[lnwire.Fee]
×
4507
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
4508
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4509

×
4510
                inboundFee = fn.Some(lnwire.Fee{
×
4511
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4512
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4513
                })
×
4514
        }
×
4515

4516
        return &models.ChannelEdgePolicy{
×
4517
                SigBytes:  dbPolicy.Signature,
×
4518
                ChannelID: channelID,
×
4519
                LastUpdate: time.Unix(
×
4520
                        dbPolicy.LastUpdate.Int64, 0,
×
4521
                ),
×
4522
                MessageFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags](
×
4523
                        dbPolicy.MessageFlags,
×
4524
                ),
×
4525
                ChannelFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags](
×
4526
                        dbPolicy.ChannelFlags,
×
4527
                ),
×
4528
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
4529
                MinHTLC: lnwire.MilliSatoshi(
×
4530
                        dbPolicy.MinHtlcMsat,
×
4531
                ),
×
4532
                MaxHTLC: lnwire.MilliSatoshi(
×
4533
                        dbPolicy.MaxHtlcMsat.Int64,
×
4534
                ),
×
4535
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4536
                        dbPolicy.BaseFeeMsat,
×
4537
                ),
×
4538
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4539
                ToNode:                    toNode,
×
4540
                InboundFee:                inboundFee,
×
4541
                ExtraOpaqueData:           recs,
×
4542
        }, nil
×
4543
}
4544

4545
// extractChannelPolicies extracts the sqlc.GraphChannelPolicy records from the give
4546
// row which is expected to be a sqlc type that contains channel policy
4547
// information. It returns two policies, which may be nil if the policy
4548
// information is not present in the row.
4549
//
4550
//nolint:ll,dupl,funlen
4551
func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy,
4552
        *sqlc.GraphChannelPolicy, error) {
×
4553

×
4554
        var policy1, policy2 *sqlc.GraphChannelPolicy
×
4555
        switch r := row.(type) {
×
4556
        case sqlc.ListChannelsWithPoliciesForCachePaginatedRow:
×
4557
                if r.Policy1Timelock.Valid {
×
4558
                        policy1 = &sqlc.GraphChannelPolicy{
×
4559
                                Timelock:                r.Policy1Timelock.Int32,
×
4560
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4561
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4562
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4563
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4564
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4565
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4566
                                Disabled:                r.Policy1Disabled,
×
4567
                                MessageFlags:            r.Policy1MessageFlags,
×
4568
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4569
                        }
×
4570
                }
×
4571
                if r.Policy2Timelock.Valid {
×
4572
                        policy2 = &sqlc.GraphChannelPolicy{
×
4573
                                Timelock:                r.Policy2Timelock.Int32,
×
4574
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4575
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4576
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4577
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4578
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4579
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4580
                                Disabled:                r.Policy2Disabled,
×
4581
                                MessageFlags:            r.Policy2MessageFlags,
×
4582
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4583
                        }
×
4584
                }
×
4585

4586
                return policy1, policy2, nil
×
4587

4588
        case sqlc.GetChannelsBySCIDWithPoliciesRow:
×
4589
                if r.Policy1ID.Valid {
×
4590
                        policy1 = &sqlc.GraphChannelPolicy{
×
4591
                                ID:                      r.Policy1ID.Int64,
×
4592
                                Version:                 r.Policy1Version.Int16,
×
4593
                                ChannelID:               r.GraphChannel.ID,
×
4594
                                NodeID:                  r.Policy1NodeID.Int64,
×
4595
                                Timelock:                r.Policy1Timelock.Int32,
×
4596
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4597
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4598
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4599
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4600
                                LastUpdate:              r.Policy1LastUpdate,
×
4601
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4602
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4603
                                Disabled:                r.Policy1Disabled,
×
4604
                                MessageFlags:            r.Policy1MessageFlags,
×
4605
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4606
                                Signature:               r.Policy1Signature,
×
4607
                        }
×
4608
                }
×
4609
                if r.Policy2ID.Valid {
×
4610
                        policy2 = &sqlc.GraphChannelPolicy{
×
4611
                                ID:                      r.Policy2ID.Int64,
×
4612
                                Version:                 r.Policy2Version.Int16,
×
4613
                                ChannelID:               r.GraphChannel.ID,
×
4614
                                NodeID:                  r.Policy2NodeID.Int64,
×
4615
                                Timelock:                r.Policy2Timelock.Int32,
×
4616
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4617
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4618
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4619
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4620
                                LastUpdate:              r.Policy2LastUpdate,
×
4621
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4622
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4623
                                Disabled:                r.Policy2Disabled,
×
4624
                                MessageFlags:            r.Policy2MessageFlags,
×
4625
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4626
                                Signature:               r.Policy2Signature,
×
4627
                        }
×
4628
                }
×
4629

4630
                return policy1, policy2, nil
×
4631

4632
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
4633
                if r.Policy1ID.Valid {
×
4634
                        policy1 = &sqlc.GraphChannelPolicy{
×
4635
                                ID:                      r.Policy1ID.Int64,
×
4636
                                Version:                 r.Policy1Version.Int16,
×
4637
                                ChannelID:               r.GraphChannel.ID,
×
4638
                                NodeID:                  r.Policy1NodeID.Int64,
×
4639
                                Timelock:                r.Policy1Timelock.Int32,
×
4640
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4641
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4642
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4643
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4644
                                LastUpdate:              r.Policy1LastUpdate,
×
4645
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4646
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4647
                                Disabled:                r.Policy1Disabled,
×
4648
                                MessageFlags:            r.Policy1MessageFlags,
×
4649
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4650
                                Signature:               r.Policy1Signature,
×
4651
                        }
×
4652
                }
×
4653
                if r.Policy2ID.Valid {
×
4654
                        policy2 = &sqlc.GraphChannelPolicy{
×
4655
                                ID:                      r.Policy2ID.Int64,
×
4656
                                Version:                 r.Policy2Version.Int16,
×
4657
                                ChannelID:               r.GraphChannel.ID,
×
4658
                                NodeID:                  r.Policy2NodeID.Int64,
×
4659
                                Timelock:                r.Policy2Timelock.Int32,
×
4660
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4661
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4662
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4663
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4664
                                LastUpdate:              r.Policy2LastUpdate,
×
4665
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4666
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4667
                                Disabled:                r.Policy2Disabled,
×
4668
                                MessageFlags:            r.Policy2MessageFlags,
×
4669
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4670
                                Signature:               r.Policy2Signature,
×
4671
                        }
×
4672
                }
×
4673

4674
                return policy1, policy2, nil
×
4675

4676
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4677
                if r.Policy1ID.Valid {
×
4678
                        policy1 = &sqlc.GraphChannelPolicy{
×
4679
                                ID:                      r.Policy1ID.Int64,
×
4680
                                Version:                 r.Policy1Version.Int16,
×
4681
                                ChannelID:               r.GraphChannel.ID,
×
4682
                                NodeID:                  r.Policy1NodeID.Int64,
×
4683
                                Timelock:                r.Policy1Timelock.Int32,
×
4684
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4685
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4686
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4687
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4688
                                LastUpdate:              r.Policy1LastUpdate,
×
4689
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4690
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4691
                                Disabled:                r.Policy1Disabled,
×
4692
                                MessageFlags:            r.Policy1MessageFlags,
×
4693
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4694
                                Signature:               r.Policy1Signature,
×
4695
                        }
×
4696
                }
×
4697
                if r.Policy2ID.Valid {
×
4698
                        policy2 = &sqlc.GraphChannelPolicy{
×
4699
                                ID:                      r.Policy2ID.Int64,
×
4700
                                Version:                 r.Policy2Version.Int16,
×
4701
                                ChannelID:               r.GraphChannel.ID,
×
4702
                                NodeID:                  r.Policy2NodeID.Int64,
×
4703
                                Timelock:                r.Policy2Timelock.Int32,
×
4704
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4705
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4706
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4707
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4708
                                LastUpdate:              r.Policy2LastUpdate,
×
4709
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4710
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4711
                                Disabled:                r.Policy2Disabled,
×
4712
                                MessageFlags:            r.Policy2MessageFlags,
×
4713
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4714
                                Signature:               r.Policy2Signature,
×
4715
                        }
×
4716
                }
×
4717

4718
                return policy1, policy2, nil
×
4719

4720
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4721
                if r.Policy1ID.Valid {
×
4722
                        policy1 = &sqlc.GraphChannelPolicy{
×
4723
                                ID:                      r.Policy1ID.Int64,
×
4724
                                Version:                 r.Policy1Version.Int16,
×
4725
                                ChannelID:               r.GraphChannel.ID,
×
4726
                                NodeID:                  r.Policy1NodeID.Int64,
×
4727
                                Timelock:                r.Policy1Timelock.Int32,
×
4728
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4729
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4730
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4731
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4732
                                LastUpdate:              r.Policy1LastUpdate,
×
4733
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4734
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4735
                                Disabled:                r.Policy1Disabled,
×
4736
                                MessageFlags:            r.Policy1MessageFlags,
×
4737
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4738
                                Signature:               r.Policy1Signature,
×
4739
                        }
×
4740
                }
×
4741
                if r.Policy2ID.Valid {
×
4742
                        policy2 = &sqlc.GraphChannelPolicy{
×
4743
                                ID:                      r.Policy2ID.Int64,
×
4744
                                Version:                 r.Policy2Version.Int16,
×
4745
                                ChannelID:               r.GraphChannel.ID,
×
4746
                                NodeID:                  r.Policy2NodeID.Int64,
×
4747
                                Timelock:                r.Policy2Timelock.Int32,
×
4748
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4749
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4750
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4751
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4752
                                LastUpdate:              r.Policy2LastUpdate,
×
4753
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4754
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4755
                                Disabled:                r.Policy2Disabled,
×
4756
                                MessageFlags:            r.Policy2MessageFlags,
×
4757
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4758
                                Signature:               r.Policy2Signature,
×
4759
                        }
×
4760
                }
×
4761

4762
                return policy1, policy2, nil
×
4763

4764
        case sqlc.ListChannelsForNodeIDsRow:
×
4765
                if r.Policy1ID.Valid {
×
4766
                        policy1 = &sqlc.GraphChannelPolicy{
×
4767
                                ID:                      r.Policy1ID.Int64,
×
4768
                                Version:                 r.Policy1Version.Int16,
×
4769
                                ChannelID:               r.GraphChannel.ID,
×
4770
                                NodeID:                  r.Policy1NodeID.Int64,
×
4771
                                Timelock:                r.Policy1Timelock.Int32,
×
4772
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4773
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4774
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4775
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4776
                                LastUpdate:              r.Policy1LastUpdate,
×
4777
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4778
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4779
                                Disabled:                r.Policy1Disabled,
×
4780
                                MessageFlags:            r.Policy1MessageFlags,
×
4781
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4782
                                Signature:               r.Policy1Signature,
×
4783
                        }
×
4784
                }
×
4785
                if r.Policy2ID.Valid {
×
4786
                        policy2 = &sqlc.GraphChannelPolicy{
×
4787
                                ID:                      r.Policy2ID.Int64,
×
4788
                                Version:                 r.Policy2Version.Int16,
×
4789
                                ChannelID:               r.GraphChannel.ID,
×
4790
                                NodeID:                  r.Policy2NodeID.Int64,
×
4791
                                Timelock:                r.Policy2Timelock.Int32,
×
4792
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4793
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4794
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4795
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4796
                                LastUpdate:              r.Policy2LastUpdate,
×
4797
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4798
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4799
                                Disabled:                r.Policy2Disabled,
×
4800
                                MessageFlags:            r.Policy2MessageFlags,
×
4801
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4802
                                Signature:               r.Policy2Signature,
×
4803
                        }
×
4804
                }
×
4805

4806
                return policy1, policy2, nil
×
4807

4808
        case sqlc.ListChannelsByNodeIDRow:
×
4809
                if r.Policy1ID.Valid {
×
4810
                        policy1 = &sqlc.GraphChannelPolicy{
×
4811
                                ID:                      r.Policy1ID.Int64,
×
4812
                                Version:                 r.Policy1Version.Int16,
×
4813
                                ChannelID:               r.GraphChannel.ID,
×
4814
                                NodeID:                  r.Policy1NodeID.Int64,
×
4815
                                Timelock:                r.Policy1Timelock.Int32,
×
4816
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4817
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4818
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4819
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4820
                                LastUpdate:              r.Policy1LastUpdate,
×
4821
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4822
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4823
                                Disabled:                r.Policy1Disabled,
×
4824
                                MessageFlags:            r.Policy1MessageFlags,
×
4825
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4826
                                Signature:               r.Policy1Signature,
×
4827
                        }
×
4828
                }
×
4829
                if r.Policy2ID.Valid {
×
4830
                        policy2 = &sqlc.GraphChannelPolicy{
×
4831
                                ID:                      r.Policy2ID.Int64,
×
4832
                                Version:                 r.Policy2Version.Int16,
×
4833
                                ChannelID:               r.GraphChannel.ID,
×
4834
                                NodeID:                  r.Policy2NodeID.Int64,
×
4835
                                Timelock:                r.Policy2Timelock.Int32,
×
4836
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4837
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4838
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4839
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4840
                                LastUpdate:              r.Policy2LastUpdate,
×
4841
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4842
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4843
                                Disabled:                r.Policy2Disabled,
×
4844
                                MessageFlags:            r.Policy2MessageFlags,
×
4845
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4846
                                Signature:               r.Policy2Signature,
×
4847
                        }
×
4848
                }
×
4849

4850
                return policy1, policy2, nil
×
4851

4852
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4853
                if r.Policy1ID.Valid {
×
4854
                        policy1 = &sqlc.GraphChannelPolicy{
×
4855
                                ID:                      r.Policy1ID.Int64,
×
4856
                                Version:                 r.Policy1Version.Int16,
×
4857
                                ChannelID:               r.GraphChannel.ID,
×
4858
                                NodeID:                  r.Policy1NodeID.Int64,
×
4859
                                Timelock:                r.Policy1Timelock.Int32,
×
4860
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4861
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4862
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4863
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4864
                                LastUpdate:              r.Policy1LastUpdate,
×
4865
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4866
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4867
                                Disabled:                r.Policy1Disabled,
×
4868
                                MessageFlags:            r.Policy1MessageFlags,
×
4869
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4870
                                Signature:               r.Policy1Signature,
×
4871
                        }
×
4872
                }
×
4873
                if r.Policy2ID.Valid {
×
4874
                        policy2 = &sqlc.GraphChannelPolicy{
×
4875
                                ID:                      r.Policy2ID.Int64,
×
4876
                                Version:                 r.Policy2Version.Int16,
×
4877
                                ChannelID:               r.GraphChannel.ID,
×
4878
                                NodeID:                  r.Policy2NodeID.Int64,
×
4879
                                Timelock:                r.Policy2Timelock.Int32,
×
4880
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4881
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4882
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4883
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4884
                                LastUpdate:              r.Policy2LastUpdate,
×
4885
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4886
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4887
                                Disabled:                r.Policy2Disabled,
×
4888
                                MessageFlags:            r.Policy2MessageFlags,
×
4889
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4890
                                Signature:               r.Policy2Signature,
×
4891
                        }
×
4892
                }
×
4893

4894
                return policy1, policy2, nil
×
4895

4896
        case sqlc.GetChannelsByIDsRow:
×
4897
                if r.Policy1ID.Valid {
×
4898
                        policy1 = &sqlc.GraphChannelPolicy{
×
4899
                                ID:                      r.Policy1ID.Int64,
×
4900
                                Version:                 r.Policy1Version.Int16,
×
4901
                                ChannelID:               r.GraphChannel.ID,
×
4902
                                NodeID:                  r.Policy1NodeID.Int64,
×
4903
                                Timelock:                r.Policy1Timelock.Int32,
×
4904
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4905
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4906
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4907
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4908
                                LastUpdate:              r.Policy1LastUpdate,
×
4909
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4910
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4911
                                Disabled:                r.Policy1Disabled,
×
4912
                                MessageFlags:            r.Policy1MessageFlags,
×
4913
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4914
                                Signature:               r.Policy1Signature,
×
4915
                        }
×
4916
                }
×
4917
                if r.Policy2ID.Valid {
×
4918
                        policy2 = &sqlc.GraphChannelPolicy{
×
4919
                                ID:                      r.Policy2ID.Int64,
×
4920
                                Version:                 r.Policy2Version.Int16,
×
4921
                                ChannelID:               r.GraphChannel.ID,
×
4922
                                NodeID:                  r.Policy2NodeID.Int64,
×
4923
                                Timelock:                r.Policy2Timelock.Int32,
×
4924
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4925
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4926
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4927
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4928
                                LastUpdate:              r.Policy2LastUpdate,
×
4929
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4930
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4931
                                Disabled:                r.Policy2Disabled,
×
4932
                                MessageFlags:            r.Policy2MessageFlags,
×
4933
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4934
                                Signature:               r.Policy2Signature,
×
4935
                        }
×
4936
                }
×
4937

4938
                return policy1, policy2, nil
×
4939

4940
        default:
×
4941
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
4942
                        "extractChannelPolicies: %T", r)
×
4943
        }
4944
}
4945

4946
// channelIDToBytes converts a channel ID (SCID) to a byte array
4947
// representation.
4948
func channelIDToBytes(channelID uint64) []byte {
×
4949
        var chanIDB [8]byte
×
4950
        byteOrder.PutUint64(chanIDB[:], channelID)
×
4951

×
4952
        return chanIDB[:]
×
4953
}
×
4954

4955
// buildNodeAddresses converts a slice of nodeAddress into a slice of net.Addr.
4956
func buildNodeAddresses(addresses []nodeAddress) ([]net.Addr, error) {
×
4957
        if len(addresses) == 0 {
×
4958
                return nil, nil
×
4959
        }
×
4960

4961
        result := make([]net.Addr, 0, len(addresses))
×
4962
        for _, addr := range addresses {
×
4963
                netAddr, err := parseAddress(addr.addrType, addr.address)
×
4964
                if err != nil {
×
4965
                        return nil, fmt.Errorf("unable to parse address %s "+
×
4966
                                "of type %d: %w", addr.address, addr.addrType,
×
4967
                                err)
×
4968
                }
×
4969
                if netAddr != nil {
×
4970
                        result = append(result, netAddr)
×
4971
                }
×
4972
        }
4973

4974
        // If we have no valid addresses, return nil instead of empty slice.
4975
        if len(result) == 0 {
×
4976
                return nil, nil
×
4977
        }
×
4978

4979
        return result, nil
×
4980
}
4981

4982
// parseAddress parses the given address string based on the address type
4983
// and returns a net.Addr instance. It supports IPv4, IPv6, Tor v2, Tor v3,
4984
// and opaque addresses.
4985
func parseAddress(addrType dbAddressType, address string) (net.Addr, error) {
×
4986
        switch addrType {
×
4987
        case addressTypeIPv4:
×
4988
                tcp, err := net.ResolveTCPAddr("tcp4", address)
×
4989
                if err != nil {
×
4990
                        return nil, err
×
4991
                }
×
4992

4993
                tcp.IP = tcp.IP.To4()
×
4994

×
4995
                return tcp, nil
×
4996

4997
        case addressTypeIPv6:
×
4998
                tcp, err := net.ResolveTCPAddr("tcp6", address)
×
4999
                if err != nil {
×
5000
                        return nil, err
×
5001
                }
×
5002

5003
                return tcp, nil
×
5004

5005
        case addressTypeTorV3, addressTypeTorV2:
×
5006
                service, portStr, err := net.SplitHostPort(address)
×
5007
                if err != nil {
×
5008
                        return nil, fmt.Errorf("unable to split tor "+
×
5009
                                "address: %v", address)
×
5010
                }
×
5011

5012
                port, err := strconv.Atoi(portStr)
×
5013
                if err != nil {
×
5014
                        return nil, err
×
5015
                }
×
5016

5017
                return &tor.OnionAddr{
×
5018
                        OnionService: service,
×
5019
                        Port:         port,
×
5020
                }, nil
×
5021

5022
        case addressTypeDNS:
×
5023
                hostname, portStr, err := net.SplitHostPort(address)
×
5024
                if err != nil {
×
5025
                        return nil, fmt.Errorf("unable to split DNS "+
×
5026
                                "address: %v", address)
×
5027
                }
×
5028

5029
                port, err := strconv.Atoi(portStr)
×
5030
                if err != nil {
×
5031
                        return nil, err
×
5032
                }
×
5033

5034
                return &lnwire.DNSAddress{
×
5035
                        Hostname: hostname,
×
5036
                        Port:     uint16(port),
×
5037
                }, nil
×
5038

5039
        case addressTypeOpaque:
×
5040
                opaque, err := hex.DecodeString(address)
×
5041
                if err != nil {
×
5042
                        return nil, fmt.Errorf("unable to decode opaque "+
×
5043
                                "address: %v", address)
×
5044
                }
×
5045

5046
                return &lnwire.OpaqueAddrs{
×
5047
                        Payload: opaque,
×
5048
                }, nil
×
5049

5050
        default:
×
5051
                return nil, fmt.Errorf("unknown address type: %v", addrType)
×
5052
        }
5053
}
5054

5055
// batchNodeData holds all the related data for a batch of nodes.
5056
type batchNodeData struct {
5057
        // features is a map from a DB node ID to the feature bits for that
5058
        // node.
5059
        features map[int64][]int
5060

5061
        // addresses is a map from a DB node ID to the node's addresses.
5062
        addresses map[int64][]nodeAddress
5063

5064
        // extraFields is a map from a DB node ID to the extra signed fields
5065
        // for that node.
5066
        extraFields map[int64]map[uint64][]byte
5067
}
5068

5069
// nodeAddress holds the address type, position and address string for a
5070
// node. This is used to batch the fetching of node addresses.
5071
type nodeAddress struct {
5072
        addrType dbAddressType
5073
        position int32
5074
        address  string
5075
}
5076

5077
// batchLoadNodeData loads all related data for a batch of node IDs using the
5078
// provided SQLQueries interface. It returns a batchNodeData instance containing
5079
// the node features, addresses and extra signed fields.
5080
func batchLoadNodeData(ctx context.Context, cfg *sqldb.QueryConfig,
5081
        db SQLQueries, nodeIDs []int64) (*batchNodeData, error) {
×
5082

×
5083
        // Batch load the node features.
×
5084
        features, err := batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
5085
        if err != nil {
×
5086
                return nil, fmt.Errorf("unable to batch load node "+
×
5087
                        "features: %w", err)
×
5088
        }
×
5089

5090
        // Batch load the node addresses.
5091
        addrs, err := batchLoadNodeAddressesHelper(ctx, cfg, db, nodeIDs)
×
5092
        if err != nil {
×
5093
                return nil, fmt.Errorf("unable to batch load node "+
×
5094
                        "addresses: %w", err)
×
5095
        }
×
5096

5097
        // Batch load the node extra signed fields.
5098
        extraTypes, err := batchLoadNodeExtraTypesHelper(ctx, cfg, db, nodeIDs)
×
5099
        if err != nil {
×
5100
                return nil, fmt.Errorf("unable to batch load node extra "+
×
5101
                        "signed fields: %w", err)
×
5102
        }
×
5103

5104
        return &batchNodeData{
×
5105
                features:    features,
×
5106
                addresses:   addrs,
×
5107
                extraFields: extraTypes,
×
5108
        }, nil
×
5109
}
5110

5111
// batchLoadNodeFeaturesHelper loads node features for a batch of node IDs
5112
// using ExecuteBatchQuery wrapper around the GetNodeFeaturesBatch query.
5113
func batchLoadNodeFeaturesHelper(ctx context.Context,
5114
        cfg *sqldb.QueryConfig, db SQLQueries,
5115
        nodeIDs []int64) (map[int64][]int, error) {
×
5116

×
5117
        features := make(map[int64][]int)
×
5118

×
5119
        return features, sqldb.ExecuteBatchQuery(
×
5120
                ctx, cfg, nodeIDs,
×
5121
                func(id int64) int64 {
×
5122
                        return id
×
5123
                },
×
5124
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature,
5125
                        error) {
×
5126

×
5127
                        return db.GetNodeFeaturesBatch(ctx, ids)
×
5128
                },
×
5129
                func(ctx context.Context, feature sqlc.GraphNodeFeature) error {
×
5130
                        features[feature.NodeID] = append(
×
5131
                                features[feature.NodeID],
×
5132
                                int(feature.FeatureBit),
×
5133
                        )
×
5134

×
5135
                        return nil
×
5136
                },
×
5137
        )
5138
}
5139

5140
// batchLoadNodeAddressesHelper loads node addresses using ExecuteBatchQuery
5141
// wrapper around the GetNodeAddressesBatch query. It returns a map from
5142
// node ID to a slice of nodeAddress structs.
5143
func batchLoadNodeAddressesHelper(ctx context.Context,
5144
        cfg *sqldb.QueryConfig, db SQLQueries,
5145
        nodeIDs []int64) (map[int64][]nodeAddress, error) {
×
5146

×
5147
        addrs := make(map[int64][]nodeAddress)
×
5148

×
5149
        return addrs, sqldb.ExecuteBatchQuery(
×
5150
                ctx, cfg, nodeIDs,
×
5151
                func(id int64) int64 {
×
5152
                        return id
×
5153
                },
×
5154
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress,
5155
                        error) {
×
5156

×
5157
                        return db.GetNodeAddressesBatch(ctx, ids)
×
5158
                },
×
5159
                func(ctx context.Context, addr sqlc.GraphNodeAddress) error {
×
5160
                        addrs[addr.NodeID] = append(
×
5161
                                addrs[addr.NodeID], nodeAddress{
×
5162
                                        addrType: dbAddressType(addr.Type),
×
5163
                                        position: addr.Position,
×
5164
                                        address:  addr.Address,
×
5165
                                },
×
5166
                        )
×
5167

×
5168
                        return nil
×
5169
                },
×
5170
        )
5171
}
5172

5173
// batchLoadNodeExtraTypesHelper loads node extra type bytes for a batch of
5174
// node IDs using ExecuteBatchQuery wrapper around the GetNodeExtraTypesBatch
5175
// query.
5176
func batchLoadNodeExtraTypesHelper(ctx context.Context,
5177
        cfg *sqldb.QueryConfig, db SQLQueries,
5178
        nodeIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5179

×
5180
        extraFields := make(map[int64]map[uint64][]byte)
×
5181

×
5182
        callback := func(ctx context.Context,
×
5183
                field sqlc.GraphNodeExtraType) error {
×
5184

×
5185
                if extraFields[field.NodeID] == nil {
×
5186
                        extraFields[field.NodeID] = make(map[uint64][]byte)
×
5187
                }
×
5188
                extraFields[field.NodeID][uint64(field.Type)] = field.Value
×
5189

×
5190
                return nil
×
5191
        }
5192

5193
        return extraFields, sqldb.ExecuteBatchQuery(
×
5194
                ctx, cfg, nodeIDs,
×
5195
                func(id int64) int64 {
×
5196
                        return id
×
5197
                },
×
5198
                func(ctx context.Context, ids []int64) (
5199
                        []sqlc.GraphNodeExtraType, error) {
×
5200

×
5201
                        return db.GetNodeExtraTypesBatch(ctx, ids)
×
5202
                },
×
5203
                callback,
5204
        )
5205
}
5206

5207
// buildChanPoliciesWithBatchData builds two models.ChannelEdgePolicy instances
5208
// from the provided sqlc.GraphChannelPolicy records and the
5209
// provided batchChannelData.
5210
func buildChanPoliciesWithBatchData(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
5211
        channelID uint64, node1, node2 route.Vertex,
5212
        batchData *batchChannelData) (*models.ChannelEdgePolicy,
5213
        *models.ChannelEdgePolicy, error) {
×
5214

×
5215
        pol1, err := buildChanPolicyWithBatchData(
×
5216
                dbPol1, channelID, node2, batchData,
×
5217
        )
×
5218
        if err != nil {
×
5219
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
5220
        }
×
5221

5222
        pol2, err := buildChanPolicyWithBatchData(
×
5223
                dbPol2, channelID, node1, batchData,
×
5224
        )
×
5225
        if err != nil {
×
5226
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
5227
        }
×
5228

5229
        return pol1, pol2, nil
×
5230
}
5231

5232
// buildChanPolicyWithBatchData builds a models.ChannelEdgePolicy instance from
5233
// the provided sqlc.GraphChannelPolicy and the provided batchChannelData.
5234
func buildChanPolicyWithBatchData(dbPol *sqlc.GraphChannelPolicy,
5235
        channelID uint64, toNode route.Vertex,
5236
        batchData *batchChannelData) (*models.ChannelEdgePolicy, error) {
×
5237

×
5238
        if dbPol == nil {
×
5239
                return nil, nil
×
5240
        }
×
5241

5242
        var dbPol1Extras map[uint64][]byte
×
5243
        if extras, exists := batchData.policyExtras[dbPol.ID]; exists {
×
5244
                dbPol1Extras = extras
×
5245
        } else {
×
5246
                dbPol1Extras = make(map[uint64][]byte)
×
5247
        }
×
5248

5249
        return buildChanPolicy(*dbPol, channelID, dbPol1Extras, toNode)
×
5250
}
5251

5252
// batchChannelData holds all the related data for a batch of channels.
5253
type batchChannelData struct {
5254
        // chanFeatures is a map from DB channel ID to a slice of feature bits.
5255
        chanfeatures map[int64][]int
5256

5257
        // chanExtras is a map from DB channel ID to a map of TLV type to
5258
        // extra signed field bytes.
5259
        chanExtraTypes map[int64]map[uint64][]byte
5260

5261
        // policyExtras is a map from DB channel policy ID to a map of TLV type
5262
        // to extra signed field bytes.
5263
        policyExtras map[int64]map[uint64][]byte
5264
}
5265

5266
// batchLoadChannelData loads all related data for batches of channels and
5267
// policies.
5268
func batchLoadChannelData(ctx context.Context, cfg *sqldb.QueryConfig,
5269
        db SQLQueries, channelIDs []int64,
5270
        policyIDs []int64) (*batchChannelData, error) {
×
5271

×
5272
        batchData := &batchChannelData{
×
5273
                chanfeatures:   make(map[int64][]int),
×
5274
                chanExtraTypes: make(map[int64]map[uint64][]byte),
×
5275
                policyExtras:   make(map[int64]map[uint64][]byte),
×
5276
        }
×
5277

×
5278
        // Batch load channel features and extras
×
5279
        var err error
×
5280
        if len(channelIDs) > 0 {
×
5281
                batchData.chanfeatures, err = batchLoadChannelFeaturesHelper(
×
5282
                        ctx, cfg, db, channelIDs,
×
5283
                )
×
5284
                if err != nil {
×
5285
                        return nil, fmt.Errorf("unable to batch load "+
×
5286
                                "channel features: %w", err)
×
5287
                }
×
5288

5289
                batchData.chanExtraTypes, err = batchLoadChannelExtrasHelper(
×
5290
                        ctx, cfg, db, channelIDs,
×
5291
                )
×
5292
                if err != nil {
×
5293
                        return nil, fmt.Errorf("unable to batch load "+
×
5294
                                "channel extras: %w", err)
×
5295
                }
×
5296
        }
5297

5298
        if len(policyIDs) > 0 {
×
5299
                policyExtras, err := batchLoadChannelPolicyExtrasHelper(
×
5300
                        ctx, cfg, db, policyIDs,
×
5301
                )
×
5302
                if err != nil {
×
5303
                        return nil, fmt.Errorf("unable to batch load "+
×
5304
                                "policy extras: %w", err)
×
5305
                }
×
5306
                batchData.policyExtras = policyExtras
×
5307
        }
5308

5309
        return batchData, nil
×
5310
}
5311

5312
// batchLoadChannelFeaturesHelper loads channel features for a batch of
5313
// channel IDs using ExecuteBatchQuery wrapper around the
5314
// GetChannelFeaturesBatch query. It returns a map from DB channel ID to a
5315
// slice of feature bits.
5316
func batchLoadChannelFeaturesHelper(ctx context.Context,
5317
        cfg *sqldb.QueryConfig, db SQLQueries,
5318
        channelIDs []int64) (map[int64][]int, error) {
×
5319

×
5320
        features := make(map[int64][]int)
×
5321

×
5322
        return features, sqldb.ExecuteBatchQuery(
×
5323
                ctx, cfg, channelIDs,
×
5324
                func(id int64) int64 {
×
5325
                        return id
×
5326
                },
×
5327
                func(ctx context.Context,
5328
                        ids []int64) ([]sqlc.GraphChannelFeature, error) {
×
5329

×
5330
                        return db.GetChannelFeaturesBatch(ctx, ids)
×
5331
                },
×
5332
                func(ctx context.Context,
5333
                        feature sqlc.GraphChannelFeature) error {
×
5334

×
5335
                        features[feature.ChannelID] = append(
×
5336
                                features[feature.ChannelID],
×
5337
                                int(feature.FeatureBit),
×
5338
                        )
×
5339

×
5340
                        return nil
×
5341
                },
×
5342
        )
5343
}
5344

5345
// batchLoadChannelExtrasHelper loads channel extra types for a batch of
5346
// channel IDs using ExecuteBatchQuery wrapper around the GetChannelExtrasBatch
5347
// query. It returns a map from DB channel ID to a map of TLV type to extra
5348
// signed field bytes.
5349
func batchLoadChannelExtrasHelper(ctx context.Context,
5350
        cfg *sqldb.QueryConfig, db SQLQueries,
5351
        channelIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5352

×
5353
        extras := make(map[int64]map[uint64][]byte)
×
5354

×
5355
        cb := func(ctx context.Context,
×
5356
                extra sqlc.GraphChannelExtraType) error {
×
5357

×
5358
                if extras[extra.ChannelID] == nil {
×
5359
                        extras[extra.ChannelID] = make(map[uint64][]byte)
×
5360
                }
×
5361
                extras[extra.ChannelID][uint64(extra.Type)] = extra.Value
×
5362

×
5363
                return nil
×
5364
        }
5365

5366
        return extras, sqldb.ExecuteBatchQuery(
×
5367
                ctx, cfg, channelIDs,
×
5368
                func(id int64) int64 {
×
5369
                        return id
×
5370
                },
×
5371
                func(ctx context.Context,
5372
                        ids []int64) ([]sqlc.GraphChannelExtraType, error) {
×
5373

×
5374
                        return db.GetChannelExtrasBatch(ctx, ids)
×
5375
                }, cb,
×
5376
        )
5377
}
5378

5379
// batchLoadChannelPolicyExtrasHelper loads channel policy extra types for a
5380
// batch of policy IDs using ExecuteBatchQuery wrapper around the
5381
// GetChannelPolicyExtraTypesBatch query. It returns a map from DB policy ID to
5382
// a map of TLV type to extra signed field bytes.
5383
func batchLoadChannelPolicyExtrasHelper(ctx context.Context,
5384
        cfg *sqldb.QueryConfig, db SQLQueries,
5385
        policyIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5386

×
5387
        extras := make(map[int64]map[uint64][]byte)
×
5388

×
5389
        return extras, sqldb.ExecuteBatchQuery(
×
5390
                ctx, cfg, policyIDs,
×
5391
                func(id int64) int64 {
×
5392
                        return id
×
5393
                },
×
5394
                func(ctx context.Context, ids []int64) (
5395
                        []sqlc.GetChannelPolicyExtraTypesBatchRow, error) {
×
5396

×
5397
                        return db.GetChannelPolicyExtraTypesBatch(ctx, ids)
×
5398
                },
×
5399
                func(ctx context.Context,
5400
                        row sqlc.GetChannelPolicyExtraTypesBatchRow) error {
×
5401

×
5402
                        if extras[row.PolicyID] == nil {
×
5403
                                extras[row.PolicyID] = make(map[uint64][]byte)
×
5404
                        }
×
5405
                        extras[row.PolicyID][uint64(row.Type)] = row.Value
×
5406

×
5407
                        return nil
×
5408
                },
5409
        )
5410
}
5411

5412
// forEachNodePaginated executes a paginated query to process each node in the
5413
// graph. It uses the provided SQLQueries interface to fetch nodes in batches
5414
// and applies the provided processNode function to each node.
5415
func forEachNodePaginated(ctx context.Context, cfg *sqldb.QueryConfig,
5416
        db SQLQueries, protocol lnwire.GossipVersion,
5417
        processNode func(context.Context, int64,
5418
                *models.Node) error) error {
×
5419

×
5420
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
5421
                limit int32) ([]sqlc.GraphNode, error) {
×
5422

×
5423
                return db.ListNodesPaginated(
×
5424
                        ctx, sqlc.ListNodesPaginatedParams{
×
5425
                                Version: int16(protocol),
×
5426
                                ID:      lastID,
×
5427
                                Limit:   limit,
×
5428
                        },
×
5429
                )
×
5430
        }
×
5431

5432
        extractPageCursor := func(node sqlc.GraphNode) int64 {
×
5433
                return node.ID
×
5434
        }
×
5435

5436
        collectFunc := func(node sqlc.GraphNode) (int64, error) {
×
5437
                return node.ID, nil
×
5438
        }
×
5439

5440
        batchQueryFunc := func(ctx context.Context,
×
5441
                nodeIDs []int64) (*batchNodeData, error) {
×
5442

×
5443
                return batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
5444
        }
×
5445

5446
        processItem := func(ctx context.Context, dbNode sqlc.GraphNode,
×
5447
                batchData *batchNodeData) error {
×
5448

×
5449
                node, err := buildNodeWithBatchData(dbNode, batchData)
×
5450
                if err != nil {
×
5451
                        return fmt.Errorf("unable to build "+
×
5452
                                "node(id=%d): %w", dbNode.ID, err)
×
5453
                }
×
5454

5455
                return processNode(ctx, dbNode.ID, node)
×
5456
        }
5457

5458
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
5459
                ctx, cfg, int64(-1), pageQueryFunc, extractPageCursor,
×
5460
                collectFunc, batchQueryFunc, processItem,
×
5461
        )
×
5462
}
5463

5464
// forEachChannelWithPolicies executes a paginated query to process each channel
5465
// with policies in the graph.
5466
func forEachChannelWithPolicies(ctx context.Context, db SQLQueries,
5467
        cfg *SQLStoreConfig, processChannel func(*models.ChannelEdgeInfo,
5468
                *models.ChannelEdgePolicy,
5469
                *models.ChannelEdgePolicy) error) error {
×
5470

×
5471
        type channelBatchIDs struct {
×
5472
                channelID int64
×
5473
                policyIDs []int64
×
5474
        }
×
5475

×
5476
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
5477
                limit int32) ([]sqlc.ListChannelsWithPoliciesPaginatedRow,
×
5478
                error) {
×
5479

×
5480
                return db.ListChannelsWithPoliciesPaginated(
×
5481
                        ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
5482
                                Version: int16(lnwire.GossipVersion1),
×
5483
                                ID:      lastID,
×
5484
                                Limit:   limit,
×
5485
                        },
×
5486
                )
×
5487
        }
×
5488

5489
        extractPageCursor := func(
×
5490
                row sqlc.ListChannelsWithPoliciesPaginatedRow) int64 {
×
5491

×
5492
                return row.GraphChannel.ID
×
5493
        }
×
5494

5495
        collectFunc := func(row sqlc.ListChannelsWithPoliciesPaginatedRow) (
×
5496
                channelBatchIDs, error) {
×
5497

×
5498
                ids := channelBatchIDs{
×
5499
                        channelID: row.GraphChannel.ID,
×
5500
                }
×
5501

×
5502
                // Extract policy IDs from the row.
×
5503
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5504
                if err != nil {
×
5505
                        return ids, err
×
5506
                }
×
5507

5508
                if dbPol1 != nil {
×
5509
                        ids.policyIDs = append(ids.policyIDs, dbPol1.ID)
×
5510
                }
×
5511
                if dbPol2 != nil {
×
5512
                        ids.policyIDs = append(ids.policyIDs, dbPol2.ID)
×
5513
                }
×
5514

5515
                return ids, nil
×
5516
        }
5517

5518
        batchDataFunc := func(ctx context.Context,
×
5519
                allIDs []channelBatchIDs) (*batchChannelData, error) {
×
5520

×
5521
                // Separate channel IDs from policy IDs.
×
5522
                var (
×
5523
                        channelIDs = make([]int64, len(allIDs))
×
5524
                        policyIDs  = make([]int64, 0, len(allIDs)*2)
×
5525
                )
×
5526

×
5527
                for i, ids := range allIDs {
×
5528
                        channelIDs[i] = ids.channelID
×
5529
                        policyIDs = append(policyIDs, ids.policyIDs...)
×
5530
                }
×
5531

5532
                return batchLoadChannelData(
×
5533
                        ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
5534
                )
×
5535
        }
5536

5537
        processItem := func(ctx context.Context,
×
5538
                row sqlc.ListChannelsWithPoliciesPaginatedRow,
×
5539
                batchData *batchChannelData) error {
×
5540

×
5541
                node1, node2, err := buildNodeVertices(
×
5542
                        row.Node1Pubkey, row.Node2Pubkey,
×
5543
                )
×
5544
                if err != nil {
×
5545
                        return err
×
5546
                }
×
5547

5548
                edge, err := buildEdgeInfoWithBatchData(
×
5549
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
5550
                        batchData,
×
5551
                )
×
5552
                if err != nil {
×
5553
                        return fmt.Errorf("unable to build channel info: %w",
×
5554
                                err)
×
5555
                }
×
5556

5557
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5558
                if err != nil {
×
5559
                        return err
×
5560
                }
×
5561

5562
                p1, p2, err := buildChanPoliciesWithBatchData(
×
5563
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
5564
                )
×
5565
                if err != nil {
×
5566
                        return err
×
5567
                }
×
5568

5569
                return processChannel(edge, p1, p2)
×
5570
        }
5571

5572
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
5573
                ctx, cfg.QueryCfg, int64(-1), pageQueryFunc, extractPageCursor,
×
5574
                collectFunc, batchDataFunc, processItem,
×
5575
        )
×
5576
}
5577

5578
// buildDirectedChannel builds a DirectedChannel instance from the provided
5579
// data.
5580
func buildDirectedChannel(chain chainhash.Hash, nodeID int64,
5581
        nodePub route.Vertex, channelRow sqlc.ListChannelsForNodeIDsRow,
5582
        channelBatchData *batchChannelData, features *lnwire.FeatureVector,
5583
        toNodeCallback func() route.Vertex) (*DirectedChannel, error) {
×
5584

×
5585
        node1, node2, err := buildNodeVertices(
×
5586
                channelRow.Node1Pubkey, channelRow.Node2Pubkey,
×
5587
        )
×
5588
        if err != nil {
×
5589
                return nil, fmt.Errorf("unable to build node vertices: %w", err)
×
5590
        }
×
5591

5592
        edge, err := buildEdgeInfoWithBatchData(
×
5593
                chain, channelRow.GraphChannel, node1, node2, channelBatchData,
×
5594
        )
×
5595
        if err != nil {
×
5596
                return nil, fmt.Errorf("unable to build channel info: %w", err)
×
5597
        }
×
5598

5599
        dbPol1, dbPol2, err := extractChannelPolicies(channelRow)
×
5600
        if err != nil {
×
5601
                return nil, fmt.Errorf("unable to extract channel policies: %w",
×
5602
                        err)
×
5603
        }
×
5604

5605
        p1, p2, err := buildChanPoliciesWithBatchData(
×
5606
                dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
5607
                channelBatchData,
×
5608
        )
×
5609
        if err != nil {
×
5610
                return nil, fmt.Errorf("unable to build channel policies: %w",
×
5611
                        err)
×
5612
        }
×
5613

5614
        // Determine outgoing and incoming policy for this specific node.
5615
        p1ToNode := channelRow.GraphChannel.NodeID2
×
5616
        p2ToNode := channelRow.GraphChannel.NodeID1
×
5617
        outPolicy, inPolicy := p1, p2
×
5618
        if (p1 != nil && p1ToNode == nodeID) ||
×
5619
                (p2 != nil && p2ToNode != nodeID) {
×
5620

×
5621
                outPolicy, inPolicy = p2, p1
×
5622
        }
×
5623

5624
        // Build cached policy.
5625
        var cachedInPolicy *models.CachedEdgePolicy
×
5626
        if inPolicy != nil {
×
5627
                cachedInPolicy = models.NewCachedPolicy(inPolicy)
×
5628
                cachedInPolicy.ToNodePubKey = toNodeCallback
×
5629
                cachedInPolicy.ToNodeFeatures = features
×
5630
        }
×
5631

5632
        // Extract inbound fee.
5633
        var inboundFee lnwire.Fee
×
5634
        if outPolicy != nil {
×
5635
                outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
5636
                        inboundFee = fee
×
5637
                })
×
5638
        }
5639

5640
        // Build directed channel.
5641
        directedChannel := &DirectedChannel{
×
5642
                ChannelID:    edge.ChannelID,
×
5643
                IsNode1:      nodePub == edge.NodeKey1Bytes,
×
5644
                OtherNode:    edge.NodeKey2Bytes,
×
5645
                Capacity:     edge.Capacity,
×
5646
                OutPolicySet: outPolicy != nil,
×
5647
                InPolicy:     cachedInPolicy,
×
5648
                InboundFee:   inboundFee,
×
5649
        }
×
5650

×
5651
        if nodePub == edge.NodeKey2Bytes {
×
5652
                directedChannel.OtherNode = edge.NodeKey1Bytes
×
5653
        }
×
5654

5655
        return directedChannel, nil
×
5656
}
5657

5658
// batchBuildChannelEdges builds a slice of ChannelEdge instances from the
5659
// provided rows. It uses batch loading for channels, policies, and nodes.
5660
func batchBuildChannelEdges[T sqlc.ChannelAndNodes](ctx context.Context,
5661
        cfg *SQLStoreConfig, db SQLQueries, rows []T) ([]ChannelEdge, error) {
×
5662

×
5663
        var (
×
5664
                channelIDs = make([]int64, len(rows))
×
5665
                policyIDs  = make([]int64, 0, len(rows)*2)
×
5666
                nodeIDs    = make([]int64, 0, len(rows)*2)
×
5667

×
5668
                // nodeIDSet is used to ensure we only collect unique node IDs.
×
5669
                nodeIDSet = make(map[int64]bool)
×
5670

×
5671
                // edges will hold the final channel edges built from the rows.
×
5672
                edges = make([]ChannelEdge, 0, len(rows))
×
5673
        )
×
5674

×
5675
        // Collect all IDs needed for batch loading.
×
5676
        for i, row := range rows {
×
5677
                channelIDs[i] = row.Channel().ID
×
5678

×
5679
                // Collect policy IDs
×
5680
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5681
                if err != nil {
×
5682
                        return nil, fmt.Errorf("unable to extract channel "+
×
5683
                                "policies: %w", err)
×
5684
                }
×
5685
                if dbPol1 != nil {
×
5686
                        policyIDs = append(policyIDs, dbPol1.ID)
×
5687
                }
×
5688
                if dbPol2 != nil {
×
5689
                        policyIDs = append(policyIDs, dbPol2.ID)
×
5690
                }
×
5691

5692
                var (
×
5693
                        node1ID = row.Node1().ID
×
5694
                        node2ID = row.Node2().ID
×
5695
                )
×
5696

×
5697
                // Collect unique node IDs.
×
5698
                if !nodeIDSet[node1ID] {
×
5699
                        nodeIDs = append(nodeIDs, node1ID)
×
5700
                        nodeIDSet[node1ID] = true
×
5701
                }
×
5702

5703
                if !nodeIDSet[node2ID] {
×
5704
                        nodeIDs = append(nodeIDs, node2ID)
×
5705
                        nodeIDSet[node2ID] = true
×
5706
                }
×
5707
        }
5708

5709
        // Batch the data for all the channels and policies.
5710
        channelBatchData, err := batchLoadChannelData(
×
5711
                ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
5712
        )
×
5713
        if err != nil {
×
5714
                return nil, fmt.Errorf("unable to batch load channel and "+
×
5715
                        "policy data: %w", err)
×
5716
        }
×
5717

5718
        // Batch the data for all the nodes.
5719
        nodeBatchData, err := batchLoadNodeData(ctx, cfg.QueryCfg, db, nodeIDs)
×
5720
        if err != nil {
×
5721
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
5722
                        err)
×
5723
        }
×
5724

5725
        // Build all channel edges using batch data.
5726
        for _, row := range rows {
×
5727
                // Build nodes using batch data.
×
5728
                node1, err := buildNodeWithBatchData(row.Node1(), nodeBatchData)
×
5729
                if err != nil {
×
5730
                        return nil, fmt.Errorf("unable to build node1: %w", err)
×
5731
                }
×
5732

5733
                node2, err := buildNodeWithBatchData(row.Node2(), nodeBatchData)
×
5734
                if err != nil {
×
5735
                        return nil, fmt.Errorf("unable to build node2: %w", err)
×
5736
                }
×
5737

5738
                // Build channel info using batch data.
5739
                channel, err := buildEdgeInfoWithBatchData(
×
5740
                        cfg.ChainHash, row.Channel(), node1.PubKeyBytes,
×
5741
                        node2.PubKeyBytes, channelBatchData,
×
5742
                )
×
5743
                if err != nil {
×
5744
                        return nil, fmt.Errorf("unable to build channel "+
×
5745
                                "info: %w", err)
×
5746
                }
×
5747

5748
                // Extract and build policies using batch data.
5749
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5750
                if err != nil {
×
5751
                        return nil, fmt.Errorf("unable to extract channel "+
×
5752
                                "policies: %w", err)
×
5753
                }
×
5754

5755
                p1, p2, err := buildChanPoliciesWithBatchData(
×
5756
                        dbPol1, dbPol2, channel.ChannelID,
×
5757
                        node1.PubKeyBytes, node2.PubKeyBytes, channelBatchData,
×
5758
                )
×
5759
                if err != nil {
×
5760
                        return nil, fmt.Errorf("unable to build channel "+
×
5761
                                "policies: %w", err)
×
5762
                }
×
5763

5764
                edges = append(edges, ChannelEdge{
×
5765
                        Info:    channel,
×
5766
                        Policy1: p1,
×
5767
                        Policy2: p2,
×
5768
                        Node1:   node1,
×
5769
                        Node2:   node2,
×
5770
                })
×
5771
        }
5772

5773
        return edges, nil
×
5774
}
5775

5776
// batchBuildChannelInfo builds a slice of models.ChannelEdgeInfo
5777
// instances from the provided rows using batch loading for channel data.
5778
func batchBuildChannelInfo[T sqlc.ChannelAndNodeIDs](ctx context.Context,
5779
        cfg *SQLStoreConfig, db SQLQueries, rows []T) (
5780
        []*models.ChannelEdgeInfo, []int64, error) {
×
5781

×
5782
        if len(rows) == 0 {
×
5783
                return nil, nil, nil
×
5784
        }
×
5785

5786
        // Collect all the channel IDs needed for batch loading.
5787
        channelIDs := make([]int64, len(rows))
×
5788
        for i, row := range rows {
×
5789
                channelIDs[i] = row.Channel().ID
×
5790
        }
×
5791

5792
        // Batch load the channel data.
5793
        channelBatchData, err := batchLoadChannelData(
×
5794
                ctx, cfg.QueryCfg, db, channelIDs, nil,
×
5795
        )
×
5796
        if err != nil {
×
5797
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
5798
                        "data: %w", err)
×
5799
        }
×
5800

5801
        // Build all channel edges using batch data.
5802
        edges := make([]*models.ChannelEdgeInfo, 0, len(rows))
×
5803
        for _, row := range rows {
×
5804
                node1, node2, err := buildNodeVertices(
×
5805
                        row.Node1Pub(), row.Node2Pub(),
×
5806
                )
×
5807
                if err != nil {
×
5808
                        return nil, nil, err
×
5809
                }
×
5810

5811
                // Build channel info using batch data
5812
                info, err := buildEdgeInfoWithBatchData(
×
5813
                        cfg.ChainHash, row.Channel(), node1, node2,
×
5814
                        channelBatchData,
×
5815
                )
×
5816
                if err != nil {
×
5817
                        return nil, nil, err
×
5818
                }
×
5819

5820
                edges = append(edges, info)
×
5821
        }
5822

5823
        return edges, channelIDs, nil
×
5824
}
5825

5826
// handleZombieMarking is a helper function that handles the logic of
5827
// marking a channel as a zombie in the database. It takes into account whether
5828
// we are in strict zombie pruning mode, and adjusts the node public keys
5829
// accordingly based on the last update timestamps of the channel policies.
5830
func handleZombieMarking(ctx context.Context, db SQLQueries,
5831
        row sqlc.GetChannelsBySCIDWithPoliciesRow, info *models.ChannelEdgeInfo,
5832
        strictZombiePruning bool, scid uint64) error {
×
5833

×
5834
        nodeKey1, nodeKey2 := info.NodeKey1Bytes, info.NodeKey2Bytes
×
5835

×
5836
        if strictZombiePruning {
×
5837
                var e1UpdateTime, e2UpdateTime *time.Time
×
5838
                if row.Policy1LastUpdate.Valid {
×
5839
                        e1Time := time.Unix(row.Policy1LastUpdate.Int64, 0)
×
5840
                        e1UpdateTime = &e1Time
×
5841
                }
×
5842
                if row.Policy2LastUpdate.Valid {
×
5843
                        e2Time := time.Unix(row.Policy2LastUpdate.Int64, 0)
×
5844
                        e2UpdateTime = &e2Time
×
5845
                }
×
5846

5847
                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
5848
                        info.NodeKey1Bytes, info.NodeKey2Bytes, e1UpdateTime,
×
5849
                        e2UpdateTime,
×
5850
                )
×
5851
        }
5852

5853
        return db.UpsertZombieChannel(
×
5854
                ctx, sqlc.UpsertZombieChannelParams{
×
5855
                        Version:  int16(lnwire.GossipVersion1),
×
5856
                        Scid:     channelIDToBytes(scid),
×
5857
                        NodeKey1: nodeKey1[:],
×
5858
                        NodeKey2: nodeKey2[:],
×
5859
                },
×
5860
        )
×
5861
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc