• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 23185093231

17 Mar 2026 08:25AM UTC coverage: 62.303% (-0.03%) from 62.33%
23185093231

push

github

web-flow
Merge pull request #10582 from ellemouton/g175-db-8

[g175] graph/db: add versioned range queries and complete v2 graph query migration

210 of 471 new or added lines in 11 files covered. (44.59%)

87 existing lines in 22 files now uncovered.

140872 of 226108 relevant lines covered (62.3%)

19402.99 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        color "image/color"
11
        "iter"
12
        "maps"
13
        "math"
14
        "net"
15
        "slices"
16
        "strconv"
17
        "sync"
18
        "time"
19

20
        "github.com/btcsuite/btcd/btcec/v2"
21
        "github.com/btcsuite/btcd/btcutil"
22
        "github.com/btcsuite/btcd/chaincfg/chainhash"
23
        "github.com/btcsuite/btcd/wire"
24
        "github.com/lightningnetwork/lnd/aliasmgr"
25
        "github.com/lightningnetwork/lnd/batch"
26
        "github.com/lightningnetwork/lnd/fn/v2"
27
        "github.com/lightningnetwork/lnd/graph/db/models"
28
        "github.com/lightningnetwork/lnd/lnwire"
29
        "github.com/lightningnetwork/lnd/routing/route"
30
        "github.com/lightningnetwork/lnd/sqldb"
31
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
32
        "github.com/lightningnetwork/lnd/tlv"
33
        "github.com/lightningnetwork/lnd/tor"
34
)
35

36
const (
37
        gossipV1 = lnwire.GossipVersion1
38
        gossipV2 = lnwire.GossipVersion2
39
)
40

41
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
42
// execute queries against the SQL graph tables.
43
//
44
//nolint:ll,interfacebloat
45
type SQLQueries interface {
46
        /*
47
                Node queries.
48
        */
49
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
50
        UpsertSourceNode(ctx context.Context, arg sqlc.UpsertSourceNodeParams) (int64, error)
51
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.GraphNode, error)
52
        GetNodesByIDs(ctx context.Context, ids []int64) ([]sqlc.GraphNode, error)
53
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
54
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.GraphNode, error)
55
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.GraphNode, error)
56
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
57
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
58
        IsPublicV2Node(ctx context.Context, pubKey []byte) (bool, error)
59
        DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error)
60
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
61
        DeleteNode(ctx context.Context, id int64) error
62
        NodeExists(ctx context.Context, arg sqlc.NodeExistsParams) (bool, error)
63

64
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeExtraType, error)
65
        GetNodeExtraTypesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeExtraType, error)
66
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
67
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
68

69
        UpsertNodeAddress(ctx context.Context, arg sqlc.UpsertNodeAddressParams) error
70
        GetNodeAddresses(ctx context.Context, nodeID int64) ([]sqlc.GetNodeAddressesRow, error)
71
        GetNodeAddressesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress, error)
72
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
73

74
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
75
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeFeature, error)
76
        GetNodeFeaturesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature, error)
77
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
78
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
79
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
80
        GetV2DisabledSCIDs(ctx context.Context) ([][]byte, error)
81

82
        /*
83
                Source node queries.
84
        */
85
        AddSourceNode(ctx context.Context, nodeID int64) error
86
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
87

88
        /*
89
                Channel queries.
90
        */
91
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
92
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
93
        AddV2ChannelProof(ctx context.Context, arg sqlc.AddV2ChannelProofParams) (sql.Result, error)
94
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.GraphChannel, error)
95
        GetChannelsBySCIDs(ctx context.Context, arg sqlc.GetChannelsBySCIDsParams) ([]sqlc.GraphChannel, error)
96
        GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]sqlc.GetChannelsByOutpointsRow, error)
97
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
98
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
99
        GetChannelsBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelsBySCIDWithPoliciesParams) ([]sqlc.GetChannelsBySCIDWithPoliciesRow, error)
100
        GetChannelsByIDs(ctx context.Context, ids []int64) ([]sqlc.GetChannelsByIDsRow, error)
101
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
102
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
103
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
104
        ListChannelsForNodeIDs(ctx context.Context, arg sqlc.ListChannelsForNodeIDsParams) ([]sqlc.ListChannelsForNodeIDsRow, error)
105
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
106
        ListChannelsWithPoliciesForCachePaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesForCachePaginatedParams) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow, error)
107
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
108
        ListChannelsPaginatedV2(ctx context.Context, arg sqlc.ListChannelsPaginatedV2Params) ([]sqlc.ListChannelsPaginatedV2Row, error)
109
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
110
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
111
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.GraphChannel, error)
112
        GetPublicV2ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV2ChannelsBySCIDParams) ([]sqlc.GraphChannel, error)
113
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
114
        DeleteChannels(ctx context.Context, ids []int64) error
115

116
        UpsertChannelExtraType(ctx context.Context, arg sqlc.UpsertChannelExtraTypeParams) error
117
        GetChannelExtrasBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelExtraType, error)
118
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
119
        GetChannelFeaturesBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelFeature, error)
120

121
        /*
122
                Channel Policy table queries.
123
        */
124
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
125
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.GraphChannelPolicy, error)
126

127
        UpsertChanPolicyExtraType(ctx context.Context, arg sqlc.UpsertChanPolicyExtraTypeParams) error
128
        GetChannelPolicyExtraTypesBatch(ctx context.Context, policyIds []int64) ([]sqlc.GetChannelPolicyExtraTypesBatchRow, error)
129
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
130

131
        /*
132
                Zombie index queries.
133
        */
134
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
135
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.GraphZombieChannel, error)
136
        GetZombieChannelsSCIDs(ctx context.Context, arg sqlc.GetZombieChannelsSCIDsParams) ([]sqlc.GraphZombieChannel, error)
137
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
138
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
139
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
140

141
        /*
142
                Prune log table queries.
143
        */
144
        GetPruneTip(ctx context.Context) (sqlc.GraphPruneLog, error)
145
        GetPruneHashByHeight(ctx context.Context, blockHeight int64) ([]byte, error)
146
        GetPruneEntriesForHeights(ctx context.Context, heights []int64) ([]sqlc.GraphPruneLog, error)
147
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
148
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
149

150
        /*
151
                Closed SCID table queries.
152
        */
153
        InsertClosedChannel(ctx context.Context, scid []byte) error
154
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
155
        GetClosedChannelsSCIDs(ctx context.Context, scids [][]byte) ([][]byte, error)
156

157
        /*
158
                Migration specific queries.
159

160
                NOTE: these should not be used in code other than migrations.
161
                Once sqldbv2 is in place, these can be removed from this struct
162
                as then migrations will have their own dedicated queries
163
                structs.
164
        */
165
        InsertNodeMig(ctx context.Context, arg sqlc.InsertNodeMigParams) (int64, error)
166
        InsertChannelMig(ctx context.Context, arg sqlc.InsertChannelMigParams) (int64, error)
167
        InsertEdgePolicyMig(ctx context.Context, arg sqlc.InsertEdgePolicyMigParams) (int64, error)
168
}
169

170
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
171
// database operations.
172
type BatchedSQLQueries interface {
173
        SQLQueries
174
        sqldb.BatchedTx[SQLQueries]
175
}
176

177
// SQLStore is an implementation of the Store interface that uses a SQL
178
// database as the backend.
179
type SQLStore struct {
180
        cfg *SQLStoreConfig
181
        db  BatchedSQLQueries
182

183
        // cacheMu guards all caches (rejectCache and chanCache). If
184
        // this mutex will be acquired at the same time as the DB mutex then
185
        // the cacheMu MUST be acquired first to prevent deadlock.
186
        cacheMu     sync.RWMutex
187
        rejectCache *rejectCache
188
        chanCache   *channelCache
189

190
        chanScheduler batch.Scheduler[SQLQueries]
191
        nodeScheduler batch.Scheduler[SQLQueries]
192

193
        srcNodes  map[lnwire.GossipVersion]*srcNodeInfo
194
        srcNodeMu sync.Mutex
195
}
196

197
// A compile-time assertion to ensure that SQLStore implements the Store
198
// interface.
199
var _ Store = (*SQLStore)(nil)
200

201
// SQLStoreConfig holds the configuration for the SQLStore.
202
type SQLStoreConfig struct {
203
        // ChainHash is the genesis hash for the chain that all the gossip
204
        // messages in this store are aimed at.
205
        ChainHash chainhash.Hash
206

207
        // QueryConfig holds configuration values for SQL queries.
208
        QueryCfg *sqldb.QueryConfig
209
}
210

211
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
212
// storage backend.
213
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
214
        options ...StoreOptionModifier) (*SQLStore, error) {
×
215

×
216
        opts := DefaultOptions()
×
217
        for _, o := range options {
×
218
                o(opts)
×
219
        }
×
220

221
        if opts.NoMigration {
×
222
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
223
                        "supported for SQL stores")
×
224
        }
×
225

226
        s := &SQLStore{
×
227
                cfg:         cfg,
×
228
                db:          db,
×
229
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
230
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
231
                srcNodes:    make(map[lnwire.GossipVersion]*srcNodeInfo),
×
232
        }
×
233

×
234
        s.chanScheduler = batch.NewTimeScheduler(
×
235
                db, &s.cacheMu, opts.BatchCommitInterval,
×
236
        )
×
237
        s.nodeScheduler = batch.NewTimeScheduler(
×
238
                db, nil, opts.BatchCommitInterval,
×
239
        )
×
240

×
241
        return s, nil
×
242
}
243

244
// AddNode adds a vertex/node to the graph database. If the node is not
245
// in the database from before, this will add a new, unconnected one to the
246
// graph. If it is present from before, this will update that node's
247
// information.
248
//
249
// NOTE: part of the Store interface.
250
func (s *SQLStore) AddNode(ctx context.Context,
251
        node *models.Node, opts ...batch.SchedulerOption) error {
×
252

×
253
        r := &batch.Request[SQLQueries]{
×
254
                Opts: batch.NewSchedulerOptions(opts...),
×
255
                Do: func(queries SQLQueries) error {
×
256
                        _, err := upsertNode(ctx, queries, node)
×
257

×
258
                        // It is possible that two of the same node
×
259
                        // announcements are both being processed in the same
×
260
                        // batch. This may case the UpsertNode conflict to
×
261
                        // be hit since we require at the db layer that the
×
262
                        // new last_update is greater than the existing
×
263
                        // last_update. We need to gracefully handle this here.
×
264
                        if errors.Is(err, sql.ErrNoRows) {
×
265
                                return nil
×
266
                        }
×
267

268
                        return err
×
269
                },
270
        }
271

272
        return s.nodeScheduler.Execute(ctx, r)
×
273
}
274

275
// FetchNode attempts to look up a target node by its identity public
276
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
277
// returned.
278
//
279
// NOTE: part of the Store interface.
280
func (s *SQLStore) FetchNode(ctx context.Context, v lnwire.GossipVersion,
281
        pubKey route.Vertex) (*models.Node, error) {
×
282

×
283
        var node *models.Node
×
284
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
285
                var err error
×
286
                _, node, err = getNodeByPubKey(
×
287
                        ctx, s.cfg.QueryCfg, db, v, pubKey,
×
288
                )
×
289

×
290
                return err
×
291
        }, sqldb.NoOpReset)
×
292
        if err != nil {
×
293
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
294
        }
×
295

296
        return node, nil
×
297
}
298

299
// HasV1Node determines if the graph has a vertex identified by the
300
// target node identity public key. If the node exists in the database, a
301
// timestamp of when the data for the node was lasted updated is returned along
302
// with a true boolean. Otherwise, an empty time.Time is returned with a false
303
// boolean.
304
//
305
// NOTE: part of the Store interface.
306
func (s *SQLStore) HasV1Node(ctx context.Context,
307
        pubKey [33]byte) (time.Time, bool, error) {
×
308

×
309
        var (
×
310
                exists     bool
×
311
                lastUpdate time.Time
×
312
        )
×
313
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
314
                dbNode, err := db.GetNodeByPubKey(
×
315
                        ctx, sqlc.GetNodeByPubKeyParams{
×
316
                                Version: int16(gossipV1),
×
317
                                PubKey:  pubKey[:],
×
318
                        },
×
319
                )
×
320
                if errors.Is(err, sql.ErrNoRows) {
×
321
                        return nil
×
322
                } else if err != nil {
×
323
                        return fmt.Errorf("unable to fetch node: %w", err)
×
324
                }
×
325

326
                exists = true
×
327

×
328
                if dbNode.LastUpdate.Valid {
×
329
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
330
                }
×
331

332
                return nil
×
333
        }, sqldb.NoOpReset)
334
        if err != nil {
×
335
                return time.Time{}, false,
×
336
                        fmt.Errorf("unable to fetch node: %w", err)
×
337
        }
×
338

339
        return lastUpdate, exists, nil
×
340
}
341

342
// HasNode determines if the graph has a vertex identified by the
343
// target node identity public key.
344
//
345
// NOTE: part of the Store interface.
346
func (s *SQLStore) HasNode(ctx context.Context, v lnwire.GossipVersion,
347
        pubKey [33]byte) (bool, error) {
×
348

×
349
        var exists bool
×
350
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
351
                var err error
×
352
                exists, err = db.NodeExists(ctx, sqlc.NodeExistsParams{
×
353
                        Version: int16(v),
×
354
                        PubKey:  pubKey[:],
×
355
                })
×
356

×
357
                return err
×
358
        }, sqldb.NoOpReset)
×
359
        if err != nil {
×
360
                return false, fmt.Errorf("unable to check if node (%x) "+
×
361
                        "exists: %w", pubKey, err)
×
362
        }
×
363

364
        return exists, nil
×
365
}
366

367
// AddrsForNode returns all known addresses for the target node public key
368
// that the graph DB is aware of. The returned boolean indicates if the
369
// given node is unknown to the graph DB or not.
370
//
371
// NOTE: part of the Store interface.
372
func (s *SQLStore) AddrsForNode(ctx context.Context, v lnwire.GossipVersion,
373
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
374

×
375
        var (
×
376
                addresses []net.Addr
×
377
                known     bool
×
378
        )
×
379
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
380
                // First, check if the node exists and get its DB ID if it
×
381
                // does.
×
382
                dbID, err := db.GetNodeIDByPubKey(
×
383
                        ctx, sqlc.GetNodeIDByPubKeyParams{
×
384
                                Version: int16(v),
×
385
                                PubKey:  nodePub.SerializeCompressed(),
×
386
                        },
×
387
                )
×
388
                if errors.Is(err, sql.ErrNoRows) {
×
389
                        return nil
×
390
                }
×
391

392
                known = true
×
393

×
394
                addresses, err = getNodeAddresses(ctx, db, dbID)
×
395
                if err != nil {
×
396
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
397
                                err)
×
398
                }
×
399

400
                return nil
×
401
        }, sqldb.NoOpReset)
402
        if err != nil {
×
403
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
404
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
405
        }
×
406

407
        return known, addresses, nil
×
408
}
409

410
// DeleteNode starts a new database transaction to remove a vertex/node
411
// from the database according to the node's public key.
412
//
413
// NOTE: part of the Store interface.
414
func (s *SQLStore) DeleteNode(ctx context.Context, v lnwire.GossipVersion,
415
        pubKey route.Vertex) error {
×
416

×
417
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
418
                res, err := db.DeleteNodeByPubKey(
×
419
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
420
                                Version: int16(v),
×
421
                                PubKey:  pubKey[:],
×
422
                        },
×
423
                )
×
424
                if err != nil {
×
425
                        return err
×
426
                }
×
427

428
                rows, err := res.RowsAffected()
×
429
                if err != nil {
×
430
                        return err
×
431
                }
×
432

433
                if rows == 0 {
×
434
                        return ErrGraphNodeNotFound
×
435
                } else if rows > 1 {
×
436
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
437
                }
×
438

439
                return err
×
440
        }, sqldb.NoOpReset)
441
        if err != nil {
×
442
                return fmt.Errorf("unable to delete node: %w", err)
×
443
        }
×
444

445
        return nil
×
446
}
447

448
// FetchNodeFeatures returns the features of the given node. If no features are
449
// known for the node, an empty feature vector is returned.
450
//
451
// NOTE: this is part of the graphdb.NodeTraverser interface.
452
func (s *SQLStore) FetchNodeFeatures(ctx context.Context,
453
        v lnwire.GossipVersion, nodePub route.Vertex) (*lnwire.FeatureVector,
454
        error) {
×
455

×
456
        return fetchNodeFeatures(ctx, s.db, v, nodePub)
×
457
}
×
458

459
// DisabledChannelIDs returns the channel ids of disabled channels.
460
// A channel is disabled when two of the associated ChanelEdgePolicies
461
// have their disabled bit on.
462
//
463
// NOTE: part of the Store interface.
464
func (s *SQLStore) DisabledChannelIDs(
465
        ctx context.Context, v lnwire.GossipVersion) ([]uint64, error) {
×
466

×
467
        var (
×
468
                chanIDs []uint64
×
469
        )
×
470
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
471
                var (
×
472
                        dbChanIDs [][]byte
×
473
                        err       error
×
474
                )
×
475
                switch v {
×
476
                case gossipV1:
×
477
                        dbChanIDs, err = db.GetV1DisabledSCIDs(ctx)
×
478
                case gossipV2:
×
479
                        dbChanIDs, err = db.GetV2DisabledSCIDs(ctx)
×
480
                default:
×
481
                        return fmt.Errorf("unsupported gossip version: %d", v)
×
482
                }
483
                if err != nil {
×
484
                        return fmt.Errorf("unable to fetch disabled "+
×
485
                                "channels: %w", err)
×
486
                }
×
487

488
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
489

×
490
                return nil
×
491
        }, sqldb.NoOpReset)
492
        if err != nil {
×
493
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
494
                        err)
×
495
        }
×
496

497
        return chanIDs, nil
×
498
}
499

500
// LookupAlias attempts to return the alias as advertised by the target node.
501
//
502
// NOTE: part of the Store interface.
503
func (s *SQLStore) LookupAlias(ctx context.Context, v lnwire.GossipVersion,
504
        pub *btcec.PublicKey) (string, error) {
×
505

×
506
        var alias string
×
507
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
508
                dbNode, err := db.GetNodeByPubKey(
×
509
                        ctx, sqlc.GetNodeByPubKeyParams{
×
510
                                Version: int16(v),
×
511
                                PubKey:  pub.SerializeCompressed(),
×
512
                        },
×
513
                )
×
514
                if errors.Is(err, sql.ErrNoRows) {
×
515
                        return ErrNodeAliasNotFound
×
516
                } else if err != nil {
×
517
                        return fmt.Errorf("unable to fetch node: %w", err)
×
518
                }
×
519

520
                if !dbNode.Alias.Valid {
×
521
                        return ErrNodeAliasNotFound
×
522
                }
×
523

524
                alias = dbNode.Alias.String
×
525

×
526
                return nil
×
527
        }, sqldb.NoOpReset)
528
        if err != nil {
×
529
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
530
        }
×
531

532
        return alias, nil
×
533
}
534

535
// SourceNode returns the source node of the graph. The source node is treated
536
// as the center node within a star-graph. This method may be used to kick off
537
// a path finding algorithm in order to explore the reachability of another
538
// node based off the source node.
539
//
540
// NOTE: part of the Store interface.
541
func (s *SQLStore) SourceNode(ctx context.Context,
542
        v lnwire.GossipVersion) (*models.Node, error) {
×
543

×
544
        var node *models.Node
×
545
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
546
                _, nodePub, err := s.getSourceNode(ctx, db, v)
×
547
                if err != nil {
×
548
                        return fmt.Errorf("unable to fetch source node: %w",
×
549
                                err)
×
550
                }
×
551

552
                _, node, err = getNodeByPubKey(
×
553
                        ctx, s.cfg.QueryCfg, db, v, nodePub,
×
554
                )
×
555

×
556
                return err
×
557
        }, sqldb.NoOpReset)
558
        if err != nil {
×
559
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
560
        }
×
561

562
        return node, nil
×
563
}
564

565
// SetSourceNode sets the source node within the graph database. The source
566
// node is to be used as the center of a star-graph within path finding
567
// algorithms.
568
//
569
// NOTE: part of the Store interface.
570
func (s *SQLStore) SetSourceNode(ctx context.Context,
571
        node *models.Node) error {
×
572

×
573
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
574
                // For the source node, we use a less strict upsert that allows
×
575
                // updates even when the timestamp hasn't changed. This handles
×
576
                // the race condition where multiple goroutines (e.g.,
×
577
                // setSelfNode, createNewHiddenService, RPC updates) read the
×
578
                // same old timestamp, independently increment it, and try to
×
579
                // write concurrently. We want all parameter changes to persist,
×
580
                // even if timestamps collide.
×
581
                id, err := upsertSourceNode(ctx, db, node)
×
582
                if err != nil {
×
583
                        return fmt.Errorf("unable to upsert source node: %w",
×
584
                                err)
×
585
                }
×
586

587
                // Make sure that if a source node for this version is already
588
                // set, then the ID is the same as the one we are about to set.
589
                dbSourceNodeID, _, err := s.getSourceNode(
×
590
                        ctx, db, node.Version,
×
591
                )
×
592
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
593
                        return fmt.Errorf("unable to fetch source node: %w",
×
594
                                err)
×
595
                } else if err == nil {
×
596
                        if dbSourceNodeID != id {
×
597
                                return fmt.Errorf("v1 source node already "+
×
598
                                        "set to a different node: %d vs %d",
×
599
                                        dbSourceNodeID, id)
×
600
                        }
×
601

602
                        return nil
×
603
                }
604

605
                return db.AddSourceNode(ctx, id)
×
606
        }, sqldb.NoOpReset)
607
}
608

609
// NodeUpdatesInHorizon returns all the known lightning node which have an
610
// update timestamp within the passed range. This method can be used by two
611
// nodes to quickly determine if they have the same set of up to date node
612
// announcements.
613
//
614
// NOTE: This is part of the Store interface.
615
func (s *SQLStore) NodeUpdatesInHorizon(ctx context.Context,
616
        startTime, endTime time.Time,
617
        opts ...IteratorOption) iter.Seq2[*models.Node, error] {
×
618

×
619
        cfg := defaultIteratorConfig()
×
620
        for _, opt := range opts {
×
621
                opt(cfg)
×
622
        }
×
623

624
        return func(yield func(*models.Node, error) bool) {
×
625
                var (
×
626
                        lastUpdateTime sql.NullInt64
×
627
                        lastPubKey     = make([]byte, 33)
×
628
                        hasMore        = true
×
629
                )
×
630

×
631
                // Each iteration, we'll read a batch amount of nodes, yield
×
632
                // them, then decide is we have more or not.
×
633
                for hasMore {
×
634
                        var batch []*models.Node
×
635

×
636
                        //nolint:ll
×
637
                        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
638
                                //nolint:ll
×
639
                                params := sqlc.GetNodesByLastUpdateRangeParams{
×
640
                                        StartTime: sqldb.SQLInt64(
×
641
                                                startTime.Unix(),
×
642
                                        ),
×
643
                                        EndTime: sqldb.SQLInt64(
×
644
                                                endTime.Unix(),
×
645
                                        ),
×
646
                                        LastUpdate: lastUpdateTime,
×
647
                                        LastPubKey: lastPubKey,
×
648
                                        OnlyPublic: sql.NullBool{
×
649
                                                Bool:  cfg.iterPublicNodes,
×
650
                                                Valid: true,
×
651
                                        },
×
652
                                        MaxResults: sqldb.SQLInt32(
×
653
                                                cfg.nodeUpdateIterBatchSize,
×
654
                                        ),
×
655
                                }
×
656
                                rows, err := db.GetNodesByLastUpdateRange(
×
657
                                        ctx, params,
×
658
                                )
×
659
                                if err != nil {
×
660
                                        return err
×
661
                                }
×
662

663
                                hasMore = len(rows) == cfg.nodeUpdateIterBatchSize
×
664

×
665
                                err = forEachNodeInBatch(
×
666
                                        ctx, s.cfg.QueryCfg, db, rows,
×
667
                                        func(_ int64, node *models.Node) error {
×
668
                                                batch = append(batch, node)
×
669

×
670
                                                // Update pagination cursors
×
671
                                                // based on the last processed
×
672
                                                // node.
×
673
                                                lastUpdateTime = sql.NullInt64{
×
674
                                                        Int64: node.LastUpdate.
×
675
                                                                Unix(),
×
676
                                                        Valid: true,
×
677
                                                }
×
678
                                                lastPubKey = node.PubKeyBytes[:]
×
679

×
680
                                                return nil
×
681
                                        },
×
682
                                )
683
                                if err != nil {
×
684
                                        return fmt.Errorf("unable to build "+
×
685
                                                "nodes: %w", err)
×
686
                                }
×
687

688
                                return nil
×
689
                        }, func() {
×
690
                                batch = []*models.Node{}
×
691
                        })
×
692

693
                        if err != nil {
×
694
                                log.Errorf("NodeUpdatesInHorizon batch "+
×
695
                                        "error: %v", err)
×
696

×
697
                                yield(&models.Node{}, err)
×
698

×
699
                                return
×
700
                        }
×
701

702
                        for _, node := range batch {
×
703
                                if !yield(node, nil) {
×
704
                                        return
×
705
                                }
×
706
                        }
707

708
                        // If the batch didn't yield anything, then we're done.
709
                        if len(batch) == 0 {
×
710
                                break
×
711
                        }
712
                }
713
        }
714
}
715

716
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
717
// undirected edge from the two target nodes are created. The information stored
718
// denotes the static attributes of the channel, such as the channelID, the keys
719
// involved in creation of the channel, and the set of features that the channel
720
// supports. The chanPoint and chanID are used to uniquely identify the edge
721
// globally within the database.
722
//
723
// NOTE: part of the Store interface.
724
func (s *SQLStore) AddChannelEdge(ctx context.Context,
725
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
726

×
727
        if !isKnownGossipVersion(edge.Version) {
×
728
                return fmt.Errorf("unsupported gossip version: %d",
×
729
                        edge.Version)
×
730
        }
×
731

732
        var alreadyExists bool
×
733
        r := &batch.Request[SQLQueries]{
×
734
                Opts: batch.NewSchedulerOptions(opts...),
×
735
                Reset: func() {
×
736
                        alreadyExists = false
×
737
                },
×
738
                Do: func(tx SQLQueries) error {
×
739
                        chanIDB := channelIDToBytes(edge.ChannelID)
×
740

×
741
                        // Make sure that the channel doesn't already exist. We
×
742
                        // do this explicitly instead of relying on catching a
×
743
                        // unique constraint error because relying on SQL to
×
744
                        // throw that error would abort the entire batch of
×
745
                        // transactions.
×
746
                        _, err := tx.GetChannelBySCID(
×
747
                                ctx, sqlc.GetChannelBySCIDParams{
×
748
                                        Scid:    chanIDB,
×
749
                                        Version: int16(edge.Version),
×
750
                                },
×
751
                        )
×
752
                        if err == nil {
×
753
                                alreadyExists = true
×
754
                                return nil
×
755
                        } else if !errors.Is(err, sql.ErrNoRows) {
×
756
                                return fmt.Errorf("unable to fetch channel: %w",
×
757
                                        err)
×
758
                        }
×
759

760
                        return insertChannel(ctx, tx, edge)
×
761
                },
762
                OnCommit: func(err error) error {
×
763
                        switch {
×
764
                        case err != nil:
×
765
                                return err
×
766
                        case alreadyExists:
×
767
                                return ErrEdgeAlreadyExist
×
768
                        default:
×
769
                                s.rejectCache.remove(
×
770
                                        edge.Version, edge.ChannelID,
×
771
                                )
×
772
                                s.chanCache.remove(
×
773
                                        edge.Version, edge.ChannelID,
×
774
                                )
×
775

×
776
                                return nil
×
777
                        }
778
                },
779
        }
780

781
        return s.chanScheduler.Execute(ctx, r)
×
782
}
783

784
// HighestChanID returns the "highest" known channel ID in the channel graph.
785
// This represents the "newest" channel from the PoV of the chain. This method
786
// can be used by peers to quickly determine if their graphs are in sync.
787
//
788
// NOTE: This is part of the Store interface.
789
func (s *SQLStore) HighestChanID(ctx context.Context,
790
        v lnwire.GossipVersion) (uint64, error) {
×
791

×
792
        var highestChanID uint64
×
793
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
794
                if !isKnownGossipVersion(v) {
×
795
                        return fmt.Errorf("unsupported gossip version: %d", v)
×
796
                }
×
797

798
                chanID, err := db.HighestSCID(ctx, int16(v))
×
799
                if errors.Is(err, sql.ErrNoRows) {
×
800
                        return nil
×
801
                } else if err != nil {
×
802
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
803
                                err)
×
804
                }
×
805

806
                highestChanID = byteOrder.Uint64(chanID)
×
807

×
808
                return nil
×
809
        }, sqldb.NoOpReset)
810
        if err != nil {
×
811
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
812
        }
×
813

814
        return highestChanID, nil
×
815
}
816

817
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
818
// within the database for the referenced channel. The `flags` attribute within
819
// the ChannelEdgePolicy determines which of the directed edges are being
820
// updated. If the flag is 1, then the first node's information is being
821
// updated, otherwise it's the second node's information. The node ordering is
822
// determined by the lexicographical ordering of the identity public keys of the
823
// nodes on either side of the channel.
824
//
825
// NOTE: part of the Store interface.
826
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
827
        edge *models.ChannelEdgePolicy,
828
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
829

×
830
        var (
×
831
                isUpdate1    bool
×
832
                edgeNotFound bool
×
833
                from, to     route.Vertex
×
834
        )
×
835

×
836
        r := &batch.Request[SQLQueries]{
×
837
                Opts: batch.NewSchedulerOptions(opts...),
×
838
                Reset: func() {
×
839
                        isUpdate1 = false
×
840
                        edgeNotFound = false
×
841
                },
×
842
                Do: func(tx SQLQueries) error {
×
843
                        var err error
×
844
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
845
                                ctx, tx, edge,
×
846
                        )
×
847
                        // It is possible that two of the same policy
×
848
                        // announcements are both being processed in the same
×
849
                        // batch. This may case the UpsertEdgePolicy conflict to
×
850
                        // be hit since we require at the db layer that the
×
851
                        // new last_update is greater than the existing
×
852
                        // last_update. We need to gracefully handle this here.
×
853
                        if errors.Is(err, sql.ErrNoRows) {
×
854
                                return nil
×
855
                        } else if err != nil {
×
856
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
857
                        }
×
858

859
                        // Silence ErrEdgeNotFound so that the batch can
860
                        // succeed, but propagate the error via local state.
861
                        if errors.Is(err, ErrEdgeNotFound) {
×
862
                                edgeNotFound = true
×
863
                                return nil
×
864
                        }
×
865

866
                        return err
×
867
                },
868
                OnCommit: func(err error) error {
×
869
                        switch {
×
870
                        case err != nil:
×
871
                                return err
×
872
                        case edgeNotFound:
×
873
                                return ErrEdgeNotFound
×
874
                        default:
×
875
                                s.updateEdgeCache(edge, isUpdate1)
×
876
                                return nil
×
877
                        }
878
                },
879
        }
880

881
        err := s.chanScheduler.Execute(ctx, r)
×
882

×
883
        return from, to, err
×
884
}
885

886
// updateEdgeCache updates our reject and channel caches with the new
887
// edge policy information.
888
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
889
        isUpdate1 bool) {
×
890

×
891
        // If an entry for this channel is found in reject cache, we'll modify
×
892
        // the entry with the updated timestamp for the direction that was just
×
893
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
894
        // during the next query for this edge.
×
895
        if entry, ok := s.rejectCache.get(e.Version, e.ChannelID); ok {
×
896
                switch e.Version {
×
897
                case gossipV1:
×
898
                        updateRejectCacheEntryV1(
×
899
                                &entry, isUpdate1, e.LastUpdate,
×
900
                        )
×
901
                case gossipV2:
×
902
                        updateRejectCacheEntryV2(
×
903
                                &entry, isUpdate1, e.LastBlockHeight,
×
904
                        )
×
905
                }
906
                s.rejectCache.insert(e.Version, e.ChannelID, entry)
×
907
        }
908

909
        // If an entry for this channel is found in channel cache, we'll modify
910
        // the entry with the updated policy for the direction that was just
911
        // written. If the edge doesn't exist, we'll defer loading the info and
912
        // policies and lazily read from disk during the next query.
913
        if channel, ok := s.chanCache.get(e.Version, e.ChannelID); ok {
×
914
                if isUpdate1 {
×
915
                        channel.Policy1 = e
×
916
                } else {
×
917
                        channel.Policy2 = e
×
918
                }
×
919
                s.chanCache.insert(e.Version, e.ChannelID, channel)
×
920
        }
921
}
922

923
// ForEachSourceNodeChannel iterates through all channels of the source node,
924
// executing the passed callback on each. The call-back is provided with the
925
// channel's outpoint, whether we have a policy for the channel and the channel
926
// peer's node information.
927
//
928
// NOTE: part of the Store interface.
929
func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context,
930
        v lnwire.GossipVersion, cb func(chanPoint wire.OutPoint,
931
                havePolicy bool, otherNode *models.Node) error,
932
        reset func()) error {
×
933

×
934
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
935
                nodeID, nodePub, err := s.getSourceNode(ctx, db, v)
×
936
                if err != nil {
×
937
                        return fmt.Errorf("unable to fetch source node: %w",
×
938
                                err)
×
939
                }
×
940

941
                return forEachNodeChannel(
×
942
                        ctx, db, s.cfg, v, nodeID,
×
943
                        func(info *models.ChannelEdgeInfo,
×
944
                                outPolicy *models.ChannelEdgePolicy,
×
945
                                _ *models.ChannelEdgePolicy) error {
×
946

×
947
                                // Fetch the other node.
×
948
                                var (
×
949
                                        otherNodePub [33]byte
×
950
                                        node1        = info.NodeKey1Bytes
×
951
                                        node2        = info.NodeKey2Bytes
×
952
                                )
×
953
                                switch {
×
954
                                case bytes.Equal(node1[:], nodePub[:]):
×
955
                                        otherNodePub = node2
×
956
                                case bytes.Equal(node2[:], nodePub[:]):
×
957
                                        otherNodePub = node1
×
958
                                default:
×
959
                                        return fmt.Errorf("node not " +
×
960
                                                "participating in this channel")
×
961
                                }
962

963
                                _, otherNode, err := getNodeByPubKey(
×
964
                                        ctx, s.cfg.QueryCfg, db, v,
×
965
                                        otherNodePub,
×
966
                                )
×
967
                                if err != nil {
×
968
                                        return fmt.Errorf("unable to fetch "+
×
969
                                                "other node(%x): %w",
×
970
                                                otherNodePub, err)
×
971
                                }
×
972

973
                                return cb(
×
974
                                        info.ChannelPoint, outPolicy != nil,
×
975
                                        otherNode,
×
976
                                )
×
977
                        },
978
                )
979
        }, reset)
980
}
981

982
// ForEachNode iterates through all the stored vertices/nodes in the graph,
983
// executing the passed callback with each node encountered. If the callback
984
// returns an error, then the transaction is aborted and the iteration stops
985
// early.
986
//
987
// NOTE: part of the Store interface.
988
func (s *SQLStore) ForEachNode(ctx context.Context, v lnwire.GossipVersion,
989
        cb func(node *models.Node) error, reset func()) error {
×
990

×
991
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
992
                return forEachNodePaginated(
×
993
                        ctx, s.cfg.QueryCfg, db,
×
NEW
994
                        v, func(_ context.Context, _ int64,
×
995
                                node *models.Node) error {
×
996

×
997
                                return cb(node)
×
998
                        },
×
999
                )
1000
        }, reset)
1001
}
1002

1003
// ForEachNodeDirectedChannel iterates through all channels of a given node,
1004
// executing the passed callback on the directed edge representing the channel
1005
// and its incoming policy. If the callback returns an error, then the iteration
1006
// is halted with the error propagated back up to the caller.
1007
//
1008
// Unknown policies are passed into the callback as nil values.
1009
//
1010
// NOTE: this is part of the graphdb.NodeTraverser interface.
1011
func (s *SQLStore) ForEachNodeDirectedChannel(ctx context.Context,
1012
        v lnwire.GossipVersion, nodePub route.Vertex,
1013
        cb func(channel *DirectedChannel) error, reset func()) error {
×
1014

×
1015
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1016
                return forEachNodeDirectedChannel(ctx, db, v, nodePub, cb)
×
1017
        }, reset)
×
1018
}
1019

1020
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
1021
// graph, executing the passed callback with each node encountered. If the
1022
// callback returns an error, then the transaction is aborted and the iteration
1023
// stops early.
1024
func (s *SQLStore) ForEachNodeCacheable(ctx context.Context,
1025
        v lnwire.GossipVersion, cb func(route.Vertex,
1026
                *lnwire.FeatureVector) error, reset func()) error {
×
1027

×
1028
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1029
                return forEachNodeCacheable(
×
1030
                        ctx, s.cfg.QueryCfg, db, v,
×
1031
                        func(_ int64, nodePub route.Vertex,
×
1032
                                features *lnwire.FeatureVector) error {
×
1033

×
1034
                                return cb(nodePub, features)
×
1035
                        },
×
1036
                )
1037
        }, reset)
1038
        if err != nil {
×
1039
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
1040
        }
×
1041

1042
        return nil
×
1043
}
1044

1045
// ForEachNodeChannel iterates through all channels of the given node,
1046
// executing the passed callback with an edge info structure and the policies
1047
// of each end of the channel. The first edge policy is the outgoing edge *to*
1048
// the connecting node, while the second is the incoming edge *from* the
1049
// connecting node. If the callback returns an error, then the iteration is
1050
// halted with the error propagated back up to the caller.
1051
//
1052
// Unknown policies are passed into the callback as nil values.
1053
//
1054
// NOTE: part of the Store interface.
1055
func (s *SQLStore) ForEachNodeChannel(ctx context.Context,
1056
        v lnwire.GossipVersion, nodePub route.Vertex,
1057
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1058
                *models.ChannelEdgePolicy) error, reset func()) error {
×
1059

×
1060
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1061
                dbNode, err := db.GetNodeByPubKey(
×
1062
                        ctx, sqlc.GetNodeByPubKeyParams{
×
1063
                                Version: int16(v),
×
1064
                                PubKey:  nodePub[:],
×
1065
                        },
×
1066
                )
×
1067
                if errors.Is(err, sql.ErrNoRows) {
×
1068
                        return nil
×
1069
                } else if err != nil {
×
1070
                        return fmt.Errorf("unable to fetch node: %w", err)
×
1071
                }
×
1072

1073
                return forEachNodeChannel(ctx, db, s.cfg, v, dbNode.ID, cb)
×
1074
        }, reset)
1075
}
1076

1077
// extractMaxUpdateTime returns the maximum of the two policy update times.
1078
// This is used for pagination cursor tracking.
1079
func extractMaxUpdateTime(
1080
        row sqlc.GetChannelsByPolicyLastUpdateRangeRow) int64 {
×
1081

×
1082
        switch {
×
1083
        case row.Policy1LastUpdate.Valid && row.Policy2LastUpdate.Valid:
×
1084
                return max(row.Policy1LastUpdate.Int64,
×
1085
                        row.Policy2LastUpdate.Int64)
×
1086
        case row.Policy1LastUpdate.Valid:
×
1087
                return row.Policy1LastUpdate.Int64
×
1088
        case row.Policy2LastUpdate.Valid:
×
1089
                return row.Policy2LastUpdate.Int64
×
1090
        default:
×
1091
                return 0
×
1092
        }
1093
}
1094

1095
// buildChannelFromRow constructs a ChannelEdge from a database row.
1096
// This includes building the nodes, channel info, and policies.
1097
func (s *SQLStore) buildChannelFromRow(ctx context.Context, db SQLQueries,
1098
        row sqlc.GetChannelsByPolicyLastUpdateRangeRow) (ChannelEdge, error) {
×
1099

×
1100
        node1, err := buildNode(ctx, s.cfg.QueryCfg, db, row.GraphNode)
×
1101
        if err != nil {
×
1102
                return ChannelEdge{}, fmt.Errorf("unable to build node1: %w",
×
1103
                        err)
×
1104
        }
×
1105

1106
        node2, err := buildNode(ctx, s.cfg.QueryCfg, db, row.GraphNode_2)
×
1107
        if err != nil {
×
1108
                return ChannelEdge{}, fmt.Errorf("unable to build node2: %w",
×
1109
                        err)
×
1110
        }
×
1111

1112
        channel, err := getAndBuildEdgeInfo(
×
1113
                ctx, s.cfg, db,
×
1114
                row.GraphChannel, node1.PubKeyBytes,
×
1115
                node2.PubKeyBytes,
×
1116
        )
×
1117
        if err != nil {
×
1118
                return ChannelEdge{}, fmt.Errorf("unable to build "+
×
1119
                        "channel info: %w", err)
×
1120
        }
×
1121

1122
        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1123
        if err != nil {
×
1124
                return ChannelEdge{}, fmt.Errorf("unable to extract "+
×
1125
                        "channel policies: %w", err)
×
1126
        }
×
1127

1128
        p1, p2, err := getAndBuildChanPolicies(
×
1129
                ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, channel.ChannelID,
×
1130
                node1.PubKeyBytes, node2.PubKeyBytes,
×
1131
        )
×
1132
        if err != nil {
×
1133
                return ChannelEdge{}, fmt.Errorf("unable to build "+
×
1134
                        "channel policies: %w", err)
×
1135
        }
×
1136

1137
        return ChannelEdge{
×
1138
                Info:    channel,
×
1139
                Policy1: p1,
×
1140
                Policy2: p2,
×
1141
                Node1:   node1,
×
1142
                Node2:   node2,
×
1143
        }, nil
×
1144
}
1145

1146
// updateChanCacheBatch updates the channel cache with multiple edges at once.
1147
// This method acquires the cache lock only once for the entire batch.
1148
func (s *SQLStore) updateChanCacheBatch(v lnwire.GossipVersion,
1149
        edgesToCache map[uint64]ChannelEdge) {
×
1150

×
1151
        if len(edgesToCache) == 0 {
×
1152
                return
×
1153
        }
×
1154

1155
        s.cacheMu.Lock()
×
1156
        defer s.cacheMu.Unlock()
×
1157

×
1158
        for chanID, edge := range edgesToCache {
×
1159
                s.chanCache.insert(v, chanID, edge)
×
1160
        }
×
1161
}
1162

1163
// ChanUpdatesInHorizon returns all the known channel edges which have at least
1164
// one edge that has an update timestamp within the specified horizon.
1165
//
1166
// Iterator Lifecycle:
1167
// 1. Initialize state (edgesSeen map, cache tracking, pagination cursors)
1168
// 2. Query batch of channels with policies in time range
1169
// 3. For each channel: check if seen, check cache, or build from DB
1170
// 4. Yield channels to caller
1171
// 5. Update cache after successful batch
1172
// 6. Repeat with updated pagination cursor until no more results
1173
//
1174
// NOTE: This is part of the Store interface.
1175
func (s *SQLStore) ChanUpdatesInHorizon(ctx context.Context,
1176
        startTime, endTime time.Time,
1177
        opts ...IteratorOption) iter.Seq2[ChannelEdge, error] {
×
1178

×
1179
        // Apply options.
×
1180
        cfg := defaultIteratorConfig()
×
1181
        for _, opt := range opts {
×
1182
                opt(cfg)
×
1183
        }
×
1184

1185
        return func(yield func(ChannelEdge, error) bool) {
×
1186
                var (
×
1187
                        edgesSeen      = make(map[uint64]struct{})
×
1188
                        edgesToCache   = make(map[uint64]ChannelEdge)
×
1189
                        hits           int
×
1190
                        total          int
×
1191
                        lastUpdateTime sql.NullInt64
×
1192
                        lastID         sql.NullInt64
×
1193
                        hasMore        = true
×
1194
                )
×
1195

×
1196
                // Each iteration, we'll read a batch amount of channel updates
×
1197
                // (consulting the cache along the way), yield them, then loop
×
1198
                // back to decide if we have any more updates to read out.
×
1199
                for hasMore {
×
1200
                        var batch []ChannelEdge
×
1201

×
1202
                        // Acquire read lock before starting transaction to
×
1203
                        // ensure consistent lock ordering (cacheMu -> DB) and
×
1204
                        // prevent deadlock with write operations.
×
1205
                        s.cacheMu.RLock()
×
1206

×
1207
                        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(),
×
1208
                                func(db SQLQueries) error {
×
1209
                                        //nolint:ll
×
1210
                                        params := sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
1211
                                                Version: int16(lnwire.GossipVersion1),
×
1212
                                                StartTime: sqldb.SQLInt64(
×
1213
                                                        startTime.Unix(),
×
1214
                                                ),
×
1215
                                                EndTime: sqldb.SQLInt64(
×
1216
                                                        endTime.Unix(),
×
1217
                                                ),
×
1218
                                                LastUpdateTime: lastUpdateTime,
×
1219
                                                LastID:         lastID,
×
1220
                                                MaxResults: sql.NullInt32{
×
1221
                                                        Int32: int32(
×
1222
                                                                cfg.chanUpdateIterBatchSize,
×
1223
                                                        ),
×
1224
                                                        Valid: true,
×
1225
                                                },
×
1226
                                        }
×
1227
                                        //nolint:ll
×
1228
                                        rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
1229
                                                ctx, params,
×
1230
                                        )
×
1231
                                        if err != nil {
×
1232
                                                return err
×
1233
                                        }
×
1234

1235
                                        //nolint:ll
1236
                                        hasMore = len(rows) == cfg.chanUpdateIterBatchSize
×
1237

×
1238
                                        //nolint:ll
×
1239
                                        for _, row := range rows {
×
1240
                                                lastUpdateTime = sql.NullInt64{
×
1241
                                                        Int64: extractMaxUpdateTime(row),
×
1242
                                                        Valid: true,
×
1243
                                                }
×
1244
                                                lastID = sql.NullInt64{
×
1245
                                                        Int64: row.GraphChannel.ID,
×
1246
                                                        Valid: true,
×
1247
                                                }
×
1248

×
1249
                                                // Skip if we've already
×
1250
                                                // processed this channel.
×
1251
                                                chanIDInt := byteOrder.Uint64(
×
1252
                                                        row.GraphChannel.Scid,
×
1253
                                                )
×
1254
                                                _, ok := edgesSeen[chanIDInt]
×
1255
                                                if ok {
×
1256
                                                        continue
×
1257
                                                }
1258

1259
                                                // Check cache (we already hold
1260
                                                // shared read lock).
1261
                                                channel, ok := s.chanCache.get(
×
1262
                                                        lnwire.GossipVersion1,
×
1263
                                                        chanIDInt,
×
1264
                                                )
×
1265
                                                if ok {
×
1266
                                                        hits++
×
1267
                                                        total++
×
1268
                                                        edgesSeen[chanIDInt] = struct{}{}
×
1269
                                                        batch = append(batch, channel)
×
1270

×
1271
                                                        continue
×
1272
                                                }
1273

1274
                                                chanEdge, err := s.buildChannelFromRow(
×
1275
                                                        ctx, db, row,
×
1276
                                                )
×
1277
                                                if err != nil {
×
1278
                                                        return err
×
1279
                                                }
×
1280

1281
                                                edgesSeen[chanIDInt] = struct{}{}
×
1282
                                                edgesToCache[chanIDInt] = chanEdge
×
1283

×
1284
                                                batch = append(batch, chanEdge)
×
1285

×
1286
                                                total++
×
1287
                                        }
1288

1289
                                        return nil
×
1290
                                }, func() {
×
1291
                                        batch = nil
×
1292
                                        edgesSeen = make(map[uint64]struct{})
×
1293
                                        edgesToCache = make(
×
1294
                                                map[uint64]ChannelEdge,
×
1295
                                        )
×
1296
                                })
×
1297

1298
                        // Release read lock after transaction completes.
1299
                        s.cacheMu.RUnlock()
×
1300

×
1301
                        if err != nil {
×
1302
                                log.Errorf("ChanUpdatesInHorizon "+
×
1303
                                        "batch error: %v", err)
×
1304

×
1305
                                yield(ChannelEdge{}, err)
×
1306

×
1307
                                return
×
1308
                        }
×
1309

1310
                        for _, edge := range batch {
×
1311
                                if !yield(edge, nil) {
×
1312
                                        return
×
1313
                                }
×
1314
                        }
1315

1316
                        // Update cache after successful batch yield, setting
1317
                        // the cache lock only once for the entire batch.
1318
                        s.updateChanCacheBatch(
×
1319
                                lnwire.GossipVersion1, edgesToCache,
×
1320
                        )
×
1321
                        edgesToCache = make(map[uint64]ChannelEdge)
×
1322

×
1323
                        // If the batch didn't yield anything, then we're done.
×
1324
                        if len(batch) == 0 {
×
1325
                                break
×
1326
                        }
1327
                }
1328

1329
                if total > 0 {
×
1330
                        log.Debugf("ChanUpdatesInHorizon hit percentage: "+
×
1331
                                "%.2f (%d/%d)",
×
1332
                                float64(hits)*100/float64(total), hits, total)
×
1333
                } else {
×
1334
                        log.Debugf("ChanUpdatesInHorizon returned no edges "+
×
1335
                                "in horizon (%s, %s)", startTime, endTime)
×
1336
                }
×
1337
        }
1338
}
1339

1340
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1341
// data to the call-back. If withAddrs is true, then the call-back will also be
1342
// provided with the addresses associated with the node. The address retrieval
1343
// result in an additional round-trip to the database, so it should only be used
1344
// if the addresses are actually needed.
1345
//
1346
// NOTE: part of the Store interface.
1347
func (s *SQLStore) ForEachNodeCached(ctx context.Context,
1348
        v lnwire.GossipVersion, withAddrs bool,
1349
        cb func(ctx context.Context, node route.Vertex, addrs []net.Addr,
1350
                chans map[uint64]*DirectedChannel) error, reset func()) error {
×
1351

×
1352
        type nodeCachedBatchData struct {
×
1353
                features      map[int64][]int
×
1354
                addrs         map[int64][]nodeAddress
×
1355
                chanBatchData *batchChannelData
×
1356
                chanMap       map[int64][]sqlc.ListChannelsForNodeIDsRow
×
1357
        }
×
1358

×
1359
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1360
                // pageQueryFunc is used to query the next page of nodes.
×
1361
                pageQueryFunc := func(ctx context.Context, lastID int64,
×
1362
                        limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
1363

×
1364
                        return db.ListNodeIDsAndPubKeys(
×
1365
                                ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
NEW
1366
                                        Version: int16(v),
×
1367
                                        ID:      lastID,
×
1368
                                        Limit:   limit,
×
1369
                                },
×
1370
                        )
×
1371
                }
×
1372

1373
                // batchDataFunc is then used to batch load the data required
1374
                // for each page of nodes.
1375
                batchDataFunc := func(ctx context.Context,
×
1376
                        nodeIDs []int64) (*nodeCachedBatchData, error) {
×
1377

×
1378
                        // Batch load node features.
×
1379
                        nodeFeatures, err := batchLoadNodeFeaturesHelper(
×
1380
                                ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1381
                        )
×
1382
                        if err != nil {
×
1383
                                return nil, fmt.Errorf("unable to batch load "+
×
1384
                                        "node features: %w", err)
×
1385
                        }
×
1386

1387
                        // Maybe fetch the node's addresses if requested.
1388
                        var nodeAddrs map[int64][]nodeAddress
×
1389
                        if withAddrs {
×
1390
                                nodeAddrs, err = batchLoadNodeAddressesHelper(
×
1391
                                        ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1392
                                )
×
1393
                                if err != nil {
×
1394
                                        return nil, fmt.Errorf("unable to "+
×
1395
                                                "batch load node "+
×
1396
                                                "addresses: %w", err)
×
1397
                                }
×
1398
                        }
1399

1400
                        // Batch load ALL unique channels for ALL nodes in this
1401
                        // page.
1402
                        allChannels, err := db.ListChannelsForNodeIDs(
×
1403
                                ctx, sqlc.ListChannelsForNodeIDsParams{
×
1404
                                        Version:  int16(lnwire.GossipVersion1),
×
1405
                                        Node1Ids: nodeIDs,
×
1406
                                        Node2Ids: nodeIDs,
×
1407
                                },
×
1408
                        )
×
1409
                        if err != nil {
×
1410
                                return nil, fmt.Errorf("unable to batch "+
×
1411
                                        "fetch channels for nodes: %w", err)
×
1412
                        }
×
1413

1414
                        // Deduplicate channels and collect IDs.
1415
                        var (
×
1416
                                allChannelIDs []int64
×
1417
                                allPolicyIDs  []int64
×
1418
                        )
×
1419
                        uniqueChannels := make(
×
1420
                                map[int64]sqlc.ListChannelsForNodeIDsRow,
×
1421
                        )
×
1422

×
1423
                        for _, channel := range allChannels {
×
1424
                                channelID := channel.GraphChannel.ID
×
1425

×
1426
                                // Only process each unique channel once.
×
1427
                                _, exists := uniqueChannels[channelID]
×
1428
                                if exists {
×
1429
                                        continue
×
1430
                                }
1431

1432
                                uniqueChannels[channelID] = channel
×
1433
                                allChannelIDs = append(allChannelIDs, channelID)
×
1434

×
1435
                                if channel.Policy1ID.Valid {
×
1436
                                        allPolicyIDs = append(
×
1437
                                                allPolicyIDs,
×
1438
                                                channel.Policy1ID.Int64,
×
1439
                                        )
×
1440
                                }
×
1441
                                if channel.Policy2ID.Valid {
×
1442
                                        allPolicyIDs = append(
×
1443
                                                allPolicyIDs,
×
1444
                                                channel.Policy2ID.Int64,
×
1445
                                        )
×
1446
                                }
×
1447
                        }
1448

1449
                        // Batch load channel data for all unique channels.
1450
                        channelBatchData, err := batchLoadChannelData(
×
1451
                                ctx, s.cfg.QueryCfg, db, allChannelIDs,
×
1452
                                allPolicyIDs,
×
1453
                        )
×
1454
                        if err != nil {
×
1455
                                return nil, fmt.Errorf("unable to batch "+
×
1456
                                        "load channel data: %w", err)
×
1457
                        }
×
1458

1459
                        // Create map of node ID to channels that involve this
1460
                        // node.
1461
                        nodeIDSet := make(map[int64]bool)
×
1462
                        for _, nodeID := range nodeIDs {
×
1463
                                nodeIDSet[nodeID] = true
×
1464
                        }
×
1465

1466
                        nodeChannelMap := make(
×
1467
                                map[int64][]sqlc.ListChannelsForNodeIDsRow,
×
1468
                        )
×
1469
                        for _, channel := range uniqueChannels {
×
1470
                                // Add channel to both nodes if they're in our
×
1471
                                // current page.
×
1472
                                node1 := channel.GraphChannel.NodeID1
×
1473
                                if nodeIDSet[node1] {
×
1474
                                        nodeChannelMap[node1] = append(
×
1475
                                                nodeChannelMap[node1], channel,
×
1476
                                        )
×
1477
                                }
×
1478
                                node2 := channel.GraphChannel.NodeID2
×
1479
                                if nodeIDSet[node2] {
×
1480
                                        nodeChannelMap[node2] = append(
×
1481
                                                nodeChannelMap[node2], channel,
×
1482
                                        )
×
1483
                                }
×
1484
                        }
1485

1486
                        return &nodeCachedBatchData{
×
1487
                                features:      nodeFeatures,
×
1488
                                addrs:         nodeAddrs,
×
1489
                                chanBatchData: channelBatchData,
×
1490
                                chanMap:       nodeChannelMap,
×
1491
                        }, nil
×
1492
                }
1493

1494
                // processItem is used to process each node in the current page.
1495
                processItem := func(ctx context.Context,
×
1496
                        nodeData sqlc.ListNodeIDsAndPubKeysRow,
×
1497
                        batchData *nodeCachedBatchData) error {
×
1498

×
1499
                        // Build feature vector for this node.
×
1500
                        fv := lnwire.EmptyFeatureVector()
×
1501
                        features, exists := batchData.features[nodeData.ID]
×
1502
                        if exists {
×
1503
                                for _, bit := range features {
×
1504
                                        fv.Set(lnwire.FeatureBit(bit))
×
1505
                                }
×
1506
                        }
1507

1508
                        var nodePub route.Vertex
×
1509
                        copy(nodePub[:], nodeData.PubKey)
×
1510

×
1511
                        nodeChannels := batchData.chanMap[nodeData.ID]
×
1512

×
1513
                        toNodeCallback := func() route.Vertex {
×
1514
                                return nodePub
×
1515
                        }
×
1516

1517
                        // Build cached channels map for this node.
1518
                        channels := make(map[uint64]*DirectedChannel)
×
1519
                        for _, channelRow := range nodeChannels {
×
1520
                                directedChan, err := buildDirectedChannel(
×
1521
                                        s.cfg.ChainHash, nodeData.ID, nodePub,
×
1522
                                        channelRow, batchData.chanBatchData, fv,
×
1523
                                        toNodeCallback,
×
1524
                                )
×
1525
                                if err != nil {
×
1526
                                        return err
×
1527
                                }
×
1528

1529
                                channels[directedChan.ChannelID] = directedChan
×
1530
                        }
1531

1532
                        addrs, err := buildNodeAddresses(
×
1533
                                batchData.addrs[nodeData.ID],
×
1534
                        )
×
1535
                        if err != nil {
×
1536
                                return fmt.Errorf("unable to build node "+
×
1537
                                        "addresses: %w", err)
×
1538
                        }
×
1539

1540
                        return cb(ctx, nodePub, addrs, channels)
×
1541
                }
1542

1543
                return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
1544
                        ctx, s.cfg.QueryCfg, int64(-1), pageQueryFunc,
×
1545
                        func(node sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
1546
                                return node.ID
×
1547
                        },
×
1548
                        func(node sqlc.ListNodeIDsAndPubKeysRow) (int64,
1549
                                error) {
×
1550

×
1551
                                return node.ID, nil
×
1552
                        },
×
1553
                        batchDataFunc, processItem,
1554
                )
1555
        }, reset)
1556
}
1557

1558
// ForEachChannelCacheable iterates through all the channel edges stored
1559
// within the graph and invokes the passed callback for each edge. The
1560
// callback takes two edges as since this is a directed graph, both the
1561
// in/out edges are visited. If the callback returns an error, then the
1562
// transaction is aborted and the iteration stops early.
1563
//
1564
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1565
// pointer for that particular channel edge routing policy will be
1566
// passed into the callback.
1567
//
1568
// NOTE: this method is like ForEachChannel but fetches only the data
1569
// required for the graph cache.
1570
func (s *SQLStore) ForEachChannelCacheable(ctx context.Context,
1571
        v lnwire.GossipVersion,
1572
        cb func(*models.CachedEdgeInfo, *models.CachedEdgePolicy,
1573
                *models.CachedEdgePolicy) error, reset func()) error {
×
1574

×
1575
        if !isKnownGossipVersion(v) {
×
1576
                return fmt.Errorf("unsupported gossip version: %d", v)
×
1577
        }
×
1578

1579
        handleChannel := func(_ context.Context,
×
1580
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) error {
×
1581

×
1582
                node1, node2, err := buildNodeVertices(
×
1583
                        row.Node1Pubkey, row.Node2Pubkey,
×
1584
                )
×
1585
                if err != nil {
×
1586
                        return err
×
1587
                }
×
1588

1589
                edge := buildCacheableChannelInfo(
×
1590
                        row.Scid, row.Capacity.Int64, node1, node2,
×
1591
                )
×
1592

×
1593
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1594
                if err != nil {
×
1595
                        return err
×
1596
                }
×
1597

1598
                pol1, pol2, err := buildCachedChanPolicies(
×
1599
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1600
                )
×
1601
                if err != nil {
×
1602
                        return err
×
1603
                }
×
1604

1605
                return cb(edge, pol1, pol2)
×
1606
        }
1607

1608
        extractCursor := func(
×
1609
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) int64 {
×
1610

×
1611
                return row.ID
×
1612
        }
×
1613

1614
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1615
                //nolint:ll
×
1616
                queryFunc := func(ctx context.Context, lastID int64,
×
1617
                        limit int32) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow,
×
1618
                        error) {
×
1619

×
1620
                        return db.ListChannelsWithPoliciesForCachePaginated(
×
1621
                                ctx, sqlc.ListChannelsWithPoliciesForCachePaginatedParams{
×
1622
                                        Version: int16(v),
×
1623
                                        ID:      lastID,
×
1624
                                        Limit:   limit,
×
1625
                                },
×
1626
                        )
×
1627
                }
×
1628

1629
                return sqldb.ExecutePaginatedQuery(
×
1630
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
1631
                        extractCursor, handleChannel,
×
1632
                )
×
1633
        }, reset)
1634
}
1635

1636
// ForEachChannel iterates through all the channel edges stored within the
1637
// graph and invokes the passed callback for each edge. The callback takes two
1638
// edges as since this is a directed graph, both the in/out edges are visited.
1639
// If the callback returns an error, then the transaction is aborted and the
1640
// iteration stops early.
1641
//
1642
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1643
// for that particular channel edge routing policy will be passed into the
1644
// callback.
1645
//
1646
// NOTE: part of the Store interface.
1647
func (s *SQLStore) ForEachChannel(ctx context.Context,
1648
        v lnwire.GossipVersion, cb func(*models.ChannelEdgeInfo,
1649
                *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error,
1650
        reset func()) error {
×
1651

×
1652
        if !isKnownGossipVersion(v) {
×
1653
                return fmt.Errorf("unsupported gossip version: %d", v)
×
1654
        }
×
1655

1656
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1657
                return forEachChannelWithPolicies(ctx, db, s.cfg, v, cb)
×
1658
        }, reset)
×
1659
}
1660

1661
// FilterChannelRange returns the channel ID's of all known channels which were
1662
// mined in a block height within the passed range. The channel IDs are grouped
1663
// by their common block height. This method can be used to quickly share with a
1664
// peer the set of channels we know of within a particular range to catch them
1665
// up after a period of time offline. If withTimestamps is true then the
1666
// timestamp info of the latest received channel update messages of the channel
1667
// will be included in the response.
1668
//
1669
// NOTE: This is part of the Store interface.
1670
func (s *SQLStore) FilterChannelRange(ctx context.Context,
1671
        v lnwire.GossipVersion, startHeight, endHeight uint32,
NEW
1672
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1673

×
1674
        var (
×
1675
                startSCID = &lnwire.ShortChannelID{
×
1676
                        BlockHeight: startHeight,
×
1677
                }
×
1678
                endSCID = lnwire.ShortChannelID{
×
1679
                        BlockHeight: endHeight,
×
1680
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1681
                        TxPosition:  math.MaxUint16,
×
1682
                }
×
1683
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1684
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1685
        )
×
1686

×
1687
        // 1) get all channels where channelID is between start and end chan ID.
×
1688
        // 2) skip if not public (ie, no channel_proof)
×
1689
        // 3) collect that channel.
×
1690
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1691
        //    and add those timestamps to the collected channel.
×
1692
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1693
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
1694
                var (
×
NEW
1695
                        dbChans []sqlc.GraphChannel
×
NEW
1696
                        chanErr error
×
1697
                )
×
NEW
1698

×
NEW
1699
                switch v {
×
NEW
1700
                case gossipV1:
×
NEW
1701
                        dbChans, chanErr = db.GetPublicV1ChannelsBySCID(
×
NEW
1702
                                ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
NEW
1703
                                        StartScid: chanIDStart,
×
NEW
1704
                                        EndScid:   chanIDEnd,
×
NEW
1705
                                },
×
NEW
1706
                        )
×
NEW
1707
                case gossipV2:
×
NEW
1708
                        dbChans, chanErr = db.GetPublicV2ChannelsBySCID(
×
NEW
1709
                                ctx, sqlc.GetPublicV2ChannelsBySCIDParams{
×
NEW
1710
                                        StartScid: chanIDStart,
×
NEW
1711
                                        EndScid:   chanIDEnd,
×
NEW
1712
                                },
×
NEW
1713
                        )
×
NEW
1714
                default:
×
NEW
1715
                        return fmt.Errorf("unsupported gossip version: %d", v)
×
1716
                }
NEW
1717
                if chanErr != nil {
×
1718
                        return fmt.Errorf("unable to fetch channel range: %w",
×
NEW
1719
                                chanErr)
×
1720
                }
×
1721

1722
                for _, dbChan := range dbChans {
×
1723
                        cid := lnwire.NewShortChanIDFromInt(
×
1724
                                byteOrder.Uint64(dbChan.Scid),
×
1725
                        )
×
NEW
1726

×
NEW
1727
                        var chanInfo ChannelUpdateInfo
×
NEW
1728
                        switch v {
×
NEW
1729
                        case gossipV1:
×
NEW
1730
                                chanInfo = NewV1ChannelUpdateInfo(
×
NEW
1731
                                        cid, time.Time{}, time.Time{},
×
NEW
1732
                                )
×
NEW
1733
                        case gossipV2:
×
NEW
1734
                                chanInfo = NewV2ChannelUpdateInfo(cid, 0, 0)
×
1735
                        }
1736

1737
                        if !withTimestamps {
×
1738
                                channelsPerBlock[cid.BlockHeight] = append(
×
1739
                                        channelsPerBlock[cid.BlockHeight],
×
1740
                                        chanInfo,
×
1741
                                )
×
1742

×
1743
                                continue
×
1744
                        }
1745

1746
                        //nolint:ll
1747
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1748
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
NEW
1749
                                        Version:   int16(v),
×
1750
                                        ChannelID: dbChan.ID,
×
1751
                                        NodeID:    dbChan.NodeID1,
×
1752
                                },
×
1753
                        )
×
1754
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1755
                                return fmt.Errorf("unable to fetch node1 "+
×
1756
                                        "policy: %w", err)
×
1757
                        } else if err == nil {
×
NEW
1758
                                n1Update := node1Policy.LastUpdate.Int64
×
NEW
1759
                                n1Height := node1Policy.BlockHeight.Int64
×
NEW
1760

×
NEW
1761
                                switch v {
×
NEW
1762
                                case gossipV1:
×
NEW
1763
                                        chanInfo.Node1Freshness =
×
NEW
1764
                                                lnwire.UnixTimestamp(n1Update)
×
NEW
1765
                                case gossipV2:
×
NEW
1766
                                        chanInfo.Node1Freshness =
×
NEW
1767
                                                lnwire.BlockHeightTimestamp(
×
NEW
1768
                                                        n1Height,
×
NEW
1769
                                                )
×
1770
                                }
1771
                        }
1772

1773
                        //nolint:ll
1774
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1775
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
NEW
1776
                                        Version:   int16(v),
×
1777
                                        ChannelID: dbChan.ID,
×
1778
                                        NodeID:    dbChan.NodeID2,
×
1779
                                },
×
1780
                        )
×
1781
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1782
                                return fmt.Errorf("unable to fetch node2 "+
×
1783
                                        "policy: %w", err)
×
1784
                        } else if err == nil {
×
NEW
1785
                                n2Update := node2Policy.LastUpdate.Int64
×
NEW
1786
                                n2Height := node2Policy.BlockHeight.Int64
×
NEW
1787

×
NEW
1788
                                switch v {
×
NEW
1789
                                case gossipV1:
×
NEW
1790
                                        chanInfo.Node2Freshness =
×
NEW
1791
                                                lnwire.UnixTimestamp(n2Update)
×
NEW
1792
                                case gossipV2:
×
NEW
1793
                                        chanInfo.Node2Freshness =
×
NEW
1794
                                                lnwire.BlockHeightTimestamp(
×
NEW
1795
                                                        n2Height,
×
NEW
1796
                                                )
×
1797
                                }
1798
                        }
1799

1800
                        channelsPerBlock[cid.BlockHeight] = append(
×
1801
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1802
                        )
×
1803
                }
1804

1805
                return nil
×
1806
        }, func() {
×
1807
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1808
        })
×
1809
        if err != nil {
×
1810
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1811
        }
×
1812

1813
        if len(channelsPerBlock) == 0 {
×
1814
                return nil, nil
×
1815
        }
×
1816

1817
        // Return the channel ranges in ascending block height order.
1818
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1819
        slices.Sort(blocks)
×
1820

×
1821
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1822
                return BlockChannelRange{
×
1823
                        Height:   block,
×
1824
                        Channels: channelsPerBlock[block],
×
1825
                }
×
1826
        }), nil
×
1827
}
1828

1829
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1830
// zombie for the given gossip version. This method is used on an ad-hoc basis,
1831
// when channels need to be marked as zombies outside the normal pruning cycle.
1832
//
1833
// NOTE: part of the Store interface.
1834
func (s *SQLStore) MarkEdgeZombie(ctx context.Context, v lnwire.GossipVersion,
NEW
1835
        chanID uint64, pubKey1, pubKey2 [33]byte) error {
×
NEW
1836

×
NEW
1837
        if !isKnownGossipVersion(v) {
×
NEW
1838
                return fmt.Errorf("unsupported gossip version: %d", v)
×
NEW
1839
        }
×
1840

1841
        s.cacheMu.Lock()
×
1842
        defer s.cacheMu.Unlock()
×
1843

×
1844
        chanIDB := channelIDToBytes(chanID)
×
1845

×
1846
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1847
                return db.UpsertZombieChannel(
×
1848
                        ctx, sqlc.UpsertZombieChannelParams{
×
NEW
1849
                                Version:  int16(v),
×
1850
                                Scid:     chanIDB,
×
1851
                                NodeKey1: pubKey1[:],
×
1852
                                NodeKey2: pubKey2[:],
×
1853
                        },
×
1854
                )
×
1855
        }, sqldb.NoOpReset)
×
1856
        if err != nil {
×
1857
                return fmt.Errorf("unable to upsert zombie channel "+
×
1858
                        "(channel_id=%d): %w", chanID, err)
×
1859
        }
×
1860

NEW
1861
        s.rejectCache.remove(v, chanID)
×
NEW
1862
        s.chanCache.remove(v, chanID)
×
1863

×
1864
        return nil
×
1865
}
1866

1867
// MarkEdgeLive clears an edge from our zombie index for the given gossip
1868
// version, deeming it as live.
1869
//
1870
// NOTE: part of the Store interface.
1871
func (s *SQLStore) MarkEdgeLive(ctx context.Context,
NEW
1872
        v lnwire.GossipVersion, chanID uint64) error {
×
NEW
1873

×
1874
        s.cacheMu.Lock()
×
1875
        defer s.cacheMu.Unlock()
×
1876

×
NEW
1877
        if !isKnownGossipVersion(v) {
×
NEW
1878
                return fmt.Errorf("unsupported gossip version: %d", v)
×
NEW
1879
        }
×
1880

NEW
1881
        chanIDB := channelIDToBytes(chanID)
×
1882

×
1883
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1884
                res, err := db.DeleteZombieChannel(
×
1885
                        ctx, sqlc.DeleteZombieChannelParams{
×
1886
                                Scid:    chanIDB,
×
NEW
1887
                                Version: int16(v),
×
1888
                        },
×
1889
                )
×
1890
                if err != nil {
×
1891
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1892
                                err)
×
1893
                }
×
1894

1895
                rows, err := res.RowsAffected()
×
1896
                if err != nil {
×
1897
                        return err
×
1898
                }
×
1899

1900
                if rows == 0 {
×
1901
                        return ErrZombieEdgeNotFound
×
1902
                } else if rows > 1 {
×
1903
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1904
                                "expected 1", rows)
×
1905
                }
×
1906

1907
                return nil
×
1908
        }, sqldb.NoOpReset)
1909
        if err != nil {
×
1910
                return fmt.Errorf("unable to mark edge live "+
×
1911
                        "(channel_id=%d): %w", chanID, err)
×
1912
        }
×
1913

NEW
1914
        s.rejectCache.remove(v, chanID)
×
NEW
1915
        s.chanCache.remove(v, chanID)
×
1916

×
1917
        return err
×
1918
}
1919

1920
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1921
// zombie, then the two node public keys corresponding to this edge are also
1922
// returned.
1923
//
1924
// NOTE: part of the Store interface.
1925
func (s *SQLStore) IsZombieEdge(ctx context.Context, v lnwire.GossipVersion,
1926
        chanID uint64) (bool, [33]byte, [33]byte, error) {
×
1927

×
1928
        var (
×
1929
                isZombie         bool
×
1930
                pubKey1, pubKey2 route.Vertex
×
1931
                chanIDB          = channelIDToBytes(chanID)
×
1932
        )
×
1933

×
1934
        if !isKnownGossipVersion(v) {
×
1935
                return false, [33]byte{}, [33]byte{},
×
1936
                        fmt.Errorf("unsupported gossip version: %d", v)
×
1937
        }
×
1938

1939
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1940
                zombie, err := db.GetZombieChannel(
×
1941
                        ctx, sqlc.GetZombieChannelParams{
×
1942
                                Scid:    chanIDB,
×
1943
                                Version: int16(v),
×
1944
                        },
×
1945
                )
×
1946
                if errors.Is(err, sql.ErrNoRows) {
×
1947
                        return nil
×
1948
                }
×
1949
                if err != nil {
×
1950
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1951
                                err)
×
1952
                }
×
1953

1954
                copy(pubKey1[:], zombie.NodeKey1)
×
1955
                copy(pubKey2[:], zombie.NodeKey2)
×
1956
                isZombie = true
×
1957

×
1958
                return nil
×
1959
        }, sqldb.NoOpReset)
1960
        if err != nil {
×
1961
                return false, route.Vertex{}, route.Vertex{},
×
1962
                        fmt.Errorf("%w: %w (chanID=%d)",
×
1963
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
×
1964
        }
×
1965

1966
        return isZombie, pubKey1, pubKey2, nil
×
1967
}
1968

1969
// NumZombies returns the current number of zombie channels in the graph.
1970
//
1971
// NOTE: part of the Store interface.
1972
func (s *SQLStore) NumZombies(
1973
        ctx context.Context, v lnwire.GossipVersion,
NEW
1974
) (uint64, error) {
×
NEW
1975

×
1976
        var (
×
1977
                numZombies uint64
×
1978
        )
×
1979
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1980
                count, err := db.CountZombieChannels(
×
NEW
1981
                        ctx, int16(v),
×
1982
                )
×
1983
                if err != nil {
×
1984
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1985
                                err)
×
1986
                }
×
1987

1988
                numZombies = uint64(count)
×
1989

×
1990
                return nil
×
1991
        }, sqldb.NoOpReset)
1992
        if err != nil {
×
1993
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1994
        }
×
1995

1996
        return numZombies, nil
×
1997
}
1998

1999
// DeleteChannelEdges removes edges with the given channel IDs from the
2000
// database and marks them as zombies. This ensures that we're unable to re-add
2001
// it to our database once again. If an edge does not exist within the
2002
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
2003
// true, then when we mark these edges as zombies, we'll set up the keys such
2004
// that we require the node that failed to send the fresh update to be the one
2005
// that resurrects the channel from its zombie state. The markZombie bool
2006
// denotes whether to mark the channel as a zombie.
2007
//
2008
// NOTE: part of the Store interface.
2009
func (s *SQLStore) DeleteChannelEdges(ctx context.Context,
2010
        v lnwire.GossipVersion, strictZombiePruning, markZombie bool,
2011
        chanIDs ...uint64) (
2012
        []*models.ChannelEdgeInfo, error) {
×
2013

×
2014
        s.cacheMu.Lock()
×
2015
        defer s.cacheMu.Unlock()
×
2016

×
2017
        // Keep track of which channels we end up finding so that we can
×
2018
        // correctly return ErrEdgeNotFound if we do not find a channel.
×
2019
        chanLookup := make(map[uint64]struct{}, len(chanIDs))
×
2020
        for _, chanID := range chanIDs {
×
2021
                chanLookup[chanID] = struct{}{}
×
2022
        }
×
2023

2024
        var edges []*models.ChannelEdgeInfo
×
2025
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2026
                // First, collect all channel rows.
×
2027
                var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow
×
2028
                chanCallBack := func(ctx context.Context,
×
2029
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
2030

×
2031
                        // Deleting the entry from the map indicates that we
×
2032
                        // have found the channel.
×
2033
                        scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
2034
                        delete(chanLookup, scid)
×
2035

×
2036
                        channelRows = append(channelRows, row)
×
2037

×
2038
                        return nil
×
2039
                }
×
2040

2041
                err := s.forEachChanWithPoliciesInSCIDList(
×
2042
                        ctx, db, v, chanCallBack, chanIDs,
×
2043
                )
×
2044
                if err != nil {
×
2045
                        return err
×
2046
                }
×
2047

2048
                if len(chanLookup) > 0 {
×
2049
                        return ErrEdgeNotFound
×
2050
                }
×
2051

2052
                if len(channelRows) == 0 {
×
2053
                        return nil
×
2054
                }
×
2055

2056
                // Batch build all channel edges.
2057
                var chanIDsToDelete []int64
×
2058
                edges, chanIDsToDelete, err = batchBuildChannelInfo(
×
2059
                        ctx, s.cfg, db, channelRows,
×
2060
                )
×
2061
                if err != nil {
×
2062
                        return err
×
2063
                }
×
2064

2065
                if markZombie {
×
2066
                        for i, row := range channelRows {
×
2067
                                scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
2068

×
2069
                                err := handleZombieMarking(
×
2070
                                        ctx, db, v, row, edges[i],
×
2071
                                        strictZombiePruning, scid,
×
2072
                                )
×
2073
                                if err != nil {
×
2074
                                        return fmt.Errorf("unable to mark "+
×
2075
                                                "channel as zombie: %w", err)
×
2076
                                }
×
2077
                        }
2078
                }
2079

2080
                return s.deleteChannels(ctx, db, chanIDsToDelete)
×
2081
        }, func() {
×
2082
                edges = nil
×
2083

×
2084
                // Re-fill the lookup map.
×
2085
                for _, chanID := range chanIDs {
×
2086
                        chanLookup[chanID] = struct{}{}
×
2087
                }
×
2088
        })
2089
        if err != nil {
×
2090
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
2091
                        err)
×
2092
        }
×
2093

2094
        for _, chanID := range chanIDs {
×
2095
                s.rejectCache.remove(v, chanID)
×
2096
                s.chanCache.remove(v, chanID)
×
2097
        }
×
2098

2099
        return edges, nil
×
2100
}
2101

2102
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
2103
// channel identified by the channel ID. If the channel can't be found, then
2104
// ErrEdgeNotFound is returned. A struct which houses the general information
2105
// for the channel itself is returned as well as two structs that contain the
2106
// routing policies for the channel in either direction.
2107
//
2108
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
2109
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
2110
// the ChannelEdgeInfo will only include the public keys of each node.
2111
//
2112
// NOTE: part of the Store interface.
2113
func (s *SQLStore) FetchChannelEdgesByID(ctx context.Context,
2114
        v lnwire.GossipVersion, chanID uint64) (
2115
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
2116
        *models.ChannelEdgePolicy, error) {
×
2117

×
2118
        var (
×
2119
                edge             *models.ChannelEdgeInfo
×
2120
                policy1, policy2 *models.ChannelEdgePolicy
×
2121
                chanIDB          = channelIDToBytes(chanID)
×
2122
        )
×
2123

×
2124
        if !isKnownGossipVersion(v) {
×
2125
                return nil, nil, nil, fmt.Errorf(
×
2126
                        "unsupported gossip version: %d", v,
×
2127
                )
×
2128
        }
×
2129

2130
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2131
                row, err := db.GetChannelBySCIDWithPolicies(
×
2132
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
2133
                                Scid:    chanIDB,
×
2134
                                Version: int16(v),
×
2135
                        },
×
2136
                )
×
2137
                if errors.Is(err, sql.ErrNoRows) {
×
2138
                        // First check if this edge is perhaps in the zombie
×
2139
                        // index.
×
2140
                        zombie, err := db.GetZombieChannel(
×
2141
                                ctx, sqlc.GetZombieChannelParams{
×
2142
                                        Scid:    chanIDB,
×
2143
                                        Version: int16(v),
×
2144
                                },
×
2145
                        )
×
2146
                        if errors.Is(err, sql.ErrNoRows) {
×
2147
                                return ErrEdgeNotFound
×
2148
                        } else if err != nil {
×
2149
                                return fmt.Errorf("unable to check if "+
×
2150
                                        "channel is zombie: %w", err)
×
2151
                        }
×
2152

2153
                        // At this point, we know the channel is a zombie, so
2154
                        // we'll return an error indicating this, and we will
2155
                        // populate the edge info with the public keys of each
2156
                        // party as this is the only information we have about
2157
                        // it.
2158
                        node1, err := route.NewVertexFromBytes(zombie.NodeKey1)
×
2159
                        if err != nil {
×
2160
                                return err
×
2161
                        }
×
2162
                        node2, err := route.NewVertexFromBytes(zombie.NodeKey2)
×
2163
                        if err != nil {
×
2164
                                return err
×
2165
                        }
×
2166
                        zombieEdge, err := models.NewV1Channel(
×
2167
                                0, chainhash.Hash{}, node1, node2,
×
2168
                                &models.ChannelV1Fields{},
×
2169
                        )
×
2170
                        if err != nil {
×
2171
                                return err
×
2172
                        }
×
2173
                        edge = zombieEdge
×
2174

×
2175
                        return ErrZombieEdge
×
2176
                } else if err != nil {
×
2177
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2178
                }
×
2179

2180
                node1, node2, err := buildNodeVertices(
×
2181
                        row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
2182
                )
×
2183
                if err != nil {
×
2184
                        return err
×
2185
                }
×
2186

2187
                edge, err = getAndBuildEdgeInfo(
×
2188
                        ctx, s.cfg, db, row.GraphChannel, node1, node2,
×
2189
                )
×
2190
                if err != nil {
×
2191
                        return fmt.Errorf("unable to build channel info: %w",
×
2192
                                err)
×
2193
                }
×
2194

2195
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2196
                if err != nil {
×
2197
                        return fmt.Errorf("unable to extract channel "+
×
2198
                                "policies: %w", err)
×
2199
                }
×
2200

2201
                policy1, policy2, err = getAndBuildChanPolicies(
×
2202
                        ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, edge.ChannelID,
×
2203
                        node1, node2,
×
2204
                )
×
2205
                if err != nil {
×
2206
                        return fmt.Errorf("unable to build channel "+
×
2207
                                "policies: %w", err)
×
2208
                }
×
2209

2210
                return nil
×
2211
        }, sqldb.NoOpReset)
2212
        if err != nil {
×
2213
                // If we are returning the ErrZombieEdge, then we also need to
×
2214
                // return the edge info as the method comment indicates that
×
2215
                // this will be populated when the edge is a zombie.
×
2216
                return edge, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
2217
                        err)
×
2218
        }
×
2219

2220
        return edge, policy1, policy2, nil
×
2221
}
2222

2223
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
2224
// the channel identified by the funding outpoint. If the channel can't be
2225
// found, then ErrEdgeNotFound is returned. A struct which houses the general
2226
// information for the channel itself is returned as well as two structs that
2227
// contain the routing policies for the channel in either direction.
2228
//
2229
// NOTE: part of the Store interface.
2230
func (s *SQLStore) FetchChannelEdgesByOutpoint(ctx context.Context,
2231
        v lnwire.GossipVersion, op *wire.OutPoint) (
2232
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
2233
        *models.ChannelEdgePolicy, error) {
×
2234

×
2235
        var (
×
2236
                edge             *models.ChannelEdgeInfo
×
2237
                policy1, policy2 *models.ChannelEdgePolicy
×
2238
        )
×
2239

×
2240
        if !isKnownGossipVersion(v) {
×
2241
                return nil, nil, nil, fmt.Errorf(
×
2242
                        "unsupported gossip version: %d", v,
×
2243
                )
×
2244
        }
×
2245

2246
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2247
                row, err := db.GetChannelByOutpointWithPolicies(
×
2248
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
2249
                                Outpoint: op.String(),
×
2250
                                Version:  int16(v),
×
2251
                        },
×
2252
                )
×
2253
                if errors.Is(err, sql.ErrNoRows) {
×
2254
                        return ErrEdgeNotFound
×
2255
                } else if err != nil {
×
2256
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2257
                }
×
2258

2259
                node1, node2, err := buildNodeVertices(
×
2260
                        row.Node1Pubkey, row.Node2Pubkey,
×
2261
                )
×
2262
                if err != nil {
×
2263
                        return err
×
2264
                }
×
2265

2266
                edge, err = getAndBuildEdgeInfo(
×
2267
                        ctx, s.cfg, db, row.GraphChannel, node1, node2,
×
2268
                )
×
2269
                if err != nil {
×
2270
                        return fmt.Errorf("unable to build channel info: %w",
×
2271
                                err)
×
2272
                }
×
2273

2274
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2275
                if err != nil {
×
2276
                        return fmt.Errorf("unable to extract channel "+
×
2277
                                "policies: %w", err)
×
2278
                }
×
2279

2280
                policy1, policy2, err = getAndBuildChanPolicies(
×
2281
                        ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, edge.ChannelID,
×
2282
                        node1, node2,
×
2283
                )
×
2284
                if err != nil {
×
2285
                        return fmt.Errorf("unable to build channel "+
×
2286
                                "policies: %w", err)
×
2287
                }
×
2288

2289
                return nil
×
2290
        }, sqldb.NoOpReset)
2291
        if err != nil {
×
2292
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
2293
                        err)
×
2294
        }
×
2295

2296
        return edge, policy1, policy2, nil
×
2297
}
2298

2299
// HasV1ChannelEdge returns true if the database knows of a channel edge
2300
// with the passed channel ID, and false otherwise. If an edge with that ID
2301
// is found within the graph, then two time stamps representing the last time
2302
// the edge was updated for both directed edges are returned along with the
2303
// boolean. If it is not found, then the zombie index is checked and its
2304
// result is returned as the second boolean.
2305
//
2306
// NOTE: part of the Store interface.
2307
func (s *SQLStore) HasV1ChannelEdge(ctx context.Context,
2308
        chanID uint64) (time.Time, time.Time, bool, bool, error) {
×
2309

×
2310
        var (
×
2311
                exists          bool
×
2312
                isZombie        bool
×
2313
                node1LastUpdate time.Time
×
2314
                node2LastUpdate time.Time
×
2315
        )
×
2316

×
2317
        // We'll query the cache with the shared lock held to allow multiple
×
2318
        // readers to access values in the cache concurrently if they exist.
×
2319
        s.cacheMu.RLock()
×
2320
        if entry, ok := s.rejectCache.get(gossipV1, chanID); ok {
×
2321
                s.cacheMu.RUnlock()
×
2322
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2323
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2324
                exists, isZombie = entry.flags.unpack()
×
2325

×
2326
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2327
        }
×
2328
        s.cacheMu.RUnlock()
×
2329

×
2330
        s.cacheMu.Lock()
×
2331
        defer s.cacheMu.Unlock()
×
2332

×
2333
        // The item was not found with the shared lock, so we'll acquire the
×
2334
        // exclusive lock and check the cache again in case another method added
×
2335
        // the entry to the cache while no lock was held.
×
2336
        if entry, ok := s.rejectCache.get(gossipV1, chanID); ok {
×
2337
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2338
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2339
                exists, isZombie = entry.flags.unpack()
×
2340

×
2341
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2342
        }
×
2343

2344
        chanIDB := channelIDToBytes(chanID)
×
2345
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2346
                channel, err := db.GetChannelBySCID(
×
2347
                        ctx, sqlc.GetChannelBySCIDParams{
×
2348
                                Scid:    chanIDB,
×
2349
                                Version: int16(gossipV1),
×
2350
                        },
×
2351
                )
×
2352
                if errors.Is(err, sql.ErrNoRows) {
×
2353
                        // Check if it is a zombie channel.
×
2354
                        isZombie, err = db.IsZombieChannel(
×
2355
                                ctx, sqlc.IsZombieChannelParams{
×
2356
                                        Scid:    chanIDB,
×
2357
                                        Version: int16(gossipV1),
×
2358
                                },
×
2359
                        )
×
2360
                        if err != nil {
×
2361
                                return fmt.Errorf("could not check if channel "+
×
2362
                                        "is zombie: %w", err)
×
2363
                        }
×
2364

2365
                        return nil
×
2366
                } else if err != nil {
×
2367
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2368
                }
×
2369

2370
                exists = true
×
2371

×
2372
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
2373
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2374
                                Version:   int16(gossipV1),
×
2375
                                ChannelID: channel.ID,
×
2376
                                NodeID:    channel.NodeID1,
×
2377
                        },
×
2378
                )
×
2379
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2380
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2381
                                err)
×
2382
                } else if err == nil {
×
2383
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
2384
                }
×
2385

2386
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
2387
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2388
                                Version:   int16(gossipV1),
×
2389
                                ChannelID: channel.ID,
×
2390
                                NodeID:    channel.NodeID2,
×
2391
                        },
×
2392
                )
×
2393
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2394
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2395
                                err)
×
2396
                } else if err == nil {
×
2397
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
2398
                }
×
2399

2400
                return nil
×
2401
        }, sqldb.NoOpReset)
2402
        if err != nil {
×
2403
                return time.Time{}, time.Time{}, false, false,
×
2404
                        fmt.Errorf("unable to fetch channel: %w", err)
×
2405
        }
×
2406

2407
        s.rejectCache.insert(
×
2408
                gossipV1, chanID,
×
2409
                newRejectCacheEntryV1(
×
2410
                        node1LastUpdate, node2LastUpdate, exists,
×
2411
                        isZombie,
×
2412
                ),
×
2413
        )
×
2414

×
2415
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2416
}
2417

2418
// HasChannelEdge returns true if the database knows of a channel edge with the
2419
// passed channel ID and gossip version, and false otherwise. If an edge with
2420
// that ID is found within the graph, then the zombie index is checked and its
2421
// result is returned as the second boolean.
2422
//
2423
// NOTE: part of the Store interface.
2424
func (s *SQLStore) HasChannelEdge(ctx context.Context,
2425
        v lnwire.GossipVersion, chanID uint64) (bool, bool, error) {
×
2426

×
2427
        if !isKnownGossipVersion(v) {
×
2428
                return false, false, fmt.Errorf(
×
2429
                        "unsupported gossip version: %d", v,
×
2430
                )
×
2431
        }
×
2432

2433
        var (
×
2434
                exists          bool
×
2435
                isZombie        bool
×
2436
                node1LastUpdate time.Time
×
2437
                node2LastUpdate time.Time
×
2438
                node1Block      uint32
×
2439
                node2Block      uint32
×
2440
        )
×
2441

×
2442
        // We'll query the cache with the shared lock held to allow multiple
×
2443
        // readers to access values in the cache concurrently if they exist.
×
2444
        s.cacheMu.RLock()
×
2445
        if entry, ok := s.rejectCache.get(v, chanID); ok {
×
2446
                s.cacheMu.RUnlock()
×
2447
                exists, isZombie = entry.flags.unpack()
×
2448
                return exists, isZombie, nil
×
2449
        }
×
2450
        s.cacheMu.RUnlock()
×
2451

×
2452
        s.cacheMu.Lock()
×
2453
        defer s.cacheMu.Unlock()
×
2454

×
2455
        // The item was not found with the shared lock, so we'll acquire the
×
2456
        // exclusive lock and check the cache again in case another method added
×
2457
        // the entry to the cache while no lock was held.
×
2458
        if entry, ok := s.rejectCache.get(v, chanID); ok {
×
2459
                exists, isZombie = entry.flags.unpack()
×
2460
                return exists, isZombie, nil
×
2461
        }
×
2462

2463
        chanIDB := channelIDToBytes(chanID)
×
2464
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2465
                channel, err := db.GetChannelBySCID(
×
2466
                        ctx, sqlc.GetChannelBySCIDParams{
×
2467
                                Scid:    chanIDB,
×
2468
                                Version: int16(v),
×
2469
                        },
×
2470
                )
×
2471
                if errors.Is(err, sql.ErrNoRows) {
×
2472
                        // Check if it is a zombie channel.
×
2473
                        isZombie, err = db.IsZombieChannel(
×
2474
                                ctx, sqlc.IsZombieChannelParams{
×
2475
                                        Scid:    chanIDB,
×
2476
                                        Version: int16(v),
×
2477
                                },
×
2478
                        )
×
2479
                        if err != nil {
×
2480
                                return fmt.Errorf("could not check if channel "+
×
2481
                                        "is zombie: %w", err)
×
2482
                        }
×
2483

2484
                        return nil
×
2485
                } else if err != nil {
×
2486
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2487
                }
×
2488

2489
                exists = true
×
2490

×
2491
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
2492
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2493
                                Version:   int16(v),
×
2494
                                ChannelID: channel.ID,
×
2495
                                NodeID:    channel.NodeID1,
×
2496
                        },
×
2497
                )
×
2498
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2499
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2500
                                err)
×
2501
                } else if err == nil {
×
2502
                        switch v {
×
2503
                        case gossipV1:
×
2504
                                if policy1.LastUpdate.Valid {
×
2505
                                        node1LastUpdate = time.Unix(
×
2506
                                                policy1.LastUpdate.Int64, 0,
×
2507
                                        )
×
2508
                                }
×
2509
                        case gossipV2:
×
2510
                                if policy1.BlockHeight.Valid {
×
2511
                                        node1Block = uint32(
×
2512
                                                policy1.BlockHeight.Int64,
×
2513
                                        )
×
2514
                                }
×
2515
                        }
2516
                }
2517

2518
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
2519
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2520
                                Version:   int16(v),
×
2521
                                ChannelID: channel.ID,
×
2522
                                NodeID:    channel.NodeID2,
×
2523
                        },
×
2524
                )
×
2525
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2526
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2527
                                err)
×
2528
                } else if err == nil {
×
2529
                        switch v {
×
2530
                        case gossipV1:
×
2531
                                if policy2.LastUpdate.Valid {
×
2532
                                        node2LastUpdate = time.Unix(
×
2533
                                                policy2.LastUpdate.Int64, 0,
×
2534
                                        )
×
2535
                                }
×
2536
                        case gossipV2:
×
2537
                                if policy2.BlockHeight.Valid {
×
2538
                                        node2Block = uint32(
×
2539
                                                policy2.BlockHeight.Int64,
×
2540
                                        )
×
2541
                                }
×
2542
                        }
2543
                }
2544

2545
                return nil
×
2546
        }, sqldb.NoOpReset)
2547
        if err != nil {
×
2548
                return false, false,
×
2549
                        fmt.Errorf("unable to fetch channel: %w", err)
×
2550
        }
×
2551

2552
        var entry rejectCacheEntry
×
2553
        switch v {
×
2554
        case gossipV1:
×
2555
                entry = newRejectCacheEntryV1(
×
2556
                        node1LastUpdate, node2LastUpdate, exists, isZombie,
×
2557
                )
×
2558
        case gossipV2:
×
2559
                entry = newRejectCacheEntryV2(
×
2560
                        node1Block, node2Block, exists, isZombie,
×
2561
                )
×
2562
        }
2563
        s.rejectCache.insert(v, chanID, entry)
×
2564

×
2565
        return exists, isZombie, nil
×
2566
}
2567

2568
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2569
// passed channel point (outpoint). If the passed channel doesn't exist within
2570
// the database, then ErrEdgeNotFound is returned.
2571
//
2572
// NOTE: part of the Store interface.
2573
func (s *SQLStore) ChannelID(ctx context.Context, v lnwire.GossipVersion,
2574
        chanPoint *wire.OutPoint) (uint64, error) {
×
2575

×
2576
        var (
×
2577
                channelID uint64
×
2578
        )
×
2579
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2580
                chanID, err := db.GetSCIDByOutpoint(
×
2581
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2582
                                Outpoint: chanPoint.String(),
×
2583
                                Version:  int16(v),
×
2584
                        },
×
2585
                )
×
2586
                if errors.Is(err, sql.ErrNoRows) {
×
2587
                        return ErrEdgeNotFound
×
2588
                } else if err != nil {
×
2589
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2590
                                err)
×
2591
                }
×
2592

2593
                channelID = byteOrder.Uint64(chanID)
×
2594

×
2595
                return nil
×
2596
        }, sqldb.NoOpReset)
2597
        if err != nil {
×
2598
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2599
        }
×
2600

2601
        return channelID, nil
×
2602
}
2603

2604
// IsPublicNode is a helper method that determines whether the node with the
2605
// given public key is seen as a public node in the graph from the graph's
2606
// source node's point of view.
2607
//
2608
// NOTE: part of the Store interface.
2609
func (s *SQLStore) IsPublicNode(ctx context.Context, v lnwire.GossipVersion,
2610
        pubKey [33]byte) (bool, error) {
×
2611

×
2612
        if !isKnownGossipVersion(v) {
×
2613
                return false, fmt.Errorf("unsupported gossip version: %d", v)
×
2614
        }
×
2615

2616
        var isPublic bool
×
2617
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2618
                var err error
×
2619
                switch v {
×
2620
                case gossipV1:
×
2621
                        isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2622
                case gossipV2:
×
2623
                        isPublic, err = db.IsPublicV2Node(ctx, pubKey[:])
×
2624
                }
2625

2626
                return err
×
2627
        }, sqldb.NoOpReset)
2628
        if err != nil {
×
2629
                return false, fmt.Errorf("unable to check if node is "+
×
2630
                        "public: %w", err)
×
2631
        }
×
2632

2633
        return isPublic, nil
×
2634
}
2635

2636
// FetchChanInfos returns the set of channel edges that correspond to the passed
2637
// channel ID's. If an edge is the query is unknown to the database, it will
2638
// skipped and the result will contain only those edges that exist at the time
2639
// of the query. This can be used to respond to peer queries that are seeking to
2640
// fill in gaps in their view of the channel graph.
2641
//
2642
// NOTE: part of the Store interface.
2643
func (s *SQLStore) FetchChanInfos(ctx context.Context,
2644
        v lnwire.GossipVersion, chanIDs []uint64) ([]ChannelEdge, error) {
×
2645

×
2646
        var (
×
2647
                edges = make(map[uint64]ChannelEdge)
×
2648
        )
×
2649
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2650
                if !isKnownGossipVersion(v) {
×
2651
                        return fmt.Errorf("unsupported gossip version: %d", v)
×
2652
                }
×
2653

2654
                // First, collect all channel rows.
2655
                var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow
×
2656
                chanCallBack := func(ctx context.Context,
×
2657
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
2658

×
2659
                        channelRows = append(channelRows, row)
×
2660
                        return nil
×
2661
                }
×
2662

2663
                err := s.forEachChanWithPoliciesInSCIDList(
×
2664
                        ctx, db, v, chanCallBack, chanIDs,
×
2665
                )
×
2666
                if err != nil {
×
2667
                        return err
×
2668
                }
×
2669

2670
                if len(channelRows) == 0 {
×
2671
                        return nil
×
2672
                }
×
2673

2674
                // Batch build all channel edges.
2675
                chans, err := batchBuildChannelEdges(
×
2676
                        ctx, s.cfg, db, channelRows,
×
2677
                )
×
2678
                if err != nil {
×
2679
                        return fmt.Errorf("unable to build channel edges: %w",
×
2680
                                err)
×
2681
                }
×
2682

2683
                for _, c := range chans {
×
2684
                        edges[c.Info.ChannelID] = c
×
2685
                }
×
2686

2687
                return err
×
2688
        }, func() {
×
2689
                clear(edges)
×
2690
        })
×
2691
        if err != nil {
×
2692
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2693
        }
×
2694

2695
        res := make([]ChannelEdge, 0, len(edges))
×
2696
        for _, chanID := range chanIDs {
×
2697
                edge, ok := edges[chanID]
×
2698
                if !ok {
×
2699
                        continue
×
2700
                }
2701

2702
                res = append(res, edge)
×
2703
        }
2704

2705
        return res, nil
×
2706
}
2707

2708
// forEachChanWithPoliciesInSCIDList is a wrapper around the
2709
// GetChannelsBySCIDWithPolicies query that allows us to iterate through
2710
// channels in a paginated manner.
2711
func (s *SQLStore) forEachChanWithPoliciesInSCIDList(ctx context.Context,
2712
        db SQLQueries, v lnwire.GossipVersion, cb func(ctx context.Context,
2713
                row sqlc.GetChannelsBySCIDWithPoliciesRow) error,
2714
        chanIDs []uint64) error {
×
2715

×
2716
        queryWrapper := func(ctx context.Context,
×
2717
                scids [][]byte) ([]sqlc.GetChannelsBySCIDWithPoliciesRow,
×
2718
                error) {
×
2719

×
2720
                return db.GetChannelsBySCIDWithPolicies(
×
2721
                        ctx, sqlc.GetChannelsBySCIDWithPoliciesParams{
×
2722
                                Version: int16(v),
×
2723
                                Scids:   scids,
×
2724
                        },
×
2725
                )
×
2726
        }
×
2727

2728
        return sqldb.ExecuteBatchQuery(
×
2729
                ctx, s.cfg.QueryCfg, chanIDs, channelIDToBytes, queryWrapper,
×
2730
                cb,
×
2731
        )
×
2732
}
2733

2734
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2735
// ID's that we don't know and are not known zombies of the passed set. In other
2736
// words, we perform a set difference of our set of chan ID's and the ones
2737
// passed in. This method can be used by callers to determine the set of
2738
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2739
// known zombies is also returned.
2740
//
2741
// NOTE: part of the Store interface.
2742
func (s *SQLStore) FilterKnownChanIDs(ctx context.Context,
2743
        chansInfo []ChannelUpdateInfo) ([]uint64, []ChannelUpdateInfo, error) {
×
2744

×
2745
        var (
×
2746
                newChanIDs   []uint64
×
2747
                knownZombies []ChannelUpdateInfo
×
2748
                infoLookup   = make(
×
2749
                        map[uint64]ChannelUpdateInfo, len(chansInfo),
×
2750
                )
×
2751
        )
×
2752

×
2753
        // We first build a lookup map of the channel ID's to the
×
2754
        // ChannelUpdateInfo. This allows us to quickly delete channels that we
×
2755
        // already know about.
×
2756
        for _, chanInfo := range chansInfo {
×
2757
                infoLookup[chanInfo.ShortChannelID.ToUint64()] = chanInfo
×
2758
        }
×
2759

2760
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2761
                // The call-back function deletes known channels from
×
2762
                // infoLookup, so that we can later check which channels are
×
2763
                // zombies by only looking at the remaining channels in the set.
×
2764
                cb := func(ctx context.Context,
×
2765
                        channel sqlc.GraphChannel) error {
×
2766

×
2767
                        delete(infoLookup, byteOrder.Uint64(channel.Scid))
×
2768

×
2769
                        return nil
×
2770
                }
×
2771

2772
                err := s.forEachChanInSCIDList(ctx, db, cb, chansInfo)
×
2773
                if err != nil {
×
2774
                        return fmt.Errorf("unable to iterate through "+
×
2775
                                "channels: %w", err)
×
2776
                }
×
2777

2778
                // We want to ensure that we deal with the channels in the
2779
                // same order that they were passed in, so we iterate over the
2780
                // original chansInfo slice and then check if that channel is
2781
                // still in the infoLookup map.
2782
                for _, chanInfo := range chansInfo {
×
2783
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
2784
                        if _, ok := infoLookup[channelID]; !ok {
×
2785
                                continue
×
2786
                        }
2787

2788
                        isZombie, err := db.IsZombieChannel(
×
2789
                                ctx, sqlc.IsZombieChannelParams{
×
2790
                                        Scid:    channelIDToBytes(channelID),
×
2791
                                        Version: int16(lnwire.GossipVersion1),
×
2792
                                },
×
2793
                        )
×
2794
                        if err != nil {
×
2795
                                return fmt.Errorf("unable to fetch zombie "+
×
2796
                                        "channel: %w", err)
×
2797
                        }
×
2798

2799
                        if isZombie {
×
2800
                                knownZombies = append(knownZombies, chanInfo)
×
2801

×
2802
                                continue
×
2803
                        }
2804

2805
                        newChanIDs = append(newChanIDs, channelID)
×
2806
                }
2807

2808
                return nil
×
2809
        }, func() {
×
2810
                newChanIDs = nil
×
2811
                knownZombies = nil
×
2812
                // Rebuild the infoLookup map in case of a rollback.
×
2813
                for _, chanInfo := range chansInfo {
×
2814
                        scid := chanInfo.ShortChannelID.ToUint64()
×
2815
                        infoLookup[scid] = chanInfo
×
2816
                }
×
2817
        })
2818
        if err != nil {
×
2819
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2820
        }
×
2821

2822
        return newChanIDs, knownZombies, nil
×
2823
}
2824

2825
// forEachChanInSCIDList is a helper method that executes a paged query
2826
// against the database to fetch all channels that match the passed
2827
// ChannelUpdateInfo slice. The callback function is called for each channel
2828
// that is found.
2829
func (s *SQLStore) forEachChanInSCIDList(ctx context.Context, db SQLQueries,
2830
        cb func(ctx context.Context, channel sqlc.GraphChannel) error,
2831
        chansInfo []ChannelUpdateInfo) error {
×
2832

×
2833
        queryWrapper := func(ctx context.Context,
×
2834
                scids [][]byte) ([]sqlc.GraphChannel, error) {
×
2835

×
2836
                return db.GetChannelsBySCIDs(
×
2837
                        ctx, sqlc.GetChannelsBySCIDsParams{
×
2838
                                Version: int16(lnwire.GossipVersion1),
×
2839
                                Scids:   scids,
×
2840
                        },
×
2841
                )
×
2842
        }
×
2843

2844
        chanIDConverter := func(chanInfo ChannelUpdateInfo) []byte {
×
2845
                channelID := chanInfo.ShortChannelID.ToUint64()
×
2846

×
2847
                return channelIDToBytes(channelID)
×
2848
        }
×
2849

2850
        return sqldb.ExecuteBatchQuery(
×
2851
                ctx, s.cfg.QueryCfg, chansInfo, chanIDConverter, queryWrapper,
×
2852
                cb,
×
2853
        )
×
2854
}
2855

2856
// PruneGraphNodes is a garbage collection method which attempts to prune out
2857
// any nodes from the channel graph that are currently unconnected. This ensure
2858
// that we only maintain a graph of reachable nodes. In the event that a pruned
2859
// node gains more channels, it will be re-added back to the graph.
2860
//
2861
// NOTE: this prunes nodes across protocol versions. It will never prune the
2862
// source nodes.
2863
//
2864
// NOTE: part of the Store interface.
2865
func (s *SQLStore) PruneGraphNodes(ctx context.Context) (
2866
        []route.Vertex, error) {
×
2867

×
2868
        var prunedNodes []route.Vertex
×
2869
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2870
                var err error
×
2871
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2872

×
2873
                return err
×
2874
        }, func() {
×
2875
                prunedNodes = nil
×
2876
        })
×
2877
        if err != nil {
×
2878
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2879
        }
×
2880

2881
        return prunedNodes, nil
×
2882
}
2883

2884
// PruneGraph prunes newly closed channels from the channel graph in response
2885
// to a new block being solved on the network. Any transactions which spend the
2886
// funding output of any known channels within he graph will be deleted.
2887
// Additionally, the "prune tip", or the last block which has been used to
2888
// prune the graph is stored so callers can ensure the graph is fully in sync
2889
// with the current UTXO state. A slice of channels that have been closed by
2890
// the target block along with any pruned nodes are returned if the function
2891
// succeeds without error.
2892
//
2893
// NOTE: part of the Store interface.
2894
func (s *SQLStore) PruneGraph(ctx context.Context,
2895
        spentOutputs []*wire.OutPoint, blockHash *chainhash.Hash,
2896
        blockHeight uint32) ([]*models.ChannelEdgeInfo, []route.Vertex,
2897
        error) {
×
2898

×
2899
        s.cacheMu.Lock()
×
2900
        defer s.cacheMu.Unlock()
×
2901

×
2902
        var (
×
2903
                closedChans []*models.ChannelEdgeInfo
×
2904
                prunedNodes []route.Vertex
×
2905
        )
×
2906
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2907
                // First, collect all channel rows that need to be pruned.
×
2908
                var channelRows []sqlc.GetChannelsByOutpointsRow
×
2909
                channelCallback := func(ctx context.Context,
×
2910
                        row sqlc.GetChannelsByOutpointsRow) error {
×
2911

×
2912
                        channelRows = append(channelRows, row)
×
2913

×
2914
                        return nil
×
2915
                }
×
2916

2917
                err := s.forEachChanInOutpoints(
×
2918
                        ctx, db, spentOutputs, channelCallback,
×
2919
                )
×
2920
                if err != nil {
×
2921
                        return fmt.Errorf("unable to fetch channels by "+
×
2922
                                "outpoints: %w", err)
×
2923
                }
×
2924

2925
                if len(channelRows) == 0 {
×
2926
                        // There are no channels to prune. So we can exit early
×
2927
                        // after updating the prune log.
×
2928
                        err = db.UpsertPruneLogEntry(
×
2929
                                ctx, sqlc.UpsertPruneLogEntryParams{
×
2930
                                        BlockHash:   blockHash[:],
×
2931
                                        BlockHeight: int64(blockHeight),
×
2932
                                },
×
2933
                        )
×
2934
                        if err != nil {
×
2935
                                return fmt.Errorf("unable to insert prune log "+
×
2936
                                        "entry: %w", err)
×
2937
                        }
×
2938

2939
                        return nil
×
2940
                }
2941

2942
                // Batch build all channel edges for pruning.
2943
                var chansToDelete []int64
×
2944
                closedChans, chansToDelete, err = batchBuildChannelInfo(
×
2945
                        ctx, s.cfg, db, channelRows,
×
2946
                )
×
2947
                if err != nil {
×
2948
                        return err
×
2949
                }
×
2950

2951
                err = s.deleteChannels(ctx, db, chansToDelete)
×
2952
                if err != nil {
×
2953
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2954
                }
×
2955

2956
                err = db.UpsertPruneLogEntry(
×
2957
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
2958
                                BlockHash:   blockHash[:],
×
2959
                                BlockHeight: int64(blockHeight),
×
2960
                        },
×
2961
                )
×
2962
                if err != nil {
×
2963
                        return fmt.Errorf("unable to insert prune log "+
×
2964
                                "entry: %w", err)
×
2965
                }
×
2966

2967
                // Now that we've pruned some channels, we'll also prune any
2968
                // nodes that no longer have any channels.
2969
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2970
                if err != nil {
×
2971
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2972
                                err)
×
2973
                }
×
2974

2975
                return nil
×
2976
        }, func() {
×
2977
                prunedNodes = nil
×
2978
                closedChans = nil
×
2979
        })
×
2980
        if err != nil {
×
2981
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2982
        }
×
2983

2984
        for _, channel := range closedChans {
×
2985
                s.rejectCache.remove(channel.Version, channel.ChannelID)
×
2986
                s.chanCache.remove(channel.Version, channel.ChannelID)
×
2987
        }
×
2988

2989
        return closedChans, prunedNodes, nil
×
2990
}
2991

2992
// forEachChanInOutpoints is a helper function that executes a paginated
2993
// query to fetch channels by their outpoints and applies the given call-back
2994
// to each.
2995
//
2996
// NOTE: this fetches channels for all protocol versions.
2997
func (s *SQLStore) forEachChanInOutpoints(ctx context.Context, db SQLQueries,
2998
        outpoints []*wire.OutPoint, cb func(ctx context.Context,
2999
                row sqlc.GetChannelsByOutpointsRow) error) error {
×
3000

×
3001
        // Create a wrapper that uses the transaction's db instance to execute
×
3002
        // the query.
×
3003
        queryWrapper := func(ctx context.Context,
×
3004
                pageOutpoints []string) ([]sqlc.GetChannelsByOutpointsRow,
×
3005
                error) {
×
3006

×
3007
                return db.GetChannelsByOutpoints(ctx, pageOutpoints)
×
3008
        }
×
3009

3010
        // Define the conversion function from Outpoint to string.
3011
        outpointToString := func(outpoint *wire.OutPoint) string {
×
3012
                return outpoint.String()
×
3013
        }
×
3014

3015
        return sqldb.ExecuteBatchQuery(
×
3016
                ctx, s.cfg.QueryCfg, outpoints, outpointToString,
×
3017
                queryWrapper, cb,
×
3018
        )
×
3019
}
3020

3021
func (s *SQLStore) deleteChannels(ctx context.Context, db SQLQueries,
3022
        dbIDs []int64) error {
×
3023

×
3024
        // Create a wrapper that uses the transaction's db instance to execute
×
3025
        // the query.
×
3026
        queryWrapper := func(ctx context.Context, ids []int64) ([]any, error) {
×
3027
                return nil, db.DeleteChannels(ctx, ids)
×
3028
        }
×
3029

3030
        idConverter := func(id int64) int64 {
×
3031
                return id
×
3032
        }
×
3033

3034
        return sqldb.ExecuteBatchQuery(
×
3035
                ctx, s.cfg.QueryCfg, dbIDs, idConverter,
×
3036
                queryWrapper, func(ctx context.Context, _ any) error {
×
3037
                        return nil
×
3038
                },
×
3039
        )
3040
}
3041

3042
// ChannelView returns the verifiable edge information for each active channel
3043
// within the known channel graph. The set of UTXOs (along with their scripts)
3044
// returned are the ones that need to be watched on chain to detect channel
3045
// closes on the resident blockchain.
3046
//
3047
// NOTE: part of the Store interface.
3048
func (s *SQLStore) ChannelView(ctx context.Context,
NEW
3049
        v lnwire.GossipVersion) ([]EdgePoint, error) {
×
NEW
3050

×
3051
        var edgePoints []EdgePoint
×
3052

×
3053
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
3054
                switch v {
×
NEW
3055
                case gossipV1:
×
NEW
3056
                        handleChannel := func(_ context.Context,
×
NEW
3057
                                channel sqlc.ListChannelsPaginatedRow) error {
×
3058

×
NEW
3059
                                key1, err := route.NewVertexFromBytes(
×
NEW
3060
                                        channel.BitcoinKey1,
×
NEW
3061
                                )
×
NEW
3062
                                if err != nil {
×
NEW
3063
                                        return err
×
NEW
3064
                                }
×
3065

NEW
3066
                                key2, err := route.NewVertexFromBytes(
×
NEW
3067
                                        channel.BitcoinKey2,
×
NEW
3068
                                )
×
NEW
3069
                                if err != nil {
×
NEW
3070
                                        return err
×
NEW
3071
                                }
×
3072

NEW
3073
                                edge := &models.ChannelEdgeInfo{
×
NEW
3074
                                        Version:          gossipV1,
×
NEW
3075
                                        BitcoinKey1Bytes: fn.Some(key1),
×
NEW
3076
                                        BitcoinKey2Bytes: fn.Some(key2),
×
NEW
3077
                                }
×
NEW
3078
                                pkScript, err := edge.FundingPKScript()
×
NEW
3079
                                if err != nil {
×
NEW
3080
                                        return err
×
NEW
3081
                                }
×
3082

NEW
3083
                                op, err := wire.NewOutPointFromString(
×
NEW
3084
                                        channel.Outpoint,
×
NEW
3085
                                )
×
NEW
3086
                                if err != nil {
×
NEW
3087
                                        return err
×
NEW
3088
                                }
×
3089

NEW
3090
                                edgePoints = append(edgePoints, EdgePoint{
×
NEW
3091
                                        FundingPkScript: pkScript,
×
NEW
3092
                                        OutPoint:        *op,
×
NEW
3093
                                })
×
NEW
3094

×
NEW
3095
                                return nil
×
3096
                        }
3097

NEW
3098
                        queryFunc := func(ctx context.Context, lastID int64,
×
NEW
3099
                                limit int32) ([]sqlc.ListChannelsPaginatedRow,
×
NEW
3100
                                error) {
×
NEW
3101

×
NEW
3102
                                return db.ListChannelsPaginated(
×
NEW
3103
                                        ctx, sqlc.ListChannelsPaginatedParams{
×
NEW
3104
                                                Version: int16(gossipV1),
×
NEW
3105
                                                ID:      lastID,
×
NEW
3106
                                                Limit:   limit,
×
NEW
3107
                                        },
×
NEW
3108
                                )
×
UNCOV
3109
                        }
×
3110

NEW
3111
                        extractCursor := func(
×
NEW
3112
                                row sqlc.ListChannelsPaginatedRow) int64 {
×
3113

×
NEW
3114
                                return row.ID
×
NEW
3115
                        }
×
3116

NEW
3117
                        return sqldb.ExecutePaginatedQuery(
×
NEW
3118
                                ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
NEW
3119
                                extractCursor, handleChannel,
×
NEW
3120
                        )
×
3121

NEW
3122
                case gossipV2:
×
NEW
3123
                        handleChannel := func(_ context.Context,
×
NEW
3124
                                channel sqlc.ListChannelsPaginatedV2Row) error {
×
NEW
3125

×
NEW
3126
                                op, err := wire.NewOutPointFromString(
×
NEW
3127
                                        channel.Outpoint,
×
NEW
3128
                                )
×
NEW
3129
                                if err != nil {
×
NEW
3130
                                        return err
×
NEW
3131
                                }
×
3132

NEW
3133
                                pkScript := channel.FundingPkScript
×
NEW
3134
                                edgePoints = append(edgePoints, EdgePoint{
×
NEW
3135
                                        FundingPkScript: pkScript,
×
NEW
3136
                                        OutPoint:        *op,
×
NEW
3137
                                })
×
NEW
3138

×
NEW
3139
                                return nil
×
3140
                        }
3141

NEW
3142
                        queryFunc := func(ctx context.Context, lastID int64,
×
NEW
3143
                                limit int32) ([]sqlc.ListChannelsPaginatedV2Row,
×
NEW
3144
                                error) {
×
NEW
3145

×
NEW
3146
                                return db.ListChannelsPaginatedV2(
×
NEW
3147
                                        ctx, sqlc.ListChannelsPaginatedV2Params{
×
NEW
3148
                                                ID:    lastID,
×
NEW
3149
                                                Limit: limit,
×
NEW
3150
                                        },
×
NEW
3151
                                )
×
NEW
3152
                        }
×
3153

NEW
3154
                        extractCursor := func(
×
NEW
3155
                                row sqlc.ListChannelsPaginatedV2Row) int64 {
×
NEW
3156

×
NEW
3157
                                return row.ID
×
NEW
3158
                        }
×
3159

NEW
3160
                        return sqldb.ExecutePaginatedQuery(
×
NEW
3161
                                ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
NEW
3162
                                extractCursor, handleChannel,
×
3163
                        )
×
3164

NEW
3165
                default:
×
NEW
3166
                        return fmt.Errorf("unsupported gossip version: %d", v)
×
3167
                }
UNCOV
3168
        }, func() {
×
3169
                edgePoints = nil
×
3170
        })
×
3171
        if err != nil {
×
3172
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
3173
        }
×
3174

3175
        return edgePoints, nil
×
3176
}
3177

3178
// PruneTip returns the block height and hash of the latest block that has been
3179
// used to prune channels in the graph. Knowing the "prune tip" allows callers
3180
// to tell if the graph is currently in sync with the current best known UTXO
3181
// state.
3182
//
3183
// NOTE: part of the Store interface.
3184
func (s *SQLStore) PruneTip(ctx context.Context) (*chainhash.Hash, uint32,
3185
        error) {
×
3186

×
3187
        var (
×
3188
                tipHash   chainhash.Hash
×
3189
                tipHeight uint32
×
3190
        )
×
3191
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
3192
                pruneTip, err := db.GetPruneTip(ctx)
×
3193
                if errors.Is(err, sql.ErrNoRows) {
×
3194
                        return ErrGraphNeverPruned
×
3195
                } else if err != nil {
×
3196
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
3197
                }
×
3198

3199
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
3200
                tipHeight = uint32(pruneTip.BlockHeight)
×
3201

×
3202
                return nil
×
3203
        }, sqldb.NoOpReset)
3204
        if err != nil {
×
3205
                return nil, 0, err
×
3206
        }
×
3207

3208
        return &tipHash, tipHeight, nil
×
3209
}
3210

3211
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
3212
//
3213
// NOTE: this prunes nodes across protocol versions. It will never prune the
3214
// source nodes.
3215
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
3216
        db SQLQueries) ([]route.Vertex, error) {
×
3217

×
3218
        nodeKeys, err := db.DeleteUnconnectedNodes(ctx)
×
3219
        if err != nil {
×
3220
                return nil, fmt.Errorf("unable to delete unconnected "+
×
3221
                        "nodes: %w", err)
×
3222
        }
×
3223

3224
        prunedNodes := make([]route.Vertex, len(nodeKeys))
×
3225
        for i, nodeKey := range nodeKeys {
×
3226
                pub, err := route.NewVertexFromBytes(nodeKey)
×
3227
                if err != nil {
×
3228
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
3229
                                "from bytes: %w", err)
×
3230
                }
×
3231

3232
                prunedNodes[i] = pub
×
3233
        }
3234

3235
        return prunedNodes, nil
×
3236
}
3237

3238
// DisconnectBlockAtHeight is used to indicate that the block specified
3239
// by the passed height has been disconnected from the main chain. This
3240
// will "rewind" the graph back to the height below, deleting channels
3241
// that are no longer confirmed from the graph. The prune log will be
3242
// set to the last prune height valid for the remaining chain.
3243
// Channels that were removed from the graph resulting from the
3244
// disconnected block are returned.
3245
//
3246
// NOTE: part of the Store interface.
3247
func (s *SQLStore) DisconnectBlockAtHeight(ctx context.Context,
3248
        height uint32) ([]*models.ChannelEdgeInfo, error) {
×
3249

×
3250
        var (
×
3251
                // Every channel having a ShortChannelID starting at 'height'
×
3252
                // will no longer be confirmed.
×
3253
                startShortChanID = lnwire.ShortChannelID{
×
3254
                        BlockHeight: height,
×
3255
                }
×
3256

×
3257
                // Delete everything after this height from the db up until the
×
3258
                // SCID alias range.
×
3259
                endShortChanID = aliasmgr.StartingAlias
×
3260

×
3261
                removedChans []*models.ChannelEdgeInfo
×
3262

×
3263
                chanIDStart = channelIDToBytes(startShortChanID.ToUint64())
×
3264
                chanIDEnd   = channelIDToBytes(endShortChanID.ToUint64())
×
3265
        )
×
3266

×
3267
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
3268
                rows, err := db.GetChannelsBySCIDRange(
×
3269
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
3270
                                StartScid: chanIDStart,
×
3271
                                EndScid:   chanIDEnd,
×
3272
                        },
×
3273
                )
×
3274
                if err != nil {
×
3275
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
3276
                }
×
3277

3278
                if len(rows) == 0 {
×
3279
                        // No channels to disconnect, but still clean up prune
×
3280
                        // log.
×
3281
                        return db.DeletePruneLogEntriesInRange(
×
3282
                                ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
3283
                                        StartHeight: int64(height),
×
3284
                                        EndHeight: int64(
×
3285
                                                endShortChanID.BlockHeight,
×
3286
                                        ),
×
3287
                                },
×
3288
                        )
×
3289
                }
×
3290

3291
                // Batch build all channel edges for disconnection.
3292
                channelEdges, chanIDsToDelete, err := batchBuildChannelInfo(
×
3293
                        ctx, s.cfg, db, rows,
×
3294
                )
×
3295
                if err != nil {
×
3296
                        return err
×
3297
                }
×
3298

3299
                removedChans = channelEdges
×
3300

×
3301
                err = s.deleteChannels(ctx, db, chanIDsToDelete)
×
3302
                if err != nil {
×
3303
                        return fmt.Errorf("unable to delete channels: %w", err)
×
3304
                }
×
3305

3306
                return db.DeletePruneLogEntriesInRange(
×
3307
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
3308
                                StartHeight: int64(height),
×
3309
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
3310
                        },
×
3311
                )
×
3312
        }, func() {
×
3313
                removedChans = nil
×
3314
        })
×
3315
        if err != nil {
×
3316
                return nil, fmt.Errorf("unable to disconnect block at "+
×
3317
                        "height: %w", err)
×
3318
        }
×
3319

3320
        s.cacheMu.Lock()
×
3321
        for _, channel := range removedChans {
×
3322
                s.rejectCache.remove(channel.Version, channel.ChannelID)
×
3323
                s.chanCache.remove(channel.Version, channel.ChannelID)
×
3324
        }
×
3325
        s.cacheMu.Unlock()
×
3326

×
3327
        return removedChans, nil
×
3328
}
3329

3330
// AddEdgeProof sets the proof of an existing edge in the graph database.
3331
//
3332
// NOTE: part of the Store interface.
3333
func (s *SQLStore) AddEdgeProof(ctx context.Context,
3334
        scid lnwire.ShortChannelID, proof *models.ChannelAuthProof) error {
×
3335

×
3336
        if !isKnownGossipVersion(proof.Version) {
×
3337
                return fmt.Errorf("unsupported gossip version: %d",
×
3338
                        proof.Version)
×
3339
        }
×
3340

3341
        var (
×
3342
                scidBytes = channelIDToBytes(scid.ToUint64())
×
3343
        )
×
3344

×
3345
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
3346
                var (
×
3347
                        res sql.Result
×
3348
                        err error
×
3349
                )
×
3350
                switch proof.Version {
×
3351
                case gossipV1:
×
3352
                        res, err = db.AddV1ChannelProof(
×
3353
                                ctx, sqlc.AddV1ChannelProofParams{
×
3354
                                        Scid:              scidBytes,
×
3355
                                        Node1Signature:    proof.NodeSig1(),
×
3356
                                        Node2Signature:    proof.NodeSig2(),
×
3357
                                        Bitcoin1Signature: proof.BitcoinSig1(),
×
3358
                                        Bitcoin2Signature: proof.BitcoinSig2(),
×
3359
                                },
×
3360
                        )
×
3361

3362
                case gossipV2:
×
3363
                        res, err = db.AddV2ChannelProof(
×
3364
                                ctx, sqlc.AddV2ChannelProofParams{
×
3365
                                        Scid:      scidBytes,
×
3366
                                        Signature: proof.Sig(),
×
3367
                                },
×
3368
                        )
×
3369

3370
                default:
×
3371
                        return fmt.Errorf("unsupported gossip version: %d",
×
3372
                                proof.Version)
×
3373
                }
3374
                if err != nil {
×
3375
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
3376
                }
×
3377

3378
                n, err := res.RowsAffected()
×
3379
                if err != nil {
×
3380
                        return err
×
3381
                }
×
3382

3383
                if n == 0 {
×
3384
                        return fmt.Errorf("no rows affected when adding edge "+
×
3385
                                "proof for SCID %v", scid)
×
3386
                } else if n > 1 {
×
3387
                        return fmt.Errorf("multiple rows affected when adding "+
×
3388
                                "edge proof for SCID %v: %d rows affected",
×
3389
                                scid, n)
×
3390
                }
×
3391

3392
                return nil
×
3393
        }, sqldb.NoOpReset)
3394
        if err != nil {
×
3395
                return fmt.Errorf("unable to add edge proof: %w", err)
×
3396
        }
×
3397

3398
        return nil
×
3399
}
3400

3401
// PutClosedScid stores a SCID for a closed channel in the database. This is so
3402
// that we can ignore channel announcements that we know to be closed without
3403
// having to validate them and fetch a block.
3404
//
3405
// NOTE: part of the Store interface.
3406
func (s *SQLStore) PutClosedScid(ctx context.Context,
3407
        scid lnwire.ShortChannelID) error {
×
3408

×
3409
        var (
×
3410
                chanIDB = channelIDToBytes(scid.ToUint64())
×
3411
        )
×
3412

×
3413
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
3414
                return db.InsertClosedChannel(ctx, chanIDB)
×
3415
        }, sqldb.NoOpReset)
×
3416
}
3417

3418
// IsClosedScid checks whether a channel identified by the passed in scid is
3419
// closed. This helps avoid having to perform expensive validation checks.
3420
//
3421
// NOTE: part of the Store interface.
3422
func (s *SQLStore) IsClosedScid(ctx context.Context,
3423
        scid lnwire.ShortChannelID) (bool, error) {
×
3424

×
3425
        var (
×
3426
                isClosed bool
×
3427
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
3428
        )
×
3429
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
3430
                var err error
×
3431
                isClosed, err = db.IsClosedChannel(ctx, chanIDB)
×
3432
                if err != nil {
×
3433
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
3434
                                err)
×
3435
                }
×
3436

3437
                return nil
×
3438
        }, sqldb.NoOpReset)
3439
        if err != nil {
×
3440
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
3441
                        err)
×
3442
        }
×
3443

3444
        return isClosed, nil
×
3445
}
3446

3447
// GraphSession will provide the call-back with access to a NodeTraverser
3448
// instance which can be used to perform queries against the channel graph.
3449
//
3450
// NOTE: part of the Store interface.
3451
func (s *SQLStore) GraphSession(ctx context.Context,
3452
        cb func(graph NodeTraverser) error, reset func()) error {
×
3453

×
3454
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
3455
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
3456
        }, reset)
×
3457
}
3458

3459
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
3460
// read only transaction for a consistent view of the graph.
3461
type sqlNodeTraverser struct {
3462
        db    SQLQueries
3463
        chain chainhash.Hash
3464
}
3465

3466
// A compile-time assertion to ensure that sqlNodeTraverser implements the
3467
// NodeTraverser interface.
3468
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
3469

3470
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
3471
func newSQLNodeTraverser(db SQLQueries,
3472
        chain chainhash.Hash) *sqlNodeTraverser {
×
3473

×
3474
        return &sqlNodeTraverser{
×
3475
                db:    db,
×
3476
                chain: chain,
×
3477
        }
×
3478
}
×
3479

3480
// ForEachNodeDirectedChannel calls the callback for every channel of the given
3481
// node.
3482
//
3483
// NOTE: Part of the NodeTraverser interface.
3484
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(
3485
        ctx context.Context, nodePub route.Vertex,
3486
        cb func(channel *DirectedChannel) error, _ func()) error {
×
3487

×
3488
        return forEachNodeDirectedChannel(
×
3489
                ctx, s.db, lnwire.GossipVersion1, nodePub, cb,
×
3490
        )
×
3491
}
×
3492

3493
// FetchNodeFeatures returns the features of the given node. If the node is
3494
// unknown, assume no additional features are supported.
3495
//
3496
// NOTE: Part of the NodeTraverser interface.
3497
func (s *sqlNodeTraverser) FetchNodeFeatures(ctx context.Context,
3498
        nodePub route.Vertex) (
3499
        *lnwire.FeatureVector, error) {
×
3500

×
3501
        return fetchNodeFeatures(ctx, s.db, lnwire.GossipVersion1, nodePub)
×
3502
}
×
3503

3504
// forEachNodeDirectedChannel iterates through all channels of a given
3505
// node, executing the passed callback on the directed edge representing the
3506
// channel and its incoming policy. If the node is not found, no error is
3507
// returned.
3508
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
3509
        v lnwire.GossipVersion, nodePub route.Vertex,
3510
        cb func(channel *DirectedChannel) error) error {
×
3511

×
3512
        toNodeCallback := func() route.Vertex {
×
3513
                return nodePub
×
3514
        }
×
3515

3516
        dbID, err := db.GetNodeIDByPubKey(
×
3517
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
3518
                        Version: int16(v),
×
3519
                        PubKey:  nodePub[:],
×
3520
                },
×
3521
        )
×
3522
        if errors.Is(err, sql.ErrNoRows) {
×
3523
                return nil
×
3524
        } else if err != nil {
×
3525
                return fmt.Errorf("unable to fetch node: %w", err)
×
3526
        }
×
3527

3528
        rows, err := db.ListChannelsByNodeID(
×
3529
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3530
                        Version: int16(v),
×
3531
                        NodeID1: dbID,
×
3532
                },
×
3533
        )
×
3534
        if err != nil {
×
3535
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3536
        }
×
3537

3538
        // Exit early if there are no channels for this node so we don't
3539
        // do the unnecessary feature fetching.
3540
        if len(rows) == 0 {
×
3541
                return nil
×
3542
        }
×
3543

3544
        features, err := getNodeFeatures(ctx, db, dbID)
×
3545
        if err != nil {
×
3546
                return fmt.Errorf("unable to fetch node features: %w", err)
×
3547
        }
×
3548

3549
        for _, row := range rows {
×
3550
                node1, node2, err := buildNodeVertices(
×
3551
                        row.Node1Pubkey, row.Node2Pubkey,
×
3552
                )
×
3553
                if err != nil {
×
3554
                        return fmt.Errorf("unable to build node vertices: %w",
×
3555
                                err)
×
3556
                }
×
3557

3558
                edge := buildCacheableChannelInfo(
×
3559
                        row.GraphChannel.Scid, row.GraphChannel.Capacity.Int64,
×
3560
                        node1, node2,
×
3561
                )
×
3562

×
3563
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3564
                if err != nil {
×
3565
                        return err
×
3566
                }
×
3567

3568
                p1, p2, err := buildCachedChanPolicies(
×
3569
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
3570
                )
×
3571
                if err != nil {
×
3572
                        return err
×
3573
                }
×
3574

3575
                // Determine the outgoing and incoming policy for this
3576
                // channel and node combo.
3577
                outPolicy, inPolicy := p1, p2
×
3578
                if p1 != nil && node2 == nodePub {
×
3579
                        outPolicy, inPolicy = p2, p1
×
3580
                } else if p2 != nil && node1 != nodePub {
×
3581
                        outPolicy, inPolicy = p2, p1
×
3582
                }
×
3583

3584
                var cachedInPolicy *models.CachedEdgePolicy
×
3585
                if inPolicy != nil {
×
3586
                        cachedInPolicy = inPolicy
×
3587
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
3588
                        cachedInPolicy.ToNodeFeatures = features
×
3589
                }
×
3590

3591
                directedChannel := &DirectedChannel{
×
3592
                        ChannelID:    edge.ChannelID,
×
3593
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
3594
                        OtherNode:    edge.NodeKey2Bytes,
×
3595
                        Capacity:     edge.Capacity,
×
3596
                        OutPolicySet: outPolicy != nil,
×
3597
                        InPolicy:     cachedInPolicy,
×
3598
                }
×
3599
                if outPolicy != nil {
×
3600
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3601
                                directedChannel.InboundFee = fee
×
3602
                        })
×
3603
                }
3604

3605
                if nodePub == edge.NodeKey2Bytes {
×
3606
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
3607
                }
×
3608

3609
                if err := cb(directedChannel); err != nil {
×
3610
                        return err
×
3611
                }
×
3612
        }
3613

3614
        return nil
×
3615
}
3616

3617
// forEachNodeCacheable fetches all node IDs and pub keys from the database,
3618
// and executes the provided callback for each node. It does so via pagination
3619
// along with batch loading of the node feature bits.
3620
func forEachNodeCacheable(ctx context.Context, cfg *sqldb.QueryConfig,
3621
        db SQLQueries, v lnwire.GossipVersion,
3622
        processNode func(nodeID int64, nodePub route.Vertex,
3623
                features *lnwire.FeatureVector) error) error {
×
3624

×
3625
        handleNode := func(_ context.Context,
×
3626
                dbNode sqlc.ListNodeIDsAndPubKeysRow,
×
3627
                featureBits map[int64][]int) error {
×
3628

×
3629
                fv := lnwire.EmptyFeatureVector()
×
3630
                if features, exists := featureBits[dbNode.ID]; exists {
×
3631
                        for _, bit := range features {
×
3632
                                fv.Set(lnwire.FeatureBit(bit))
×
3633
                        }
×
3634
                }
3635

3636
                var pub route.Vertex
×
3637
                copy(pub[:], dbNode.PubKey)
×
3638

×
3639
                return processNode(dbNode.ID, pub, fv)
×
3640
        }
3641

3642
        queryFunc := func(ctx context.Context, lastID int64,
×
3643
                limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
3644

×
3645
                return db.ListNodeIDsAndPubKeys(
×
3646
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
3647
                                Version: int16(v),
×
3648
                                ID:      lastID,
×
3649
                                Limit:   limit,
×
3650
                        },
×
3651
                )
×
3652
        }
×
3653

3654
        extractCursor := func(row sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
3655
                return row.ID
×
3656
        }
×
3657

3658
        collectFunc := func(node sqlc.ListNodeIDsAndPubKeysRow) (int64, error) {
×
3659
                return node.ID, nil
×
3660
        }
×
3661

3662
        batchQueryFunc := func(ctx context.Context,
×
3663
                nodeIDs []int64) (map[int64][]int, error) {
×
3664

×
3665
                return batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
3666
        }
×
3667

3668
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
3669
                ctx, cfg, int64(-1), queryFunc, extractCursor, collectFunc,
×
3670
                batchQueryFunc, handleNode,
×
3671
        )
×
3672
}
3673

3674
// forEachNodeChannel iterates through all channels of a node, executing
3675
// the passed callback on each. The call-back is provided with the channel's
3676
// edge information, the outgoing policy and the incoming policy for the
3677
// channel and node combo.
3678
func forEachNodeChannel(ctx context.Context, db SQLQueries,
3679
        cfg *SQLStoreConfig, v lnwire.GossipVersion, id int64,
3680
        cb func(*models.ChannelEdgeInfo,
3681
                *models.ChannelEdgePolicy,
3682
                *models.ChannelEdgePolicy) error) error {
×
3683

×
3684
        // Get all the channels for this node.
×
3685
        rows, err := db.ListChannelsByNodeID(
×
3686
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3687
                        Version: int16(v),
×
3688
                        NodeID1: id,
×
3689
                },
×
3690
        )
×
3691
        if err != nil {
×
3692
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3693
        }
×
3694

3695
        // Collect all the channel and policy IDs.
3696
        var (
×
3697
                chanIDs   = make([]int64, 0, len(rows))
×
3698
                policyIDs = make([]int64, 0, 2*len(rows))
×
3699
        )
×
3700
        for _, row := range rows {
×
3701
                chanIDs = append(chanIDs, row.GraphChannel.ID)
×
3702

×
3703
                if row.Policy1ID.Valid {
×
3704
                        policyIDs = append(policyIDs, row.Policy1ID.Int64)
×
3705
                }
×
3706
                if row.Policy2ID.Valid {
×
3707
                        policyIDs = append(policyIDs, row.Policy2ID.Int64)
×
3708
                }
×
3709
        }
3710

3711
        batchData, err := batchLoadChannelData(
×
3712
                ctx, cfg.QueryCfg, db, chanIDs, policyIDs,
×
3713
        )
×
3714
        if err != nil {
×
3715
                return fmt.Errorf("unable to batch load channel data: %w", err)
×
3716
        }
×
3717

3718
        // Call the call-back for each channel and its known policies.
3719
        for _, row := range rows {
×
3720
                node1, node2, err := buildNodeVertices(
×
3721
                        row.Node1Pubkey, row.Node2Pubkey,
×
3722
                )
×
3723
                if err != nil {
×
3724
                        return fmt.Errorf("unable to build node vertices: %w",
×
3725
                                err)
×
3726
                }
×
3727

3728
                edge, err := buildEdgeInfoWithBatchData(
×
3729
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
3730
                        batchData,
×
3731
                )
×
3732
                if err != nil {
×
3733
                        return fmt.Errorf("unable to build channel info: %w",
×
3734
                                err)
×
3735
                }
×
3736

3737
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3738
                if err != nil {
×
3739
                        return fmt.Errorf("unable to extract channel "+
×
3740
                                "policies: %w", err)
×
3741
                }
×
3742

3743
                p1, p2, err := buildChanPoliciesWithBatchData(
×
3744
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
3745
                )
×
3746
                if err != nil {
×
3747
                        return fmt.Errorf("unable to build channel "+
×
3748
                                "policies: %w", err)
×
3749
                }
×
3750

3751
                // Determine the outgoing and incoming policy for this
3752
                // channel and node combo.
3753
                p1ToNode := row.GraphChannel.NodeID2
×
3754
                p2ToNode := row.GraphChannel.NodeID1
×
3755
                outPolicy, inPolicy := p1, p2
×
3756
                if (p1 != nil && p1ToNode == id) ||
×
3757
                        (p2 != nil && p2ToNode != id) {
×
3758

×
3759
                        outPolicy, inPolicy = p2, p1
×
3760
                }
×
3761

3762
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3763
                        return err
×
3764
                }
×
3765
        }
3766

3767
        return nil
×
3768
}
3769

3770
// updateChanEdgePolicy upserts the channel policy info we have stored for
3771
// a channel we already know of.
3772
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3773
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
3774
        error) {
×
3775

×
3776
        var (
×
3777
                node1Pub, node2Pub route.Vertex
×
3778
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3779
                version            = edge.Version
×
3780
        )
×
3781

×
3782
        if !isKnownGossipVersion(version) {
×
3783
                return node1Pub, node2Pub, false, fmt.Errorf(
×
3784
                        "unsupported gossip version: %d", version,
×
3785
                )
×
3786
        }
×
3787

3788
        // Check that this edge policy refers to a channel that we already
3789
        // know of. We do this explicitly so that we can return the appropriate
3790
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
3791
        // abort the transaction which would abort the entire batch.
3792
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3793
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3794
                        Scid:    chanIDB,
×
3795
                        Version: int16(version),
×
3796
                },
×
3797
        )
×
3798
        if errors.Is(err, sql.ErrNoRows) {
×
3799
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3800
        } else if err != nil {
×
3801
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3802
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3803
        }
×
3804

3805
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3806
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3807

×
3808
        // Figure out which node this edge is from.
×
3809
        isNode1 := edge.IsNode1()
×
3810
        nodeID := dbChan.NodeID1
×
3811
        if !isNode1 {
×
3812
                nodeID = dbChan.NodeID2
×
3813
        }
×
3814

3815
        var (
×
3816
                inboundBase sql.NullInt64
×
3817
                inboundRate sql.NullInt64
×
3818
        )
×
3819
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3820
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3821
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3822
        })
×
3823

3824
        params := sqlc.UpsertEdgePolicyParams{
×
3825
                Version:                 int16(version),
×
3826
                ChannelID:               dbChan.ID,
×
3827
                NodeID:                  nodeID,
×
3828
                Timelock:                int32(edge.TimeLockDelta),
×
3829
                FeePpm:                  int64(edge.FeeProportionalMillionths),
×
3830
                BaseFeeMsat:             int64(edge.FeeBaseMSat),
×
3831
                MinHtlcMsat:             int64(edge.MinHTLC),
×
3832
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
3833
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
3834
                InboundBaseFeeMsat:      inboundBase,
×
3835
                InboundFeeRateMilliMsat: inboundRate,
×
3836
                Signature:               edge.SigBytes,
×
3837
        }
×
3838

×
3839
        switch version {
×
3840
        case gossipV1:
×
3841
                params.LastUpdate = sqldb.SQLInt64(edge.LastUpdate.Unix())
×
3842
                params.Disabled = sql.NullBool{
×
3843
                        Valid: true,
×
3844
                        Bool:  edge.IsDisabled(),
×
3845
                }
×
3846
                params.MaxHtlcMsat = sql.NullInt64{
×
3847
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3848
                        Int64: int64(edge.MaxHTLC),
×
3849
                }
×
3850
        case gossipV2:
×
3851
                params.BlockHeight = sqldb.SQLInt64(
×
3852
                        int64(edge.LastBlockHeight),
×
3853
                )
×
3854
                params.DisableFlags = sqldb.SQLInt16(edge.DisableFlags)
×
3855
                params.MaxHtlcMsat = sqldb.SQLInt64(int64(edge.MaxHTLC))
×
3856
        }
3857

3858
        id, err := tx.UpsertEdgePolicy(ctx, params)
×
3859
        if err != nil {
×
3860
                return node1Pub, node2Pub, isNode1,
×
3861
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3862
        }
×
3863

3864
        // Convert the flat extra opaque data into a map of TLV types to
3865
        // values.
3866
        extra := edge.ExtraSignedFields
×
3867
        if version == gossipV1 {
×
3868
                extra, err = marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3869
                if err != nil {
×
3870
                        return node1Pub, node2Pub, false, fmt.Errorf(
×
3871
                                "unable to marshal extra opaque data: %w", err,
×
3872
                        )
×
3873
                }
×
3874
        }
3875

3876
        // Update the channel policy's extra signed fields.
3877
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3878
        if err != nil {
×
3879
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3880
                        "policy extra TLVs: %w", err)
×
3881
        }
×
3882

3883
        return node1Pub, node2Pub, isNode1, nil
×
3884
}
3885

3886
// getNodeByPubKey attempts to look up a target node by its public key.
3887
func getNodeByPubKey(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries,
3888
        v lnwire.GossipVersion, pubKey route.Vertex) (int64, *models.Node,
3889
        error) {
×
3890

×
3891
        dbNode, err := db.GetNodeByPubKey(
×
3892
                ctx, sqlc.GetNodeByPubKeyParams{
×
3893
                        Version: int16(v),
×
3894
                        PubKey:  pubKey[:],
×
3895
                },
×
3896
        )
×
3897
        if errors.Is(err, sql.ErrNoRows) {
×
3898
                return 0, nil, ErrGraphNodeNotFound
×
3899
        } else if err != nil {
×
3900
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3901
        }
×
3902

3903
        node, err := buildNode(ctx, cfg, db, dbNode)
×
3904
        if err != nil {
×
3905
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3906
        }
×
3907

3908
        return dbNode.ID, node, nil
×
3909
}
3910

3911
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3912
// provided parameters.
3913
func buildCacheableChannelInfo(scid []byte, capacity int64, node1Pub,
3914
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
3915

×
3916
        return &models.CachedEdgeInfo{
×
3917
                ChannelID:     byteOrder.Uint64(scid),
×
3918
                NodeKey1Bytes: node1Pub,
×
3919
                NodeKey2Bytes: node2Pub,
×
3920
                Capacity:      btcutil.Amount(capacity),
×
3921
        }
×
3922
}
×
3923

3924
// buildNode constructs a Node instance from the given database node
3925
// record. The node's features, addresses and extra signed fields are also
3926
// fetched from the database and set on the node.
3927
func buildNode(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries,
3928
        dbNode sqlc.GraphNode) (*models.Node, error) {
×
3929

×
3930
        data, err := batchLoadNodeData(ctx, cfg, db, []int64{dbNode.ID})
×
3931
        if err != nil {
×
3932
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
3933
                        err)
×
3934
        }
×
3935

3936
        return buildNodeWithBatchData(dbNode, data)
×
3937
}
3938

3939
// isKnownGossipVersion checks whether the provided gossip version is known
3940
// and supported.
3941
func isKnownGossipVersion(v lnwire.GossipVersion) bool {
×
3942
        switch v {
×
3943
        case gossipV1:
×
3944
                return true
×
3945
        case gossipV2:
×
3946
                return true
×
3947
        default:
×
3948
                return false
×
3949
        }
3950
}
3951

3952
// buildNodeWithBatchData builds a models.Node instance
3953
// from the provided sqlc.GraphNode and batchNodeData. If the node does have
3954
// features/addresses/extra fields, then the corresponding fields are expected
3955
// to be present in the batchNodeData.
3956
func buildNodeWithBatchData(dbNode sqlc.GraphNode,
3957
        batchData *batchNodeData) (*models.Node, error) {
×
3958

×
3959
        v := lnwire.GossipVersion(dbNode.Version)
×
3960

×
3961
        if !isKnownGossipVersion(v) {
×
3962
                return nil, fmt.Errorf("unknown node version: %d", v)
×
3963
        }
×
3964

3965
        pub, err := route.NewVertexFromBytes(dbNode.PubKey)
×
3966
        if err != nil {
×
3967
                return nil, fmt.Errorf("unable to parse pubkey: %w", err)
×
3968
        }
×
3969

3970
        node := models.NewShellNode(v, pub)
×
3971

×
3972
        if len(dbNode.Signature) == 0 {
×
3973
                return node, nil
×
3974
        }
×
3975

3976
        node.AuthSigBytes = dbNode.Signature
×
3977

×
3978
        if dbNode.Alias.Valid {
×
3979
                node.Alias = fn.Some(dbNode.Alias.String)
×
3980
        }
×
3981
        if dbNode.LastUpdate.Valid {
×
3982
                node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
3983
        }
×
3984
        if dbNode.BlockHeight.Valid {
×
3985
                node.LastBlockHeight = uint32(dbNode.BlockHeight.Int64)
×
3986
        }
×
3987

3988
        if dbNode.Color.Valid {
×
3989
                nodeColor, err := DecodeHexColor(dbNode.Color.String)
×
3990
                if err != nil {
×
3991
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3992
                                err)
×
3993
                }
×
3994

3995
                node.Color = fn.Some(nodeColor)
×
3996
        }
3997

3998
        // Use preloaded features.
3999
        if features, exists := batchData.features[dbNode.ID]; exists {
×
4000
                fv := lnwire.EmptyFeatureVector()
×
4001
                for _, bit := range features {
×
4002
                        fv.Set(lnwire.FeatureBit(bit))
×
4003
                }
×
4004
                node.Features = fv
×
4005
        }
4006

4007
        // Use preloaded addresses.
4008
        addresses, exists := batchData.addresses[dbNode.ID]
×
4009
        if exists && len(addresses) > 0 {
×
4010
                node.Addresses, err = buildNodeAddresses(addresses)
×
4011
                if err != nil {
×
4012
                        return nil, fmt.Errorf("unable to build addresses "+
×
4013
                                "for node(%d): %w", dbNode.ID, err)
×
4014
                }
×
4015
        }
4016

4017
        // Use preloaded extra fields.
4018
        if extraFields, exists := batchData.extraFields[dbNode.ID]; exists {
×
4019
                if v == gossipV1 {
×
4020
                        records := lnwire.CustomRecords(extraFields)
×
4021
                        recs, err := records.Serialize()
×
4022
                        if err != nil {
×
4023
                                return nil, fmt.Errorf("unable to serialize "+
×
4024
                                        "extra signed fields: %w", err)
×
4025
                        }
×
4026

4027
                        if len(recs) != 0 {
×
4028
                                node.ExtraOpaqueData = recs
×
4029
                        }
×
4030
                } else if len(extraFields) > 0 {
×
4031
                        node.ExtraSignedFields = extraFields
×
4032
                }
×
4033
        }
4034

4035
        return node, nil
×
4036
}
4037

4038
// forEachNodeInBatch fetches all nodes in the provided batch, builds them
4039
// with the preloaded data, and executes the provided callback for each node.
4040
func forEachNodeInBatch(ctx context.Context, cfg *sqldb.QueryConfig,
4041
        db SQLQueries, nodes []sqlc.GraphNode,
4042
        cb func(dbID int64, node *models.Node) error) error {
×
4043

×
4044
        // Extract node IDs for batch loading.
×
4045
        nodeIDs := make([]int64, len(nodes))
×
4046
        for i, node := range nodes {
×
4047
                nodeIDs[i] = node.ID
×
4048
        }
×
4049

4050
        // Batch load all related data for this page.
4051
        batchData, err := batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
4052
        if err != nil {
×
4053
                return fmt.Errorf("unable to batch load node data: %w", err)
×
4054
        }
×
4055

4056
        for _, dbNode := range nodes {
×
4057
                node, err := buildNodeWithBatchData(dbNode, batchData)
×
4058
                if err != nil {
×
4059
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
4060
                                dbNode.ID, err)
×
4061
                }
×
4062

4063
                if err := cb(dbNode.ID, node); err != nil {
×
4064
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
4065
                                dbNode.ID, err)
×
4066
                }
×
4067
        }
4068

4069
        return nil
×
4070
}
4071

4072
// getNodeFeatures fetches the feature bits and constructs the feature vector
4073
// for a node with the given DB ID.
4074
func getNodeFeatures(ctx context.Context, db SQLQueries,
4075
        nodeID int64) (*lnwire.FeatureVector, error) {
×
4076

×
4077
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
4078
        if err != nil {
×
4079
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
4080
                        nodeID, err)
×
4081
        }
×
4082

4083
        features := lnwire.EmptyFeatureVector()
×
4084
        for _, feature := range rows {
×
4085
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
4086
        }
×
4087

4088
        return features, nil
×
4089
}
4090

4091
// upsertNodeAncillaryData updates the node's features, addresses, and extra
4092
// signed fields. This is common logic shared by upsertNode and
4093
// upsertSourceNode.
4094
func upsertNodeAncillaryData(ctx context.Context, db SQLQueries,
4095
        nodeID int64, node *models.Node) error {
×
4096

×
4097
        // Update the node's features.
×
4098
        err := upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
4099
        if err != nil {
×
4100
                return fmt.Errorf("inserting node features: %w", err)
×
4101
        }
×
4102

4103
        // Update the node's addresses.
4104
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
4105
        if err != nil {
×
4106
                return fmt.Errorf("inserting node addresses: %w", err)
×
4107
        }
×
4108

4109
        // Convert the flat extra opaque data into a map of TLV types to
4110
        // values.
4111
        extra := node.ExtraSignedFields
×
4112
        if node.Version == gossipV1 {
×
4113
                extra, err = marshalExtraOpaqueData(node.ExtraOpaqueData)
×
4114
                if err != nil {
×
4115
                        return fmt.Errorf("unable to marshal extra opaque "+
×
4116
                                "data: %w", err)
×
4117
                }
×
4118
        }
4119

4120
        // Update the node's extra signed fields.
4121
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
4122
        if err != nil {
×
4123
                return fmt.Errorf("inserting node extra TLVs: %w", err)
×
4124
        }
×
4125

4126
        return nil
×
4127
}
4128

4129
// populateNodeParams populates the common node parameters from a models.Node.
4130
// This is a helper for building UpsertNodeParams and UpsertSourceNodeParams.
4131
func populateNodeParams(node *models.Node,
4132
        setParams func(lastUpdate, lastBlockHeight sql.NullInt64, alias,
4133
                colorStr sql.NullString, signature []byte)) error {
×
4134

×
4135
        if !node.HaveAnnouncement() {
×
4136
                return nil
×
4137
        }
×
4138

4139
        var (
×
4140
                alias, colorStr             sql.NullString
×
4141
                lastUpdate, lastBlockHeight sql.NullInt64
×
4142
        )
×
4143
        node.Color.WhenSome(func(rgba color.RGBA) {
×
4144
                colorStr = sqldb.SQLStrValid(EncodeHexColor(rgba))
×
4145
        })
×
4146
        node.Alias.WhenSome(func(s string) {
×
4147
                alias = sqldb.SQLStrValid(s)
×
4148
        })
×
4149

4150
        switch node.Version {
×
4151
        case gossipV1:
×
4152
                lastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
4153

4154
        case gossipV2:
×
4155
                lastBlockHeight = sqldb.SQLInt64(int64(node.LastBlockHeight))
×
4156

4157
        default:
×
4158
                return fmt.Errorf("unknown gossip version: %d", node.Version)
×
4159
        }
4160

4161
        setParams(
×
4162
                lastUpdate, lastBlockHeight, alias, colorStr, node.AuthSigBytes,
×
4163
        )
×
4164

×
4165
        return nil
×
4166
}
4167

4168
// buildNodeUpsertParams builds the parameters for upserting a node using the
4169
// strict UpsertNode query (requires timestamp to be increasing).
4170
func buildNodeUpsertParams(node *models.Node) (sqlc.UpsertNodeParams, error) {
×
4171
        params := sqlc.UpsertNodeParams{
×
4172
                Version: int16(node.Version),
×
4173
                PubKey:  node.PubKeyBytes[:],
×
4174
        }
×
4175

×
4176
        err := populateNodeParams(
×
4177
                node, func(lastUpdate, lastBlockHeight sql.NullInt64, alias,
×
4178
                        colorStr sql.NullString,
×
4179
                        signature []byte) {
×
4180

×
4181
                        params.LastUpdate = lastUpdate
×
4182
                        params.BlockHeight = lastBlockHeight
×
4183
                        params.Alias = alias
×
4184
                        params.Color = colorStr
×
4185
                        params.Signature = signature
×
4186
                },
×
4187
        )
4188

4189
        return params, err
×
4190
}
4191

4192
// buildSourceNodeUpsertParams builds the parameters for upserting the source
4193
// node using the lenient UpsertSourceNode query (allows same timestamp).
4194
func buildSourceNodeUpsertParams(node *models.Node) (
4195
        sqlc.UpsertSourceNodeParams, error) {
×
4196

×
4197
        params := sqlc.UpsertSourceNodeParams{
×
4198
                Version: int16(node.Version),
×
4199
                PubKey:  node.PubKeyBytes[:],
×
4200
        }
×
4201

×
4202
        err := populateNodeParams(
×
4203
                node, func(lastUpdate, lastBlock sql.NullInt64, alias,
×
4204
                        colorStr sql.NullString, signature []byte) {
×
4205

×
4206
                        params.BlockHeight = lastBlock
×
4207
                        params.LastUpdate = lastUpdate
×
4208
                        params.Alias = alias
×
4209
                        params.Color = colorStr
×
4210
                        params.Signature = signature
×
4211
                },
×
4212
        )
4213

4214
        return params, err
×
4215
}
4216

4217
// upsertSourceNode upserts the source node record into the database using a
4218
// less strict upsert that allows updates even when the timestamp hasn't
4219
// changed. This is necessary to handle concurrent updates to our own node
4220
// during startup and runtime. The node's features, addresses and extra TLV
4221
// types are also updated. The node's DB ID is returned.
4222
func upsertSourceNode(ctx context.Context, db SQLQueries,
4223
        node *models.Node) (int64, error) {
×
4224

×
4225
        params, err := buildSourceNodeUpsertParams(node)
×
4226
        if err != nil {
×
4227
                return 0, err
×
4228
        }
×
4229

4230
        nodeID, err := db.UpsertSourceNode(ctx, params)
×
4231
        if err != nil {
×
4232
                return 0, fmt.Errorf("upserting source node(%x): %w",
×
4233
                        node.PubKeyBytes, err)
×
4234
        }
×
4235

4236
        // We can exit here if we don't have the announcement yet.
4237
        if !node.HaveAnnouncement() {
×
4238
                return nodeID, nil
×
4239
        }
×
4240

4241
        // Update the ancillary node data (features, addresses, extra fields).
4242
        err = upsertNodeAncillaryData(ctx, db, nodeID, node)
×
4243
        if err != nil {
×
4244
                return 0, err
×
4245
        }
×
4246

4247
        return nodeID, nil
×
4248
}
4249

4250
// upsertNode upserts the node record into the database. If the node already
4251
// exists, then the node's information is updated. If the node doesn't exist,
4252
// then a new node is created. The node's features, addresses and extra TLV
4253
// types are also updated. The node's DB ID is returned.
4254
func upsertNode(ctx context.Context, db SQLQueries,
4255
        node *models.Node) (int64, error) {
×
4256

×
4257
        if !isKnownGossipVersion(node.Version) {
×
4258
                return 0, fmt.Errorf("unknown gossip version: %d", node.Version)
×
4259
        }
×
4260

4261
        params, err := buildNodeUpsertParams(node)
×
4262
        if err != nil {
×
4263
                return 0, err
×
4264
        }
×
4265

4266
        nodeID, err := db.UpsertNode(ctx, params)
×
4267
        if err != nil {
×
4268
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
4269
                        err)
×
4270
        }
×
4271

4272
        // We can exit here if we don't have the announcement yet.
4273
        if !node.HaveAnnouncement() {
×
4274
                return nodeID, nil
×
4275
        }
×
4276

4277
        // Update the ancillary node data (features, addresses, extra fields).
4278
        err = upsertNodeAncillaryData(ctx, db, nodeID, node)
×
4279
        if err != nil {
×
4280
                return 0, err
×
4281
        }
×
4282

4283
        return nodeID, nil
×
4284
}
4285

4286
// upsertNodeFeatures updates the node's features node_features table. This
4287
// includes deleting any feature bits no longer present and inserting any new
4288
// feature bits. If the feature bit does not yet exist in the features table,
4289
// then an entry is created in that table first.
4290
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
4291
        features *lnwire.FeatureVector) error {
×
4292

×
4293
        // Get any existing features for the node.
×
4294
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
4295
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
4296
                return err
×
4297
        }
×
4298

4299
        // Copy the nodes latest set of feature bits.
4300
        newFeatures := make(map[int32]struct{})
×
4301
        if features != nil {
×
4302
                for feature := range features.Features() {
×
4303
                        newFeatures[int32(feature)] = struct{}{}
×
4304
                }
×
4305
        }
4306

4307
        // For any current feature that already exists in the DB, remove it from
4308
        // the in-memory map. For any existing feature that does not exist in
4309
        // the in-memory map, delete it from the database.
4310
        for _, feature := range existingFeatures {
×
4311
                // The feature is still present, so there are no updates to be
×
4312
                // made.
×
4313
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
4314
                        delete(newFeatures, feature.FeatureBit)
×
4315
                        continue
×
4316
                }
4317

4318
                // The feature is no longer present, so we remove it from the
4319
                // database.
4320
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
4321
                        NodeID:     nodeID,
×
4322
                        FeatureBit: feature.FeatureBit,
×
4323
                })
×
4324
                if err != nil {
×
4325
                        return fmt.Errorf("unable to delete node(%d) "+
×
4326
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
4327
                                err)
×
4328
                }
×
4329
        }
4330

4331
        // Any remaining entries in newFeatures are new features that need to be
4332
        // added to the database for the first time.
4333
        for feature := range newFeatures {
×
4334
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
4335
                        NodeID:     nodeID,
×
4336
                        FeatureBit: feature,
×
4337
                })
×
4338
                if err != nil {
×
4339
                        return fmt.Errorf("unable to insert node(%d) "+
×
4340
                                "feature(%v): %w", nodeID, feature, err)
×
4341
                }
×
4342
        }
4343

4344
        return nil
×
4345
}
4346

4347
// fetchNodeFeatures fetches the features for a node with the given public key.
4348
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
4349
        v lnwire.GossipVersion, nodePub route.Vertex) (*lnwire.FeatureVector,
4350
        error) {
×
4351

×
4352
        rows, err := queries.GetNodeFeaturesByPubKey(
×
4353
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
4354
                        PubKey:  nodePub[:],
×
4355
                        Version: int16(v),
×
4356
                },
×
4357
        )
×
4358
        if err != nil {
×
4359
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
4360
                        nodePub, err)
×
4361
        }
×
4362

4363
        features := lnwire.EmptyFeatureVector()
×
4364
        for _, bit := range rows {
×
4365
                features.Set(lnwire.FeatureBit(bit))
×
4366
        }
×
4367

4368
        return features, nil
×
4369
}
4370

4371
// dbAddressType is an enum type that represents the different address types
4372
// that we store in the node_addresses table. The address type determines how
4373
// the address is to be serialised/deserialize.
4374
type dbAddressType uint8
4375

4376
const (
4377
        addressTypeIPv4   dbAddressType = 1
4378
        addressTypeIPv6   dbAddressType = 2
4379
        addressTypeTorV2  dbAddressType = 3
4380
        addressTypeTorV3  dbAddressType = 4
4381
        addressTypeDNS    dbAddressType = 5
4382
        addressTypeOpaque dbAddressType = math.MaxInt8
4383
)
4384

4385
// collectAddressRecords collects the addresses from the provided
4386
// net.Addr slice and returns a map of dbAddressType to a slice of address
4387
// strings.
4388
func collectAddressRecords(addresses []net.Addr) (map[dbAddressType][]string,
4389
        error) {
×
4390

×
4391
        // Copy the nodes latest set of addresses.
×
4392
        newAddresses := map[dbAddressType][]string{
×
4393
                addressTypeIPv4:   {},
×
4394
                addressTypeIPv6:   {},
×
4395
                addressTypeTorV2:  {},
×
4396
                addressTypeTorV3:  {},
×
4397
                addressTypeDNS:    {},
×
4398
                addressTypeOpaque: {},
×
4399
        }
×
4400
        addAddr := func(t dbAddressType, addr net.Addr) {
×
4401
                newAddresses[t] = append(newAddresses[t], addr.String())
×
4402
        }
×
4403

4404
        for _, address := range addresses {
×
4405
                switch addr := address.(type) {
×
4406
                case *net.TCPAddr:
×
4407
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
4408
                                addAddr(addressTypeIPv4, addr)
×
4409
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
4410
                                addAddr(addressTypeIPv6, addr)
×
4411
                        } else {
×
4412
                                return nil, fmt.Errorf("unhandled IP "+
×
4413
                                        "address: %v", addr)
×
4414
                        }
×
4415

4416
                case *tor.OnionAddr:
×
4417
                        switch len(addr.OnionService) {
×
4418
                        case tor.V2Len:
×
4419
                                addAddr(addressTypeTorV2, addr)
×
4420
                        case tor.V3Len:
×
4421
                                addAddr(addressTypeTorV3, addr)
×
4422
                        default:
×
4423
                                return nil, fmt.Errorf("invalid length for " +
×
4424
                                        "a tor address")
×
4425
                        }
4426

4427
                case *lnwire.DNSAddress:
×
4428
                        addAddr(addressTypeDNS, addr)
×
4429

4430
                case *lnwire.OpaqueAddrs:
×
4431
                        addAddr(addressTypeOpaque, addr)
×
4432

4433
                default:
×
4434
                        return nil, fmt.Errorf("unhandled address type: %T",
×
4435
                                addr)
×
4436
                }
4437
        }
4438

4439
        return newAddresses, nil
×
4440
}
4441

4442
// upsertNodeAddresses updates the node's addresses in the database. This
4443
// includes deleting any existing addresses and inserting the new set of
4444
// addresses. The deletion is necessary since the ordering of the addresses may
4445
// change, and we need to ensure that the database reflects the latest set of
4446
// addresses so that at the time of reconstructing the node announcement, the
4447
// order is preserved and the signature over the message remains valid.
4448
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
4449
        addresses []net.Addr) error {
×
4450

×
4451
        // Delete any existing addresses for the node. This is required since
×
4452
        // even if the new set of addresses is the same, the ordering may have
×
4453
        // changed for a given address type.
×
4454
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
4455
        if err != nil {
×
4456
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
4457
                        nodeID, err)
×
4458
        }
×
4459

4460
        newAddresses, err := collectAddressRecords(addresses)
×
4461
        if err != nil {
×
4462
                return err
×
4463
        }
×
4464

4465
        // Any remaining entries in newAddresses are new addresses that need to
4466
        // be added to the database for the first time.
4467
        for addrType, addrList := range newAddresses {
×
4468
                for position, addr := range addrList {
×
4469
                        err := db.UpsertNodeAddress(
×
4470
                                ctx, sqlc.UpsertNodeAddressParams{
×
4471
                                        NodeID:   nodeID,
×
4472
                                        Type:     int16(addrType),
×
4473
                                        Address:  addr,
×
4474
                                        Position: int32(position),
×
4475
                                },
×
4476
                        )
×
4477
                        if err != nil {
×
4478
                                return fmt.Errorf("unable to insert "+
×
4479
                                        "node(%d) address(%v): %w", nodeID,
×
4480
                                        addr, err)
×
4481
                        }
×
4482
                }
4483
        }
4484

4485
        return nil
×
4486
}
4487

4488
// getNodeAddresses fetches the addresses for a node with the given DB ID.
4489
func getNodeAddresses(ctx context.Context, db SQLQueries, id int64) ([]net.Addr,
4490
        error) {
×
4491

×
4492
        // GetNodeAddresses ensures that the addresses for a given type are
×
4493
        // returned in the same order as they were inserted.
×
4494
        rows, err := db.GetNodeAddresses(ctx, id)
×
4495
        if err != nil {
×
4496
                return nil, err
×
4497
        }
×
4498

4499
        addresses := make([]net.Addr, 0, len(rows))
×
4500
        for _, row := range rows {
×
4501
                address := row.Address
×
4502

×
4503
                addr, err := parseAddress(dbAddressType(row.Type), address)
×
4504
                if err != nil {
×
4505
                        return nil, fmt.Errorf("unable to parse address "+
×
4506
                                "for node(%d): %v: %w", id, address, err)
×
4507
                }
×
4508

4509
                addresses = append(addresses, addr)
×
4510
        }
4511

4512
        // If we have no addresses, then we'll return nil instead of an
4513
        // empty slice.
4514
        if len(addresses) == 0 {
×
4515
                addresses = nil
×
4516
        }
×
4517

4518
        return addresses, nil
×
4519
}
4520

4521
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
4522
// database. This includes updating any existing types, inserting any new types,
4523
// and deleting any types that are no longer present.
4524
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
4525
        nodeID int64, extraFields map[uint64][]byte) error {
×
4526

×
4527
        // Get any existing extra signed fields for the node.
×
4528
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
4529
        if err != nil {
×
4530
                return err
×
4531
        }
×
4532

4533
        // Make a lookup map of the existing field types so that we can use it
4534
        // to keep track of any fields we should delete.
4535
        m := make(map[uint64]bool)
×
4536
        for _, field := range existingFields {
×
4537
                m[uint64(field.Type)] = true
×
4538
        }
×
4539

4540
        // For all the new fields, we'll upsert them and remove them from the
4541
        // map of existing fields.
4542
        for tlvType, value := range extraFields {
×
4543
                err = db.UpsertNodeExtraType(
×
4544
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
4545
                                NodeID: nodeID,
×
4546
                                Type:   int64(tlvType),
×
4547
                                Value:  value,
×
4548
                        },
×
4549
                )
×
4550
                if err != nil {
×
4551
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
4552
                                "signed field(%v): %w", nodeID, tlvType, err)
×
4553
                }
×
4554

4555
                // Remove the field from the map of existing fields if it was
4556
                // present.
4557
                delete(m, tlvType)
×
4558
        }
4559

4560
        // For all the fields that are left in the map of existing fields, we'll
4561
        // delete them as they are no longer present in the new set of fields.
4562
        for tlvType := range m {
×
4563
                err = db.DeleteExtraNodeType(
×
4564
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
4565
                                NodeID: nodeID,
×
4566
                                Type:   int64(tlvType),
×
4567
                        },
×
4568
                )
×
4569
                if err != nil {
×
4570
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
4571
                                "signed field(%v): %w", nodeID, tlvType, err)
×
4572
                }
×
4573
        }
4574

4575
        return nil
×
4576
}
4577

4578
// srcNodeInfo holds the information about the source node of the graph.
4579
type srcNodeInfo struct {
4580
        // id is the DB level ID of the source node entry in the "nodes" table.
4581
        id int64
4582

4583
        // pub is the public key of the source node.
4584
        pub route.Vertex
4585
}
4586

4587
// sourceNode returns the DB node ID and pub key of the source node for the
4588
// specified protocol version.
4589
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
4590
        version lnwire.GossipVersion) (int64, route.Vertex, error) {
×
4591

×
4592
        s.srcNodeMu.Lock()
×
4593
        defer s.srcNodeMu.Unlock()
×
4594

×
4595
        // If we already have the source node ID and pub key cached, then
×
4596
        // return them.
×
4597
        if info, ok := s.srcNodes[version]; ok {
×
4598
                return info.id, info.pub, nil
×
4599
        }
×
4600

4601
        var pubKey route.Vertex
×
4602

×
4603
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
4604
        if err != nil {
×
4605
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
4606
                        err)
×
4607
        }
×
4608

4609
        if len(nodes) == 0 {
×
4610
                return 0, pubKey, ErrSourceNodeNotSet
×
4611
        } else if len(nodes) > 1 {
×
4612
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
4613
                        "protocol %s found", version)
×
4614
        }
×
4615

4616
        copy(pubKey[:], nodes[0].PubKey)
×
4617

×
4618
        s.srcNodes[version] = &srcNodeInfo{
×
4619
                id:  nodes[0].NodeID,
×
4620
                pub: pubKey,
×
4621
        }
×
4622

×
4623
        return nodes[0].NodeID, pubKey, nil
×
4624
}
4625

4626
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
4627
// This then produces a map from TLV type to value. If the input is not a
4628
// valid TLV stream, then an error is returned.
4629
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
4630
        r := bytes.NewReader(data)
×
4631

×
4632
        tlvStream, err := tlv.NewStream()
×
4633
        if err != nil {
×
4634
                return nil, err
×
4635
        }
×
4636

4637
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
4638
        // pass it into the P2P decoding variant.
4639
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
4640
        if err != nil {
×
4641
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
4642
        }
×
4643
        if len(parsedTypes) == 0 {
×
4644
                return nil, nil
×
4645
        }
×
4646

4647
        records := make(map[uint64][]byte)
×
4648
        for k, v := range parsedTypes {
×
4649
                records[uint64(k)] = v
×
4650
        }
×
4651

4652
        return records, nil
×
4653
}
4654

4655
// insertChannel inserts a new channel record into the database.
4656
func insertChannel(ctx context.Context, db SQLQueries,
4657
        edge *models.ChannelEdgeInfo) error {
×
4658

×
4659
        v := edge.Version
×
4660

×
4661
        // Make sure that at least a "shell" entry for each node is present in
×
4662
        // the nodes table.
×
4663
        node1DBID, err := maybeCreateShellNode(ctx, db, v, edge.NodeKey1Bytes)
×
4664
        if err != nil {
×
4665
                return fmt.Errorf("unable to create shell node: %w", err)
×
4666
        }
×
4667

4668
        node2DBID, err := maybeCreateShellNode(ctx, db, v, edge.NodeKey2Bytes)
×
4669
        if err != nil {
×
4670
                return fmt.Errorf("unable to create shell node: %w", err)
×
4671
        }
×
4672

4673
        var capacity sql.NullInt64
×
4674
        if edge.Capacity != 0 {
×
4675
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
4676
        }
×
4677

4678
        createParams := sqlc.CreateChannelParams{
×
4679
                Version:  int16(v),
×
4680
                Scid:     channelIDToBytes(edge.ChannelID),
×
4681
                NodeID1:  node1DBID,
×
4682
                NodeID2:  node2DBID,
×
4683
                Outpoint: edge.ChannelPoint.String(),
×
4684
                Capacity: capacity,
×
4685
        }
×
4686
        edge.BitcoinKey1Bytes.WhenSome(func(vertex route.Vertex) {
×
4687
                createParams.BitcoinKey1 = vertex[:]
×
4688
        })
×
4689
        edge.BitcoinKey2Bytes.WhenSome(func(vertex route.Vertex) {
×
4690
                createParams.BitcoinKey2 = vertex[:]
×
4691
        })
×
4692
        edge.FundingScript.WhenSome(func(script []byte) {
×
4693
                createParams.FundingPkScript = script
×
4694
        })
×
4695
        edge.MerkleRootHash.WhenSome(func(hash chainhash.Hash) {
×
4696
                createParams.MerkleRootHash = hash[:]
×
4697
        })
×
4698

4699
        if edge.AuthProof != nil {
×
4700
                proof := edge.AuthProof
×
4701

×
4702
                createParams.Node1Signature = proof.NodeSig1()
×
4703
                createParams.Node2Signature = proof.NodeSig2()
×
4704
                createParams.Bitcoin1Signature = proof.BitcoinSig1()
×
4705
                createParams.Bitcoin2Signature = proof.BitcoinSig2()
×
4706
                createParams.Signature = proof.Sig()
×
4707
        }
×
4708

4709
        // Insert the new channel record.
4710
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
4711
        if err != nil {
×
4712
                return err
×
4713
        }
×
4714

4715
        // Insert any channel features.
4716
        for feature := range edge.Features.Features() {
×
4717
                err = db.InsertChannelFeature(
×
4718
                        ctx, sqlc.InsertChannelFeatureParams{
×
4719
                                ChannelID:  dbChanID,
×
4720
                                FeatureBit: int32(feature),
×
4721
                        },
×
4722
                )
×
4723
                if err != nil {
×
4724
                        return fmt.Errorf("unable to insert channel(%d) "+
×
4725
                                "feature(%v): %w", dbChanID, feature, err)
×
4726
                }
×
4727
        }
4728

4729
        // Finally, insert any extra TLV fields in the channel announcement.
4730
        extra := edge.ExtraSignedFields
×
4731
        if v == gossipV1 {
×
4732
                extra, err = marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
4733
                if err != nil {
×
4734
                        return fmt.Errorf("unable to marshal extra opaque "+
×
4735
                                "data: %w", err)
×
4736
                }
×
4737
        }
4738

4739
        for tlvType, value := range extra {
×
4740
                err := db.UpsertChannelExtraType(
×
4741
                        ctx, sqlc.UpsertChannelExtraTypeParams{
×
4742
                                ChannelID: dbChanID,
×
4743
                                Type:      int64(tlvType),
×
4744
                                Value:     value,
×
4745
                        },
×
4746
                )
×
4747
                if err != nil {
×
4748
                        return fmt.Errorf("unable to upsert channel(%d) "+
×
4749
                                "extra signed field(%v): %w", edge.ChannelID,
×
4750
                                tlvType, err)
×
4751
                }
×
4752
        }
4753

4754
        return nil
×
4755
}
4756

4757
// maybeCreateShellNode checks if a shell node entry exists for the
4758
// given public key. If it does not exist, then a new shell node entry is
4759
// created. The ID of the node is returned. A shell node only has a protocol
4760
// version and public key persisted.
4761
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
4762
        v lnwire.GossipVersion, pubKey route.Vertex) (int64, error) {
×
4763

×
4764
        dbNode, err := db.GetNodeByPubKey(
×
4765
                ctx, sqlc.GetNodeByPubKeyParams{
×
4766
                        PubKey:  pubKey[:],
×
4767
                        Version: int16(v),
×
4768
                },
×
4769
        )
×
4770
        // The node exists. Return the ID.
×
4771
        if err == nil {
×
4772
                return dbNode.ID, nil
×
4773
        } else if !errors.Is(err, sql.ErrNoRows) {
×
4774
                return 0, err
×
4775
        }
×
4776

4777
        // Otherwise, the node does not exist, so we create a shell entry for
4778
        // it.
4779
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
4780
                Version: int16(v),
×
4781
                PubKey:  pubKey[:],
×
4782
        })
×
4783
        if err != nil {
×
4784
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
4785
        }
×
4786

4787
        return id, nil
×
4788
}
4789

4790
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
4791
// the database. This includes deleting any existing types and then inserting
4792
// the new types.
4793
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
4794
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
4795

×
4796
        // Delete all existing extra signed fields for the channel policy.
×
4797
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
4798
        if err != nil {
×
4799
                return fmt.Errorf("unable to delete "+
×
4800
                        "existing policy extra signed fields for policy %d: %w",
×
4801
                        chanPolicyID, err)
×
4802
        }
×
4803

4804
        // Insert all new extra signed fields for the channel policy.
4805
        for tlvType, value := range extraFields {
×
4806
                err = db.UpsertChanPolicyExtraType(
×
4807
                        ctx, sqlc.UpsertChanPolicyExtraTypeParams{
×
4808
                                ChannelPolicyID: chanPolicyID,
×
4809
                                Type:            int64(tlvType),
×
4810
                                Value:           value,
×
4811
                        },
×
4812
                )
×
4813
                if err != nil {
×
4814
                        return fmt.Errorf("unable to insert "+
×
4815
                                "channel_policy(%d) extra signed field(%v): %w",
×
4816
                                chanPolicyID, tlvType, err)
×
4817
                }
×
4818
        }
4819

4820
        return nil
×
4821
}
4822

4823
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
4824
// provided dbChanRow and also fetches any other required information
4825
// to construct the edge info.
4826
func getAndBuildEdgeInfo(ctx context.Context, cfg *SQLStoreConfig,
4827
        db SQLQueries, dbChan sqlc.GraphChannel, node1,
4828
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
4829

×
4830
        data, err := batchLoadChannelData(
×
4831
                ctx, cfg.QueryCfg, db, []int64{dbChan.ID}, nil,
×
4832
        )
×
4833
        if err != nil {
×
4834
                return nil, fmt.Errorf("unable to batch load channel data: %w",
×
4835
                        err)
×
4836
        }
×
4837

4838
        return buildEdgeInfoWithBatchData(
×
4839
                cfg.ChainHash, dbChan, node1, node2, data,
×
4840
        )
×
4841
}
4842

4843
// buildEdgeInfoWithBatchData builds edge info using pre-loaded batch data.
4844
func buildEdgeInfoWithBatchData(chain chainhash.Hash,
4845
        dbChan sqlc.GraphChannel, node1, node2 route.Vertex,
4846
        batchData *batchChannelData) (*models.ChannelEdgeInfo, error) {
×
4847

×
4848
        v := lnwire.GossipVersion(dbChan.Version)
×
4849
        if !isKnownGossipVersion(v) {
×
4850
                return nil, fmt.Errorf("unknown channel version: %d", v)
×
4851
        }
×
4852

4853
        // Use pre-loaded features and extras types.
4854
        fv := lnwire.EmptyFeatureVector()
×
4855
        if features, exists := batchData.chanfeatures[dbChan.ID]; exists {
×
4856
                for _, bit := range features {
×
4857
                        fv.Set(lnwire.FeatureBit(bit))
×
4858
                }
×
4859
        }
4860

4861
        var extras map[uint64][]byte
×
4862
        channelExtras, exists := batchData.chanExtraTypes[dbChan.ID]
×
4863
        if exists {
×
4864
                extras = channelExtras
×
4865
        } else {
×
4866
                extras = make(map[uint64][]byte)
×
4867
        }
×
4868

4869
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
4870
        if err != nil {
×
4871
                return nil, err
×
4872
        }
×
4873

4874
        // Build the appropriate channel based on version.
4875
        var channel *models.ChannelEdgeInfo
×
4876
        switch v {
×
4877
        case gossipV1:
×
4878
                // For v1, serialize extras into ExtraOpaqueData.
×
4879
                recs, err := lnwire.CustomRecords(extras).Serialize()
×
4880
                if err != nil {
×
4881
                        return nil, fmt.Errorf("unable to serialize extra "+
×
4882
                                "signed fields: %w", err)
×
4883
                }
×
4884
                if recs == nil {
×
4885
                        recs = make([]byte, 0)
×
4886
                }
×
4887

4888
                // Bitcoin keys are required for v1.
4889
                btcKey1, err := route.NewVertexFromBytes(dbChan.BitcoinKey1)
×
4890
                if err != nil {
×
4891
                        return nil, err
×
4892
                }
×
4893
                btcKey2, err := route.NewVertexFromBytes(dbChan.BitcoinKey2)
×
4894
                if err != nil {
×
4895
                        return nil, err
×
4896
                }
×
4897

4898
                channel, err = models.NewV1Channel(
×
4899
                        byteOrder.Uint64(dbChan.Scid), chain, node1, node2,
×
4900
                        &models.ChannelV1Fields{
×
4901
                                BitcoinKey1Bytes: btcKey1,
×
4902
                                BitcoinKey2Bytes: btcKey2,
×
4903
                                ExtraOpaqueData:  recs,
×
4904
                        },
×
4905
                        models.WithChannelPoint(*op),
×
4906
                        models.WithCapacity(
×
4907
                                btcutil.Amount(dbChan.Capacity.Int64),
×
4908
                        ),
×
4909
                        models.WithFeatures(fv.RawFeatureVector),
×
4910
                )
×
4911
                if err != nil {
×
4912
                        return nil, err
×
4913
                }
×
4914

4915
                // For v1 channels, attach the auth proof if all four
4916
                // signatures are present.
4917
                if len(dbChan.Bitcoin1Signature) > 0 {
×
4918
                        channel.AuthProof = models.NewV1ChannelAuthProof(
×
4919
                                dbChan.Node1Signature,
×
4920
                                dbChan.Node2Signature,
×
4921
                                dbChan.Bitcoin1Signature,
×
4922
                                dbChan.Bitcoin2Signature,
×
4923
                        )
×
4924
                }
×
4925

4926
        case gossipV2:
×
4927
                v2Fields := &models.ChannelV2Fields{
×
4928
                        ExtraSignedFields: extras,
×
4929
                }
×
4930

×
4931
                // For v2, bitcoin keys are optional.
×
4932
                if len(dbChan.BitcoinKey1) > 0 {
×
4933
                        btcKey1, err := route.NewVertexFromBytes(
×
4934
                                dbChan.BitcoinKey1,
×
4935
                        )
×
4936
                        if err != nil {
×
4937
                                return nil, err
×
4938
                        }
×
4939
                        v2Fields.BitcoinKey1Bytes = fn.Some(btcKey1)
×
4940
                }
4941
                if len(dbChan.BitcoinKey2) > 0 {
×
4942
                        btcKey2, err := route.NewVertexFromBytes(
×
4943
                                dbChan.BitcoinKey2,
×
4944
                        )
×
4945
                        if err != nil {
×
4946
                                return nil, err
×
4947
                        }
×
4948
                        v2Fields.BitcoinKey2Bytes = fn.Some(btcKey2)
×
4949
                }
4950

4951
                // Parse funding script if present.
4952
                if len(dbChan.FundingPkScript) > 0 {
×
4953
                        v2Fields.FundingScript = fn.Some(dbChan.FundingPkScript)
×
4954
                }
×
4955

4956
                // Parse merkle root hash if present.
4957
                if len(dbChan.MerkleRootHash) > 0 {
×
4958
                        var hash chainhash.Hash
×
4959
                        copy(hash[:], dbChan.MerkleRootHash)
×
4960
                        v2Fields.MerkleRootHash = fn.Some(hash)
×
4961
                }
×
4962

4963
                opts := []models.EdgeModifier{
×
4964
                        models.WithChannelPoint(*op),
×
4965
                        models.WithCapacity(btcutil.Amount(
×
4966
                                dbChan.Capacity.Int64,
×
4967
                        )),
×
4968
                        models.WithFeatures(fv.RawFeatureVector),
×
4969
                }
×
4970

×
4971
                // For v2 channels, attach the auth proof if the signature is
×
4972
                // present.
×
4973
                if len(dbChan.Signature) > 0 {
×
4974
                        proof := models.NewV2ChannelAuthProof(dbChan.Signature)
×
4975
                        opts = append(opts, models.WithChanProof(proof))
×
4976
                }
×
4977

4978
                channel, err = models.NewV2Channel(
×
4979
                        byteOrder.Uint64(dbChan.Scid), chain, node1, node2,
×
4980
                        v2Fields, opts...,
×
4981
                )
×
4982
                if err != nil {
×
4983
                        return nil, err
×
4984
                }
×
4985

4986
        default:
×
4987
                return nil, fmt.Errorf("unsupported channel version: %d", v)
×
4988
        }
4989

4990
        return channel, nil
×
4991
}
4992

4993
// buildNodeVertices is a helper that converts raw node public keys
4994
// into route.Vertex instances.
4995
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4996
        route.Vertex, error) {
×
4997

×
4998
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4999
        if err != nil {
×
5000
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
5001
                        "create vertex from node1 pubkey: %w", err)
×
5002
        }
×
5003

5004
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
5005
        if err != nil {
×
5006
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
5007
                        "create vertex from node2 pubkey: %w", err)
×
5008
        }
×
5009

5010
        return node1Vertex, node2Vertex, nil
×
5011
}
5012

5013
// getAndBuildChanPolicies uses the given sqlc.GraphChannelPolicy and also
5014
// retrieves all the extra info required to build the complete
5015
// models.ChannelEdgePolicy types. It returns two policies, which may be nil if
5016
// the provided sqlc.GraphChannelPolicy records are nil.
5017
func getAndBuildChanPolicies(ctx context.Context, cfg *sqldb.QueryConfig,
5018
        db SQLQueries, dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
5019
        channelID uint64, node1, node2 route.Vertex) (*models.ChannelEdgePolicy,
5020
        *models.ChannelEdgePolicy, error) {
×
5021

×
5022
        if dbPol1 == nil && dbPol2 == nil {
×
5023
                return nil, nil, nil
×
5024
        }
×
5025

5026
        if dbPol1 != nil &&
×
5027
                !isKnownGossipVersion(lnwire.GossipVersion(dbPol1.Version)) {
×
5028

×
5029
                return nil, nil, fmt.Errorf("unsupported policy1 version: %d",
×
5030
                        dbPol1.Version)
×
5031
        }
×
5032

5033
        if dbPol2 != nil &&
×
5034
                !isKnownGossipVersion(lnwire.GossipVersion(dbPol2.Version)) {
×
5035

×
5036
                return nil, nil, fmt.Errorf("unsupported policy2 version: %d",
×
5037
                        dbPol2.Version)
×
5038
        }
×
5039

5040
        var policyIDs = make([]int64, 0, 2)
×
5041
        if dbPol1 != nil {
×
5042
                policyIDs = append(policyIDs, dbPol1.ID)
×
5043
        }
×
5044
        if dbPol2 != nil {
×
5045
                policyIDs = append(policyIDs, dbPol2.ID)
×
5046
        }
×
5047

5048
        batchData, err := batchLoadChannelData(ctx, cfg, db, nil, policyIDs)
×
5049
        if err != nil {
×
5050
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
5051
                        "data: %w", err)
×
5052
        }
×
5053

5054
        pol1, err := buildChanPolicyWithBatchData(
×
5055
                true, dbPol1, channelID, node2, batchData,
×
5056
        )
×
5057
        if err != nil {
×
5058
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
5059
        }
×
5060

5061
        pol2, err := buildChanPolicyWithBatchData(
×
5062
                false, dbPol2, channelID, node1, batchData,
×
5063
        )
×
5064
        if err != nil {
×
5065
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
5066
        }
×
5067

5068
        return pol1, pol2, nil
×
5069
}
5070

5071
// buildCachedChanPolicies builds models.CachedEdgePolicy instances from the
5072
// provided sqlc.GraphChannelPolicy objects. If a provided policy is nil,
5073
// then nil is returned for it.
5074
func buildCachedChanPolicies(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
5075
        channelID uint64, node1, node2 route.Vertex) (*models.CachedEdgePolicy,
5076
        *models.CachedEdgePolicy, error) {
×
5077

×
5078
        var p1, p2 *models.CachedEdgePolicy
×
5079
        if dbPol1 != nil {
×
5080
                policy1, err := buildChanPolicy(
×
5081
                        true, *dbPol1, channelID, nil, node2,
×
5082
                )
×
5083
                if err != nil {
×
5084
                        return nil, nil, err
×
5085
                }
×
5086

5087
                p1 = models.NewCachedPolicy(policy1)
×
5088
        }
5089
        if dbPol2 != nil {
×
5090
                policy2, err := buildChanPolicy(
×
5091
                        false, *dbPol2, channelID, nil, node1,
×
5092
                )
×
5093
                if err != nil {
×
5094
                        return nil, nil, err
×
5095
                }
×
5096

5097
                p2 = models.NewCachedPolicy(policy2)
×
5098
        }
5099

5100
        return p1, p2, nil
×
5101
}
5102

5103
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
5104
// provided sqlc.GraphChannelPolicy and other required information.
5105
func buildChanPolicy(isNode1 bool, dbPolicy sqlc.GraphChannelPolicy,
5106
        channelID uint64, extras map[uint64][]byte,
5107
        toNode route.Vertex) (*models.ChannelEdgePolicy, error) {
×
5108

×
5109
        var inboundFee fn.Option[lnwire.Fee]
×
5110
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
5111
                dbPolicy.InboundBaseFeeMsat.Valid {
×
5112

×
5113
                inboundFee = fn.Some(lnwire.Fee{
×
5114
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
5115
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
5116
                })
×
5117
        }
×
5118

5119
        p := &models.ChannelEdgePolicy{
×
5120
                Version:       lnwire.GossipVersion(dbPolicy.Version),
×
5121
                SigBytes:      dbPolicy.Signature,
×
5122
                ChannelID:     channelID,
×
5123
                SecondPeer:    !isNode1,
×
5124
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
5125
                MinHTLC: lnwire.MilliSatoshi(
×
5126
                        dbPolicy.MinHtlcMsat,
×
5127
                ),
×
5128
                MaxHTLC: lnwire.MilliSatoshi(
×
5129
                        dbPolicy.MaxHtlcMsat.Int64,
×
5130
                ),
×
5131
                FeeBaseMSat: lnwire.MilliSatoshi(
×
5132
                        dbPolicy.BaseFeeMsat,
×
5133
                ),
×
5134
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
5135
                ToNode:                    toNode,
×
5136
                InboundFee:                inboundFee,
×
5137
        }
×
5138

×
5139
        if p.Version != gossipV2 {
×
5140
                recs, err := lnwire.CustomRecords(extras).Serialize()
×
5141
                if err != nil {
×
5142
                        return nil, fmt.Errorf("unable to serialize extra "+
×
5143
                                "signed fields: %w", err)
×
5144
                }
×
5145

5146
                p.ExtraOpaqueData = recs
×
5147
                p.LastUpdate = time.Unix(dbPolicy.LastUpdate.Int64, 0)
×
5148
                //nolint:ll
×
5149
                p.MessageFlags = sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags](
×
5150
                        dbPolicy.MessageFlags,
×
5151
                )
×
5152
                //nolint:ll
×
5153
                p.ChannelFlags = sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags](
×
5154
                        dbPolicy.ChannelFlags,
×
5155
                )
×
5156
        } else {
×
5157
                if dbPolicy.BlockHeight.Valid {
×
5158
                        p.LastBlockHeight = uint32(
×
5159
                                dbPolicy.BlockHeight.Int64,
×
5160
                        )
×
5161
                }
×
5162

5163
                //nolint:ll
5164
                p.DisableFlags = sqldb.ExtractSqlInt16[lnwire.ChanUpdateDisableFlags](
×
5165
                        dbPolicy.DisableFlags,
×
5166
                )
×
5167
                p.ExtraSignedFields = extras
×
5168
        }
5169

5170
        return p, nil
×
5171
}
5172

5173
// extractChannelPolicies extracts the sqlc.GraphChannelPolicy records from the give
5174
// row which is expected to be a sqlc type that contains channel policy
5175
// information. It returns two policies, which may be nil if the policy
5176
// information is not present in the row.
5177
//
5178
//nolint:ll,dupl,funlen
5179
func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy,
5180
        *sqlc.GraphChannelPolicy, error) {
×
5181

×
5182
        var policy1, policy2 *sqlc.GraphChannelPolicy
×
5183
        switch r := row.(type) {
×
5184
        case sqlc.ListChannelsWithPoliciesForCachePaginatedRow:
×
5185
                if r.Policy1Timelock.Valid {
×
5186
                        policy1 = &sqlc.GraphChannelPolicy{
×
5187
                                Version:                 r.Policy1Version.Int16,
×
5188
                                Timelock:                r.Policy1Timelock.Int32,
×
5189
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
5190
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
5191
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
5192
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
5193
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
5194
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
5195
                                Disabled:                r.Policy1Disabled,
×
5196
                                MessageFlags:            r.Policy1MessageFlags,
×
5197
                                ChannelFlags:            r.Policy1ChannelFlags,
×
5198
                                BlockHeight:             r.Policy1BlockHeight,
×
5199
                                DisableFlags:            r.Policy1DisableFlags,
×
5200
                        }
×
5201
                }
×
5202
                if r.Policy2Timelock.Valid {
×
5203
                        policy2 = &sqlc.GraphChannelPolicy{
×
5204
                                Version:                 r.Policy2Version.Int16,
×
5205
                                Timelock:                r.Policy2Timelock.Int32,
×
5206
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
5207
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
5208
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
5209
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
5210
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
5211
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
5212
                                Disabled:                r.Policy2Disabled,
×
5213
                                MessageFlags:            r.Policy2MessageFlags,
×
5214
                                ChannelFlags:            r.Policy2ChannelFlags,
×
5215
                                BlockHeight:             r.Policy2BlockHeight,
×
5216
                                DisableFlags:            r.Policy2DisableFlags,
×
5217
                        }
×
5218
                }
×
5219

5220
                return policy1, policy2, nil
×
5221

5222
        case sqlc.GetChannelsBySCIDWithPoliciesRow:
×
5223
                if r.Policy1ID.Valid {
×
5224
                        policy1 = &sqlc.GraphChannelPolicy{
×
5225
                                ID:                      r.Policy1ID.Int64,
×
5226
                                Version:                 r.Policy1Version.Int16,
×
5227
                                ChannelID:               r.GraphChannel.ID,
×
5228
                                NodeID:                  r.Policy1NodeID.Int64,
×
5229
                                Timelock:                r.Policy1Timelock.Int32,
×
5230
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
5231
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
5232
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
5233
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
5234
                                LastUpdate:              r.Policy1LastUpdate,
×
5235
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
5236
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
5237
                                Disabled:                r.Policy1Disabled,
×
5238
                                MessageFlags:            r.Policy1MessageFlags,
×
5239
                                ChannelFlags:            r.Policy1ChannelFlags,
×
5240
                                Signature:               r.Policy1Signature,
×
5241
                                BlockHeight:             r.Policy1BlockHeight,
×
5242
                                DisableFlags:            r.Policy1DisableFlags,
×
5243
                        }
×
5244
                }
×
5245
                if r.Policy2ID.Valid {
×
5246
                        policy2 = &sqlc.GraphChannelPolicy{
×
5247
                                ID:                      r.Policy2ID.Int64,
×
5248
                                Version:                 r.Policy2Version.Int16,
×
5249
                                ChannelID:               r.GraphChannel.ID,
×
5250
                                NodeID:                  r.Policy2NodeID.Int64,
×
5251
                                Timelock:                r.Policy2Timelock.Int32,
×
5252
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
5253
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
5254
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
5255
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
5256
                                LastUpdate:              r.Policy2LastUpdate,
×
5257
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
5258
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
5259
                                Disabled:                r.Policy2Disabled,
×
5260
                                MessageFlags:            r.Policy2MessageFlags,
×
5261
                                ChannelFlags:            r.Policy2ChannelFlags,
×
5262
                                Signature:               r.Policy2Signature,
×
5263
                                BlockHeight:             r.Policy2BlockHeight,
×
5264
                                DisableFlags:            r.Policy2DisableFlags,
×
5265
                        }
×
5266
                }
×
5267

5268
                return policy1, policy2, nil
×
5269

5270
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
5271
                if r.Policy1ID.Valid {
×
5272
                        policy1 = &sqlc.GraphChannelPolicy{
×
5273
                                ID:                      r.Policy1ID.Int64,
×
5274
                                Version:                 r.Policy1Version.Int16,
×
5275
                                ChannelID:               r.GraphChannel.ID,
×
5276
                                NodeID:                  r.Policy1NodeID.Int64,
×
5277
                                Timelock:                r.Policy1Timelock.Int32,
×
5278
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
5279
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
5280
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
5281
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
5282
                                LastUpdate:              r.Policy1LastUpdate,
×
5283
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
5284
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
5285
                                Disabled:                r.Policy1Disabled,
×
5286
                                MessageFlags:            r.Policy1MessageFlags,
×
5287
                                ChannelFlags:            r.Policy1ChannelFlags,
×
5288
                                Signature:               r.Policy1Signature,
×
5289
                                BlockHeight:             r.Policy1BlockHeight,
×
5290
                                DisableFlags:            r.Policy1DisableFlags,
×
5291
                        }
×
5292
                }
×
5293
                if r.Policy2ID.Valid {
×
5294
                        policy2 = &sqlc.GraphChannelPolicy{
×
5295
                                ID:                      r.Policy2ID.Int64,
×
5296
                                Version:                 r.Policy2Version.Int16,
×
5297
                                ChannelID:               r.GraphChannel.ID,
×
5298
                                NodeID:                  r.Policy2NodeID.Int64,
×
5299
                                Timelock:                r.Policy2Timelock.Int32,
×
5300
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
5301
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
5302
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
5303
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
5304
                                LastUpdate:              r.Policy2LastUpdate,
×
5305
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
5306
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
5307
                                Disabled:                r.Policy2Disabled,
×
5308
                                MessageFlags:            r.Policy2MessageFlags,
×
5309
                                ChannelFlags:            r.Policy2ChannelFlags,
×
5310
                                Signature:               r.Policy2Signature,
×
5311
                                BlockHeight:             r.Policy2BlockHeight,
×
5312
                                DisableFlags:            r.Policy2DisableFlags,
×
5313
                        }
×
5314
                }
×
5315

5316
                return policy1, policy2, nil
×
5317

5318
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
5319
                if r.Policy1ID.Valid {
×
5320
                        policy1 = &sqlc.GraphChannelPolicy{
×
5321
                                ID:                      r.Policy1ID.Int64,
×
5322
                                Version:                 r.Policy1Version.Int16,
×
5323
                                ChannelID:               r.GraphChannel.ID,
×
5324
                                NodeID:                  r.Policy1NodeID.Int64,
×
5325
                                Timelock:                r.Policy1Timelock.Int32,
×
5326
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
5327
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
5328
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
5329
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
5330
                                LastUpdate:              r.Policy1LastUpdate,
×
5331
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
5332
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
5333
                                Disabled:                r.Policy1Disabled,
×
5334
                                MessageFlags:            r.Policy1MessageFlags,
×
5335
                                ChannelFlags:            r.Policy1ChannelFlags,
×
5336
                                Signature:               r.Policy1Signature,
×
5337
                                BlockHeight:             r.Policy1BlockHeight,
×
5338
                                DisableFlags:            r.Policy1DisableFlags,
×
5339
                        }
×
5340
                }
×
5341
                if r.Policy2ID.Valid {
×
5342
                        policy2 = &sqlc.GraphChannelPolicy{
×
5343
                                ID:                      r.Policy2ID.Int64,
×
5344
                                Version:                 r.Policy2Version.Int16,
×
5345
                                ChannelID:               r.GraphChannel.ID,
×
5346
                                NodeID:                  r.Policy2NodeID.Int64,
×
5347
                                Timelock:                r.Policy2Timelock.Int32,
×
5348
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
5349
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
5350
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
5351
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
5352
                                LastUpdate:              r.Policy2LastUpdate,
×
5353
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
5354
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
5355
                                Disabled:                r.Policy2Disabled,
×
5356
                                MessageFlags:            r.Policy2MessageFlags,
×
5357
                                ChannelFlags:            r.Policy2ChannelFlags,
×
5358
                                Signature:               r.Policy2Signature,
×
5359
                                BlockHeight:             r.Policy2BlockHeight,
×
5360
                                DisableFlags:            r.Policy2DisableFlags,
×
5361
                        }
×
5362
                }
×
5363

5364
                return policy1, policy2, nil
×
5365

5366
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
5367
                if r.Policy1ID.Valid {
×
5368
                        policy1 = &sqlc.GraphChannelPolicy{
×
5369
                                ID:                      r.Policy1ID.Int64,
×
5370
                                Version:                 r.Policy1Version.Int16,
×
5371
                                ChannelID:               r.GraphChannel.ID,
×
5372
                                NodeID:                  r.Policy1NodeID.Int64,
×
5373
                                Timelock:                r.Policy1Timelock.Int32,
×
5374
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
5375
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
5376
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
5377
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
5378
                                LastUpdate:              r.Policy1LastUpdate,
×
5379
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
5380
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
5381
                                Disabled:                r.Policy1Disabled,
×
5382
                                MessageFlags:            r.Policy1MessageFlags,
×
5383
                                ChannelFlags:            r.Policy1ChannelFlags,
×
5384
                                Signature:               r.Policy1Signature,
×
5385
                                BlockHeight:             r.Policy1BlockHeight,
×
5386
                                DisableFlags:            r.Policy1DisableFlags,
×
5387
                        }
×
5388
                }
×
5389
                if r.Policy2ID.Valid {
×
5390
                        policy2 = &sqlc.GraphChannelPolicy{
×
5391
                                ID:                      r.Policy2ID.Int64,
×
5392
                                Version:                 r.Policy2Version.Int16,
×
5393
                                ChannelID:               r.GraphChannel.ID,
×
5394
                                NodeID:                  r.Policy2NodeID.Int64,
×
5395
                                Timelock:                r.Policy2Timelock.Int32,
×
5396
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
5397
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
5398
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
5399
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
5400
                                LastUpdate:              r.Policy2LastUpdate,
×
5401
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
5402
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
5403
                                Disabled:                r.Policy2Disabled,
×
5404
                                MessageFlags:            r.Policy2MessageFlags,
×
5405
                                ChannelFlags:            r.Policy2ChannelFlags,
×
5406
                                Signature:               r.Policy2Signature,
×
5407
                                BlockHeight:             r.Policy2BlockHeight,
×
5408
                                DisableFlags:            r.Policy2DisableFlags,
×
5409
                        }
×
5410
                }
×
5411

5412
                return policy1, policy2, nil
×
5413

5414
        case sqlc.ListChannelsForNodeIDsRow:
×
5415
                if r.Policy1ID.Valid {
×
5416
                        policy1 = &sqlc.GraphChannelPolicy{
×
5417
                                ID:                      r.Policy1ID.Int64,
×
5418
                                Version:                 r.Policy1Version.Int16,
×
5419
                                ChannelID:               r.GraphChannel.ID,
×
5420
                                NodeID:                  r.Policy1NodeID.Int64,
×
5421
                                Timelock:                r.Policy1Timelock.Int32,
×
5422
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
5423
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
5424
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
5425
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
5426
                                LastUpdate:              r.Policy1LastUpdate,
×
5427
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
5428
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
5429
                                Disabled:                r.Policy1Disabled,
×
5430
                                MessageFlags:            r.Policy1MessageFlags,
×
5431
                                ChannelFlags:            r.Policy1ChannelFlags,
×
5432
                                Signature:               r.Policy1Signature,
×
5433
                                BlockHeight:             r.Policy1BlockHeight,
×
5434
                                DisableFlags:            r.Policy1DisableFlags,
×
5435
                        }
×
5436
                }
×
5437
                if r.Policy2ID.Valid {
×
5438
                        policy2 = &sqlc.GraphChannelPolicy{
×
5439
                                ID:                      r.Policy2ID.Int64,
×
5440
                                Version:                 r.Policy2Version.Int16,
×
5441
                                ChannelID:               r.GraphChannel.ID,
×
5442
                                NodeID:                  r.Policy2NodeID.Int64,
×
5443
                                Timelock:                r.Policy2Timelock.Int32,
×
5444
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
5445
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
5446
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
5447
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
5448
                                LastUpdate:              r.Policy2LastUpdate,
×
5449
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
5450
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
5451
                                Disabled:                r.Policy2Disabled,
×
5452
                                MessageFlags:            r.Policy2MessageFlags,
×
5453
                                ChannelFlags:            r.Policy2ChannelFlags,
×
5454
                                Signature:               r.Policy2Signature,
×
5455
                                BlockHeight:             r.Policy2BlockHeight,
×
5456
                                DisableFlags:            r.Policy2DisableFlags,
×
5457
                        }
×
5458
                }
×
5459

5460
                return policy1, policy2, nil
×
5461

5462
        case sqlc.ListChannelsByNodeIDRow:
×
5463
                if r.Policy1ID.Valid {
×
5464
                        policy1 = &sqlc.GraphChannelPolicy{
×
5465
                                ID:                      r.Policy1ID.Int64,
×
5466
                                Version:                 r.Policy1Version.Int16,
×
5467
                                ChannelID:               r.GraphChannel.ID,
×
5468
                                NodeID:                  r.Policy1NodeID.Int64,
×
5469
                                Timelock:                r.Policy1Timelock.Int32,
×
5470
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
5471
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
5472
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
5473
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
5474
                                LastUpdate:              r.Policy1LastUpdate,
×
5475
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
5476
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
5477
                                Disabled:                r.Policy1Disabled,
×
5478
                                MessageFlags:            r.Policy1MessageFlags,
×
5479
                                ChannelFlags:            r.Policy1ChannelFlags,
×
5480
                                Signature:               r.Policy1Signature,
×
5481
                                BlockHeight:             r.Policy1BlockHeight,
×
5482
                                DisableFlags:            r.Policy1DisableFlags,
×
5483
                        }
×
5484
                }
×
5485
                if r.Policy2ID.Valid {
×
5486
                        policy2 = &sqlc.GraphChannelPolicy{
×
5487
                                ID:                      r.Policy2ID.Int64,
×
5488
                                Version:                 r.Policy2Version.Int16,
×
5489
                                ChannelID:               r.GraphChannel.ID,
×
5490
                                NodeID:                  r.Policy2NodeID.Int64,
×
5491
                                Timelock:                r.Policy2Timelock.Int32,
×
5492
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
5493
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
5494
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
5495
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
5496
                                LastUpdate:              r.Policy2LastUpdate,
×
5497
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
5498
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
5499
                                Disabled:                r.Policy2Disabled,
×
5500
                                MessageFlags:            r.Policy2MessageFlags,
×
5501
                                ChannelFlags:            r.Policy2ChannelFlags,
×
5502
                                Signature:               r.Policy2Signature,
×
5503
                                BlockHeight:             r.Policy2BlockHeight,
×
5504
                                DisableFlags:            r.Policy2DisableFlags,
×
5505
                        }
×
5506
                }
×
5507

5508
                return policy1, policy2, nil
×
5509

5510
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
5511
                if r.Policy1ID.Valid {
×
5512
                        policy1 = &sqlc.GraphChannelPolicy{
×
5513
                                ID:                      r.Policy1ID.Int64,
×
5514
                                Version:                 r.Policy1Version.Int16,
×
5515
                                ChannelID:               r.GraphChannel.ID,
×
5516
                                NodeID:                  r.Policy1NodeID.Int64,
×
5517
                                Timelock:                r.Policy1Timelock.Int32,
×
5518
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
5519
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
5520
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
5521
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
5522
                                LastUpdate:              r.Policy1LastUpdate,
×
5523
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
5524
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
5525
                                Disabled:                r.Policy1Disabled,
×
5526
                                MessageFlags:            r.Policy1MessageFlags,
×
5527
                                ChannelFlags:            r.Policy1ChannelFlags,
×
5528
                                Signature:               r.Policy1Signature,
×
5529
                                BlockHeight:             r.Policy1BlockHeight,
×
5530
                                DisableFlags:            r.Policy1DisableFlags,
×
5531
                        }
×
5532
                }
×
5533
                if r.Policy2ID.Valid {
×
5534
                        policy2 = &sqlc.GraphChannelPolicy{
×
5535
                                ID:                      r.Policy2ID.Int64,
×
5536
                                Version:                 r.Policy2Version.Int16,
×
5537
                                ChannelID:               r.GraphChannel.ID,
×
5538
                                NodeID:                  r.Policy2NodeID.Int64,
×
5539
                                Timelock:                r.Policy2Timelock.Int32,
×
5540
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
5541
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
5542
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
5543
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
5544
                                LastUpdate:              r.Policy2LastUpdate,
×
5545
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
5546
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
5547
                                Disabled:                r.Policy2Disabled,
×
5548
                                MessageFlags:            r.Policy2MessageFlags,
×
5549
                                ChannelFlags:            r.Policy2ChannelFlags,
×
5550
                                Signature:               r.Policy2Signature,
×
5551
                                BlockHeight:             r.Policy2BlockHeight,
×
5552
                                DisableFlags:            r.Policy2DisableFlags,
×
5553
                        }
×
5554
                }
×
5555

5556
                return policy1, policy2, nil
×
5557

5558
        case sqlc.GetChannelsByIDsRow:
×
5559
                if r.Policy1ID.Valid {
×
5560
                        policy1 = &sqlc.GraphChannelPolicy{
×
5561
                                ID:                      r.Policy1ID.Int64,
×
5562
                                Version:                 r.Policy1Version.Int16,
×
5563
                                ChannelID:               r.GraphChannel.ID,
×
5564
                                NodeID:                  r.Policy1NodeID.Int64,
×
5565
                                Timelock:                r.Policy1Timelock.Int32,
×
5566
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
5567
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
5568
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
5569
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
5570
                                LastUpdate:              r.Policy1LastUpdate,
×
5571
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
5572
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
5573
                                Disabled:                r.Policy1Disabled,
×
5574
                                MessageFlags:            r.Policy1MessageFlags,
×
5575
                                ChannelFlags:            r.Policy1ChannelFlags,
×
5576
                                Signature:               r.Policy1Signature,
×
5577
                                BlockHeight:             r.Policy1BlockHeight,
×
5578
                                DisableFlags:            r.Policy1DisableFlags,
×
5579
                        }
×
5580
                }
×
5581
                if r.Policy2ID.Valid {
×
5582
                        policy2 = &sqlc.GraphChannelPolicy{
×
5583
                                ID:                      r.Policy2ID.Int64,
×
5584
                                Version:                 r.Policy2Version.Int16,
×
5585
                                ChannelID:               r.GraphChannel.ID,
×
5586
                                NodeID:                  r.Policy2NodeID.Int64,
×
5587
                                Timelock:                r.Policy2Timelock.Int32,
×
5588
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
5589
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
5590
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
5591
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
5592
                                LastUpdate:              r.Policy2LastUpdate,
×
5593
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
5594
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
5595
                                Disabled:                r.Policy2Disabled,
×
5596
                                MessageFlags:            r.Policy2MessageFlags,
×
5597
                                ChannelFlags:            r.Policy2ChannelFlags,
×
5598
                                Signature:               r.Policy2Signature,
×
5599
                                BlockHeight:             r.Policy2BlockHeight,
×
5600
                                DisableFlags:            r.Policy2DisableFlags,
×
5601
                        }
×
5602
                }
×
5603

5604
                return policy1, policy2, nil
×
5605

5606
        default:
×
5607
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
5608
                        "extractChannelPolicies: %T", r)
×
5609
        }
5610
}
5611

5612
// channelIDToBytes converts a channel ID (SCID) to a byte array
5613
// representation.
5614
func channelIDToBytes(channelID uint64) []byte {
×
5615
        var chanIDB [8]byte
×
5616
        byteOrder.PutUint64(chanIDB[:], channelID)
×
5617

×
5618
        return chanIDB[:]
×
5619
}
×
5620

5621
// buildNodeAddresses converts a slice of nodeAddress into a slice of net.Addr.
5622
func buildNodeAddresses(addresses []nodeAddress) ([]net.Addr, error) {
×
5623
        if len(addresses) == 0 {
×
5624
                return nil, nil
×
5625
        }
×
5626

5627
        result := make([]net.Addr, 0, len(addresses))
×
5628
        for _, addr := range addresses {
×
5629
                netAddr, err := parseAddress(addr.addrType, addr.address)
×
5630
                if err != nil {
×
5631
                        return nil, fmt.Errorf("unable to parse address %s "+
×
5632
                                "of type %d: %w", addr.address, addr.addrType,
×
5633
                                err)
×
5634
                }
×
5635
                if netAddr != nil {
×
5636
                        result = append(result, netAddr)
×
5637
                }
×
5638
        }
5639

5640
        // If we have no valid addresses, return nil instead of empty slice.
5641
        if len(result) == 0 {
×
5642
                return nil, nil
×
5643
        }
×
5644

5645
        return result, nil
×
5646
}
5647

5648
// parseAddress parses the given address string based on the address type
5649
// and returns a net.Addr instance. It supports IPv4, IPv6, Tor v2, Tor v3,
5650
// and opaque addresses.
5651
func parseAddress(addrType dbAddressType, address string) (net.Addr, error) {
×
5652
        switch addrType {
×
5653
        case addressTypeIPv4:
×
5654
                tcp, err := net.ResolveTCPAddr("tcp4", address)
×
5655
                if err != nil {
×
5656
                        return nil, err
×
5657
                }
×
5658

5659
                tcp.IP = tcp.IP.To4()
×
5660

×
5661
                return tcp, nil
×
5662

5663
        case addressTypeIPv6:
×
5664
                tcp, err := net.ResolveTCPAddr("tcp6", address)
×
5665
                if err != nil {
×
5666
                        return nil, err
×
5667
                }
×
5668

5669
                return tcp, nil
×
5670

5671
        case addressTypeTorV3, addressTypeTorV2:
×
5672
                service, portStr, err := net.SplitHostPort(address)
×
5673
                if err != nil {
×
5674
                        return nil, fmt.Errorf("unable to split tor "+
×
5675
                                "address: %v", address)
×
5676
                }
×
5677

5678
                port, err := strconv.Atoi(portStr)
×
5679
                if err != nil {
×
5680
                        return nil, err
×
5681
                }
×
5682

5683
                return &tor.OnionAddr{
×
5684
                        OnionService: service,
×
5685
                        Port:         port,
×
5686
                }, nil
×
5687

5688
        case addressTypeDNS:
×
5689
                hostname, portStr, err := net.SplitHostPort(address)
×
5690
                if err != nil {
×
5691
                        return nil, fmt.Errorf("unable to split DNS "+
×
5692
                                "address: %v", address)
×
5693
                }
×
5694

5695
                port, err := strconv.Atoi(portStr)
×
5696
                if err != nil {
×
5697
                        return nil, err
×
5698
                }
×
5699

5700
                return &lnwire.DNSAddress{
×
5701
                        Hostname: hostname,
×
5702
                        Port:     uint16(port),
×
5703
                }, nil
×
5704

5705
        case addressTypeOpaque:
×
5706
                opaque, err := hex.DecodeString(address)
×
5707
                if err != nil {
×
5708
                        return nil, fmt.Errorf("unable to decode opaque "+
×
5709
                                "address: %v", address)
×
5710
                }
×
5711

5712
                return &lnwire.OpaqueAddrs{
×
5713
                        Payload: opaque,
×
5714
                }, nil
×
5715

5716
        default:
×
5717
                return nil, fmt.Errorf("unknown address type: %v", addrType)
×
5718
        }
5719
}
5720

5721
// batchNodeData holds all the related data for a batch of nodes.
5722
type batchNodeData struct {
5723
        // features is a map from a DB node ID to the feature bits for that
5724
        // node.
5725
        features map[int64][]int
5726

5727
        // addresses is a map from a DB node ID to the node's addresses.
5728
        addresses map[int64][]nodeAddress
5729

5730
        // extraFields is a map from a DB node ID to the extra signed fields
5731
        // for that node.
5732
        extraFields map[int64]map[uint64][]byte
5733
}
5734

5735
// nodeAddress holds the address type, position and address string for a
5736
// node. This is used to batch the fetching of node addresses.
5737
type nodeAddress struct {
5738
        addrType dbAddressType
5739
        position int32
5740
        address  string
5741
}
5742

5743
// batchLoadNodeData loads all related data for a batch of node IDs using the
5744
// provided SQLQueries interface. It returns a batchNodeData instance containing
5745
// the node features, addresses and extra signed fields.
5746
func batchLoadNodeData(ctx context.Context, cfg *sqldb.QueryConfig,
5747
        db SQLQueries, nodeIDs []int64) (*batchNodeData, error) {
×
5748

×
5749
        // Batch load the node features.
×
5750
        features, err := batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
5751
        if err != nil {
×
5752
                return nil, fmt.Errorf("unable to batch load node "+
×
5753
                        "features: %w", err)
×
5754
        }
×
5755

5756
        // Batch load the node addresses.
5757
        addrs, err := batchLoadNodeAddressesHelper(ctx, cfg, db, nodeIDs)
×
5758
        if err != nil {
×
5759
                return nil, fmt.Errorf("unable to batch load node "+
×
5760
                        "addresses: %w", err)
×
5761
        }
×
5762

5763
        // Batch load the node extra signed fields.
5764
        extraTypes, err := batchLoadNodeExtraTypesHelper(ctx, cfg, db, nodeIDs)
×
5765
        if err != nil {
×
5766
                return nil, fmt.Errorf("unable to batch load node extra "+
×
5767
                        "signed fields: %w", err)
×
5768
        }
×
5769

5770
        return &batchNodeData{
×
5771
                features:    features,
×
5772
                addresses:   addrs,
×
5773
                extraFields: extraTypes,
×
5774
        }, nil
×
5775
}
5776

5777
// batchLoadNodeFeaturesHelper loads node features for a batch of node IDs
5778
// using ExecuteBatchQuery wrapper around the GetNodeFeaturesBatch query.
5779
func batchLoadNodeFeaturesHelper(ctx context.Context,
5780
        cfg *sqldb.QueryConfig, db SQLQueries,
5781
        nodeIDs []int64) (map[int64][]int, error) {
×
5782

×
5783
        features := make(map[int64][]int)
×
5784

×
5785
        return features, sqldb.ExecuteBatchQuery(
×
5786
                ctx, cfg, nodeIDs,
×
5787
                func(id int64) int64 {
×
5788
                        return id
×
5789
                },
×
5790
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature,
5791
                        error) {
×
5792

×
5793
                        return db.GetNodeFeaturesBatch(ctx, ids)
×
5794
                },
×
5795
                func(ctx context.Context, feature sqlc.GraphNodeFeature) error {
×
5796
                        features[feature.NodeID] = append(
×
5797
                                features[feature.NodeID],
×
5798
                                int(feature.FeatureBit),
×
5799
                        )
×
5800

×
5801
                        return nil
×
5802
                },
×
5803
        )
5804
}
5805

5806
// batchLoadNodeAddressesHelper loads node addresses using ExecuteBatchQuery
5807
// wrapper around the GetNodeAddressesBatch query. It returns a map from
5808
// node ID to a slice of nodeAddress structs.
5809
func batchLoadNodeAddressesHelper(ctx context.Context,
5810
        cfg *sqldb.QueryConfig, db SQLQueries,
5811
        nodeIDs []int64) (map[int64][]nodeAddress, error) {
×
5812

×
5813
        addrs := make(map[int64][]nodeAddress)
×
5814

×
5815
        return addrs, sqldb.ExecuteBatchQuery(
×
5816
                ctx, cfg, nodeIDs,
×
5817
                func(id int64) int64 {
×
5818
                        return id
×
5819
                },
×
5820
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress,
5821
                        error) {
×
5822

×
5823
                        return db.GetNodeAddressesBatch(ctx, ids)
×
5824
                },
×
5825
                func(ctx context.Context, addr sqlc.GraphNodeAddress) error {
×
5826
                        addrs[addr.NodeID] = append(
×
5827
                                addrs[addr.NodeID], nodeAddress{
×
5828
                                        addrType: dbAddressType(addr.Type),
×
5829
                                        position: addr.Position,
×
5830
                                        address:  addr.Address,
×
5831
                                },
×
5832
                        )
×
5833

×
5834
                        return nil
×
5835
                },
×
5836
        )
5837
}
5838

5839
// batchLoadNodeExtraTypesHelper loads node extra type bytes for a batch of
5840
// node IDs using ExecuteBatchQuery wrapper around the GetNodeExtraTypesBatch
5841
// query.
5842
func batchLoadNodeExtraTypesHelper(ctx context.Context,
5843
        cfg *sqldb.QueryConfig, db SQLQueries,
5844
        nodeIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5845

×
5846
        extraFields := make(map[int64]map[uint64][]byte)
×
5847

×
5848
        callback := func(ctx context.Context,
×
5849
                field sqlc.GraphNodeExtraType) error {
×
5850

×
5851
                if extraFields[field.NodeID] == nil {
×
5852
                        extraFields[field.NodeID] = make(map[uint64][]byte)
×
5853
                }
×
5854
                extraFields[field.NodeID][uint64(field.Type)] = field.Value
×
5855

×
5856
                return nil
×
5857
        }
5858

5859
        return extraFields, sqldb.ExecuteBatchQuery(
×
5860
                ctx, cfg, nodeIDs,
×
5861
                func(id int64) int64 {
×
5862
                        return id
×
5863
                },
×
5864
                func(ctx context.Context, ids []int64) (
5865
                        []sqlc.GraphNodeExtraType, error) {
×
5866

×
5867
                        return db.GetNodeExtraTypesBatch(ctx, ids)
×
5868
                },
×
5869
                callback,
5870
        )
5871
}
5872

5873
// buildChanPoliciesWithBatchData builds two models.ChannelEdgePolicy instances
5874
// from the provided sqlc.GraphChannelPolicy records and the
5875
// provided batchChannelData.
5876
func buildChanPoliciesWithBatchData(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
5877
        channelID uint64, node1, node2 route.Vertex,
5878
        batchData *batchChannelData) (*models.ChannelEdgePolicy,
5879
        *models.ChannelEdgePolicy, error) {
×
5880

×
5881
        pol1, err := buildChanPolicyWithBatchData(
×
5882
                true, dbPol1, channelID, node2, batchData,
×
5883
        )
×
5884
        if err != nil {
×
5885
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
5886
        }
×
5887

5888
        pol2, err := buildChanPolicyWithBatchData(
×
5889
                false, dbPol2, channelID, node1, batchData,
×
5890
        )
×
5891
        if err != nil {
×
5892
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
5893
        }
×
5894

5895
        return pol1, pol2, nil
×
5896
}
5897

5898
// buildChanPolicyWithBatchData builds a models.ChannelEdgePolicy instance from
5899
// the provided sqlc.GraphChannelPolicy and the provided batchChannelData.
5900
func buildChanPolicyWithBatchData(isNode1 bool,
5901
        dbPol *sqlc.GraphChannelPolicy, channelID uint64,
5902
        toNode route.Vertex, batchData *batchChannelData) (
5903
        *models.ChannelEdgePolicy, error) {
×
5904

×
5905
        if dbPol == nil {
×
5906
                return nil, nil
×
5907
        }
×
5908

5909
        var dbPol1Extras map[uint64][]byte
×
5910
        if extras, exists := batchData.policyExtras[dbPol.ID]; exists {
×
5911
                dbPol1Extras = extras
×
5912
        } else {
×
5913
                dbPol1Extras = make(map[uint64][]byte)
×
5914
        }
×
5915

5916
        return buildChanPolicy(isNode1, *dbPol, channelID, dbPol1Extras, toNode)
×
5917
}
5918

5919
// batchChannelData holds all the related data for a batch of channels.
5920
type batchChannelData struct {
5921
        // chanFeatures is a map from DB channel ID to a slice of feature bits.
5922
        chanfeatures map[int64][]int
5923

5924
        // chanExtras is a map from DB channel ID to a map of TLV type to
5925
        // extra signed field bytes.
5926
        chanExtraTypes map[int64]map[uint64][]byte
5927

5928
        // policyExtras is a map from DB channel policy ID to a map of TLV type
5929
        // to extra signed field bytes.
5930
        policyExtras map[int64]map[uint64][]byte
5931
}
5932

5933
// batchLoadChannelData loads all related data for batches of channels and
5934
// policies.
5935
func batchLoadChannelData(ctx context.Context, cfg *sqldb.QueryConfig,
5936
        db SQLQueries, channelIDs []int64,
5937
        policyIDs []int64) (*batchChannelData, error) {
×
5938

×
5939
        batchData := &batchChannelData{
×
5940
                chanfeatures:   make(map[int64][]int),
×
5941
                chanExtraTypes: make(map[int64]map[uint64][]byte),
×
5942
                policyExtras:   make(map[int64]map[uint64][]byte),
×
5943
        }
×
5944

×
5945
        // Batch load channel features and extras
×
5946
        var err error
×
5947
        if len(channelIDs) > 0 {
×
5948
                batchData.chanfeatures, err = batchLoadChannelFeaturesHelper(
×
5949
                        ctx, cfg, db, channelIDs,
×
5950
                )
×
5951
                if err != nil {
×
5952
                        return nil, fmt.Errorf("unable to batch load "+
×
5953
                                "channel features: %w", err)
×
5954
                }
×
5955

5956
                batchData.chanExtraTypes, err = batchLoadChannelExtrasHelper(
×
5957
                        ctx, cfg, db, channelIDs,
×
5958
                )
×
5959
                if err != nil {
×
5960
                        return nil, fmt.Errorf("unable to batch load "+
×
5961
                                "channel extras: %w", err)
×
5962
                }
×
5963
        }
5964

5965
        if len(policyIDs) > 0 {
×
5966
                policyExtras, err := batchLoadChannelPolicyExtrasHelper(
×
5967
                        ctx, cfg, db, policyIDs,
×
5968
                )
×
5969
                if err != nil {
×
5970
                        return nil, fmt.Errorf("unable to batch load "+
×
5971
                                "policy extras: %w", err)
×
5972
                }
×
5973
                batchData.policyExtras = policyExtras
×
5974
        }
5975

5976
        return batchData, nil
×
5977
}
5978

5979
// batchLoadChannelFeaturesHelper loads channel features for a batch of
5980
// channel IDs using ExecuteBatchQuery wrapper around the
5981
// GetChannelFeaturesBatch query. It returns a map from DB channel ID to a
5982
// slice of feature bits.
5983
func batchLoadChannelFeaturesHelper(ctx context.Context,
5984
        cfg *sqldb.QueryConfig, db SQLQueries,
5985
        channelIDs []int64) (map[int64][]int, error) {
×
5986

×
5987
        features := make(map[int64][]int)
×
5988

×
5989
        return features, sqldb.ExecuteBatchQuery(
×
5990
                ctx, cfg, channelIDs,
×
5991
                func(id int64) int64 {
×
5992
                        return id
×
5993
                },
×
5994
                func(ctx context.Context,
5995
                        ids []int64) ([]sqlc.GraphChannelFeature, error) {
×
5996

×
5997
                        return db.GetChannelFeaturesBatch(ctx, ids)
×
5998
                },
×
5999
                func(ctx context.Context,
6000
                        feature sqlc.GraphChannelFeature) error {
×
6001

×
6002
                        features[feature.ChannelID] = append(
×
6003
                                features[feature.ChannelID],
×
6004
                                int(feature.FeatureBit),
×
6005
                        )
×
6006

×
6007
                        return nil
×
6008
                },
×
6009
        )
6010
}
6011

6012
// batchLoadChannelExtrasHelper loads channel extra types for a batch of
6013
// channel IDs using ExecuteBatchQuery wrapper around the GetChannelExtrasBatch
6014
// query. It returns a map from DB channel ID to a map of TLV type to extra
6015
// signed field bytes.
6016
func batchLoadChannelExtrasHelper(ctx context.Context,
6017
        cfg *sqldb.QueryConfig, db SQLQueries,
6018
        channelIDs []int64) (map[int64]map[uint64][]byte, error) {
×
6019

×
6020
        extras := make(map[int64]map[uint64][]byte)
×
6021

×
6022
        cb := func(ctx context.Context,
×
6023
                extra sqlc.GraphChannelExtraType) error {
×
6024

×
6025
                if extras[extra.ChannelID] == nil {
×
6026
                        extras[extra.ChannelID] = make(map[uint64][]byte)
×
6027
                }
×
6028
                extras[extra.ChannelID][uint64(extra.Type)] = extra.Value
×
6029

×
6030
                return nil
×
6031
        }
6032

6033
        return extras, sqldb.ExecuteBatchQuery(
×
6034
                ctx, cfg, channelIDs,
×
6035
                func(id int64) int64 {
×
6036
                        return id
×
6037
                },
×
6038
                func(ctx context.Context,
6039
                        ids []int64) ([]sqlc.GraphChannelExtraType, error) {
×
6040

×
6041
                        return db.GetChannelExtrasBatch(ctx, ids)
×
6042
                }, cb,
×
6043
        )
6044
}
6045

6046
// batchLoadChannelPolicyExtrasHelper loads channel policy extra types for a
6047
// batch of policy IDs using ExecuteBatchQuery wrapper around the
6048
// GetChannelPolicyExtraTypesBatch query. It returns a map from DB policy ID to
6049
// a map of TLV type to extra signed field bytes.
6050
func batchLoadChannelPolicyExtrasHelper(ctx context.Context,
6051
        cfg *sqldb.QueryConfig, db SQLQueries,
6052
        policyIDs []int64) (map[int64]map[uint64][]byte, error) {
×
6053

×
6054
        extras := make(map[int64]map[uint64][]byte)
×
6055

×
6056
        return extras, sqldb.ExecuteBatchQuery(
×
6057
                ctx, cfg, policyIDs,
×
6058
                func(id int64) int64 {
×
6059
                        return id
×
6060
                },
×
6061
                func(ctx context.Context, ids []int64) (
6062
                        []sqlc.GetChannelPolicyExtraTypesBatchRow, error) {
×
6063

×
6064
                        return db.GetChannelPolicyExtraTypesBatch(ctx, ids)
×
6065
                },
×
6066
                func(ctx context.Context,
6067
                        row sqlc.GetChannelPolicyExtraTypesBatchRow) error {
×
6068

×
6069
                        if extras[row.PolicyID] == nil {
×
6070
                                extras[row.PolicyID] = make(map[uint64][]byte)
×
6071
                        }
×
6072
                        extras[row.PolicyID][uint64(row.Type)] = row.Value
×
6073

×
6074
                        return nil
×
6075
                },
6076
        )
6077
}
6078

6079
// forEachNodePaginated executes a paginated query to process each node in the
6080
// graph. It uses the provided SQLQueries interface to fetch nodes in batches
6081
// and applies the provided processNode function to each node.
6082
func forEachNodePaginated(ctx context.Context, cfg *sqldb.QueryConfig,
6083
        db SQLQueries, protocol lnwire.GossipVersion,
6084
        processNode func(context.Context, int64,
6085
                *models.Node) error) error {
×
6086

×
6087
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
6088
                limit int32) ([]sqlc.GraphNode, error) {
×
6089

×
6090
                return db.ListNodesPaginated(
×
6091
                        ctx, sqlc.ListNodesPaginatedParams{
×
6092
                                Version: int16(protocol),
×
6093
                                ID:      lastID,
×
6094
                                Limit:   limit,
×
6095
                        },
×
6096
                )
×
6097
        }
×
6098

6099
        extractPageCursor := func(node sqlc.GraphNode) int64 {
×
6100
                return node.ID
×
6101
        }
×
6102

6103
        collectFunc := func(node sqlc.GraphNode) (int64, error) {
×
6104
                return node.ID, nil
×
6105
        }
×
6106

6107
        batchQueryFunc := func(ctx context.Context,
×
6108
                nodeIDs []int64) (*batchNodeData, error) {
×
6109

×
6110
                return batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
6111
        }
×
6112

6113
        processItem := func(ctx context.Context, dbNode sqlc.GraphNode,
×
6114
                batchData *batchNodeData) error {
×
6115

×
6116
                node, err := buildNodeWithBatchData(dbNode, batchData)
×
6117
                if err != nil {
×
6118
                        return fmt.Errorf("unable to build "+
×
6119
                                "node(id=%d): %w", dbNode.ID, err)
×
6120
                }
×
6121

6122
                return processNode(ctx, dbNode.ID, node)
×
6123
        }
6124

6125
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
6126
                ctx, cfg, int64(-1), pageQueryFunc, extractPageCursor,
×
6127
                collectFunc, batchQueryFunc, processItem,
×
6128
        )
×
6129
}
6130

6131
// forEachChannelWithPolicies executes a paginated query to process each channel
6132
// with policies in the graph.
6133
func forEachChannelWithPolicies(ctx context.Context, db SQLQueries,
6134
        cfg *SQLStoreConfig, v lnwire.GossipVersion,
6135
        processChannel func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
6136
                *models.ChannelEdgePolicy) error) error {
×
6137

×
6138
        type channelBatchIDs struct {
×
6139
                channelID int64
×
6140
                policyIDs []int64
×
6141
        }
×
6142

×
6143
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
6144
                limit int32) ([]sqlc.ListChannelsWithPoliciesPaginatedRow,
×
6145
                error) {
×
6146

×
6147
                return db.ListChannelsWithPoliciesPaginated(
×
6148
                        ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
6149
                                Version: int16(v),
×
6150
                                ID:      lastID,
×
6151
                                Limit:   limit,
×
6152
                        },
×
6153
                )
×
6154
        }
×
6155

6156
        extractPageCursor := func(
×
6157
                row sqlc.ListChannelsWithPoliciesPaginatedRow) int64 {
×
6158

×
6159
                return row.GraphChannel.ID
×
6160
        }
×
6161

6162
        collectFunc := func(row sqlc.ListChannelsWithPoliciesPaginatedRow) (
×
6163
                channelBatchIDs, error) {
×
6164

×
6165
                ids := channelBatchIDs{
×
6166
                        channelID: row.GraphChannel.ID,
×
6167
                }
×
6168

×
6169
                // Extract policy IDs from the row.
×
6170
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
6171
                if err != nil {
×
6172
                        return ids, err
×
6173
                }
×
6174

6175
                if dbPol1 != nil {
×
6176
                        ids.policyIDs = append(ids.policyIDs, dbPol1.ID)
×
6177
                }
×
6178
                if dbPol2 != nil {
×
6179
                        ids.policyIDs = append(ids.policyIDs, dbPol2.ID)
×
6180
                }
×
6181

6182
                return ids, nil
×
6183
        }
6184

6185
        batchDataFunc := func(ctx context.Context,
×
6186
                allIDs []channelBatchIDs) (*batchChannelData, error) {
×
6187

×
6188
                // Separate channel IDs from policy IDs.
×
6189
                var (
×
6190
                        channelIDs = make([]int64, len(allIDs))
×
6191
                        policyIDs  = make([]int64, 0, len(allIDs)*2)
×
6192
                )
×
6193

×
6194
                for i, ids := range allIDs {
×
6195
                        channelIDs[i] = ids.channelID
×
6196
                        policyIDs = append(policyIDs, ids.policyIDs...)
×
6197
                }
×
6198

6199
                return batchLoadChannelData(
×
6200
                        ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
6201
                )
×
6202
        }
6203

6204
        processItem := func(ctx context.Context,
×
6205
                row sqlc.ListChannelsWithPoliciesPaginatedRow,
×
6206
                batchData *batchChannelData) error {
×
6207

×
6208
                node1, node2, err := buildNodeVertices(
×
6209
                        row.Node1Pubkey, row.Node2Pubkey,
×
6210
                )
×
6211
                if err != nil {
×
6212
                        return err
×
6213
                }
×
6214

6215
                edge, err := buildEdgeInfoWithBatchData(
×
6216
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
6217
                        batchData,
×
6218
                )
×
6219
                if err != nil {
×
6220
                        return fmt.Errorf("unable to build channel info: %w",
×
6221
                                err)
×
6222
                }
×
6223

6224
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
6225
                if err != nil {
×
6226
                        return err
×
6227
                }
×
6228

6229
                p1, p2, err := buildChanPoliciesWithBatchData(
×
6230
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
6231
                )
×
6232
                if err != nil {
×
6233
                        return err
×
6234
                }
×
6235

6236
                return processChannel(edge, p1, p2)
×
6237
        }
6238

6239
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
6240
                ctx, cfg.QueryCfg, int64(-1), pageQueryFunc, extractPageCursor,
×
6241
                collectFunc, batchDataFunc, processItem,
×
6242
        )
×
6243
}
6244

6245
// buildDirectedChannel builds a DirectedChannel instance from the provided
6246
// data.
6247
func buildDirectedChannel(chain chainhash.Hash, nodeID int64,
6248
        nodePub route.Vertex, channelRow sqlc.ListChannelsForNodeIDsRow,
6249
        channelBatchData *batchChannelData, features *lnwire.FeatureVector,
6250
        toNodeCallback func() route.Vertex) (*DirectedChannel, error) {
×
6251

×
6252
        node1, node2, err := buildNodeVertices(
×
6253
                channelRow.Node1Pubkey, channelRow.Node2Pubkey,
×
6254
        )
×
6255
        if err != nil {
×
6256
                return nil, fmt.Errorf("unable to build node vertices: %w", err)
×
6257
        }
×
6258

6259
        edge, err := buildEdgeInfoWithBatchData(
×
6260
                chain, channelRow.GraphChannel, node1, node2, channelBatchData,
×
6261
        )
×
6262
        if err != nil {
×
6263
                return nil, fmt.Errorf("unable to build channel info: %w", err)
×
6264
        }
×
6265

6266
        dbPol1, dbPol2, err := extractChannelPolicies(channelRow)
×
6267
        if err != nil {
×
6268
                return nil, fmt.Errorf("unable to extract channel policies: %w",
×
6269
                        err)
×
6270
        }
×
6271

6272
        p1, p2, err := buildChanPoliciesWithBatchData(
×
6273
                dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
6274
                channelBatchData,
×
6275
        )
×
6276
        if err != nil {
×
6277
                return nil, fmt.Errorf("unable to build channel policies: %w",
×
6278
                        err)
×
6279
        }
×
6280

6281
        // Determine outgoing and incoming policy for this specific node.
6282
        p1ToNode := channelRow.GraphChannel.NodeID2
×
6283
        p2ToNode := channelRow.GraphChannel.NodeID1
×
6284
        outPolicy, inPolicy := p1, p2
×
6285
        if (p1 != nil && p1ToNode == nodeID) ||
×
6286
                (p2 != nil && p2ToNode != nodeID) {
×
6287

×
6288
                outPolicy, inPolicy = p2, p1
×
6289
        }
×
6290

6291
        // Build cached policy.
6292
        var cachedInPolicy *models.CachedEdgePolicy
×
6293
        if inPolicy != nil {
×
6294
                cachedInPolicy = models.NewCachedPolicy(inPolicy)
×
6295
                cachedInPolicy.ToNodePubKey = toNodeCallback
×
6296
                cachedInPolicy.ToNodeFeatures = features
×
6297
        }
×
6298

6299
        // Extract inbound fee.
6300
        var inboundFee lnwire.Fee
×
6301
        if outPolicy != nil {
×
6302
                outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
6303
                        inboundFee = fee
×
6304
                })
×
6305
        }
6306

6307
        // Build directed channel.
6308
        directedChannel := &DirectedChannel{
×
6309
                ChannelID:    edge.ChannelID,
×
6310
                IsNode1:      nodePub == edge.NodeKey1Bytes,
×
6311
                OtherNode:    edge.NodeKey2Bytes,
×
6312
                Capacity:     edge.Capacity,
×
6313
                OutPolicySet: outPolicy != nil,
×
6314
                InPolicy:     cachedInPolicy,
×
6315
                InboundFee:   inboundFee,
×
6316
        }
×
6317

×
6318
        if nodePub == edge.NodeKey2Bytes {
×
6319
                directedChannel.OtherNode = edge.NodeKey1Bytes
×
6320
        }
×
6321

6322
        return directedChannel, nil
×
6323
}
6324

6325
// batchBuildChannelEdges builds a slice of ChannelEdge instances from the
6326
// provided rows. It uses batch loading for channels, policies, and nodes.
6327
func batchBuildChannelEdges[T sqlc.ChannelAndNodes](ctx context.Context,
6328
        cfg *SQLStoreConfig, db SQLQueries, rows []T) ([]ChannelEdge, error) {
×
6329

×
6330
        var (
×
6331
                channelIDs = make([]int64, len(rows))
×
6332
                policyIDs  = make([]int64, 0, len(rows)*2)
×
6333
                nodeIDs    = make([]int64, 0, len(rows)*2)
×
6334

×
6335
                // nodeIDSet is used to ensure we only collect unique node IDs.
×
6336
                nodeIDSet = make(map[int64]bool)
×
6337

×
6338
                // edges will hold the final channel edges built from the rows.
×
6339
                edges = make([]ChannelEdge, 0, len(rows))
×
6340
        )
×
6341

×
6342
        // Collect all IDs needed for batch loading.
×
6343
        for i, row := range rows {
×
6344
                channelIDs[i] = row.Channel().ID
×
6345

×
6346
                // Collect policy IDs
×
6347
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
6348
                if err != nil {
×
6349
                        return nil, fmt.Errorf("unable to extract channel "+
×
6350
                                "policies: %w", err)
×
6351
                }
×
6352
                if dbPol1 != nil {
×
6353
                        policyIDs = append(policyIDs, dbPol1.ID)
×
6354
                }
×
6355
                if dbPol2 != nil {
×
6356
                        policyIDs = append(policyIDs, dbPol2.ID)
×
6357
                }
×
6358

6359
                var (
×
6360
                        node1ID = row.Node1().ID
×
6361
                        node2ID = row.Node2().ID
×
6362
                )
×
6363

×
6364
                // Collect unique node IDs.
×
6365
                if !nodeIDSet[node1ID] {
×
6366
                        nodeIDs = append(nodeIDs, node1ID)
×
6367
                        nodeIDSet[node1ID] = true
×
6368
                }
×
6369

6370
                if !nodeIDSet[node2ID] {
×
6371
                        nodeIDs = append(nodeIDs, node2ID)
×
6372
                        nodeIDSet[node2ID] = true
×
6373
                }
×
6374
        }
6375

6376
        // Batch the data for all the channels and policies.
6377
        channelBatchData, err := batchLoadChannelData(
×
6378
                ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
6379
        )
×
6380
        if err != nil {
×
6381
                return nil, fmt.Errorf("unable to batch load channel and "+
×
6382
                        "policy data: %w", err)
×
6383
        }
×
6384

6385
        // Batch the data for all the nodes.
6386
        nodeBatchData, err := batchLoadNodeData(ctx, cfg.QueryCfg, db, nodeIDs)
×
6387
        if err != nil {
×
6388
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
6389
                        err)
×
6390
        }
×
6391

6392
        // Build all channel edges using batch data.
6393
        for _, row := range rows {
×
6394
                // Build nodes using batch data.
×
6395
                node1, err := buildNodeWithBatchData(row.Node1(), nodeBatchData)
×
6396
                if err != nil {
×
6397
                        return nil, fmt.Errorf("unable to build node1: %w", err)
×
6398
                }
×
6399

6400
                node2, err := buildNodeWithBatchData(row.Node2(), nodeBatchData)
×
6401
                if err != nil {
×
6402
                        return nil, fmt.Errorf("unable to build node2: %w", err)
×
6403
                }
×
6404

6405
                // Build channel info using batch data.
6406
                channel, err := buildEdgeInfoWithBatchData(
×
6407
                        cfg.ChainHash, row.Channel(), node1.PubKeyBytes,
×
6408
                        node2.PubKeyBytes, channelBatchData,
×
6409
                )
×
6410
                if err != nil {
×
6411
                        return nil, fmt.Errorf("unable to build channel "+
×
6412
                                "info: %w", err)
×
6413
                }
×
6414

6415
                // Extract and build policies using batch data.
6416
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
6417
                if err != nil {
×
6418
                        return nil, fmt.Errorf("unable to extract channel "+
×
6419
                                "policies: %w", err)
×
6420
                }
×
6421

6422
                p1, p2, err := buildChanPoliciesWithBatchData(
×
6423
                        dbPol1, dbPol2, channel.ChannelID,
×
6424
                        node1.PubKeyBytes, node2.PubKeyBytes, channelBatchData,
×
6425
                )
×
6426
                if err != nil {
×
6427
                        return nil, fmt.Errorf("unable to build channel "+
×
6428
                                "policies: %w", err)
×
6429
                }
×
6430

6431
                edges = append(edges, ChannelEdge{
×
6432
                        Info:    channel,
×
6433
                        Policy1: p1,
×
6434
                        Policy2: p2,
×
6435
                        Node1:   node1,
×
6436
                        Node2:   node2,
×
6437
                })
×
6438
        }
6439

6440
        return edges, nil
×
6441
}
6442

6443
// batchBuildChannelInfo builds a slice of models.ChannelEdgeInfo
6444
// instances from the provided rows using batch loading for channel data.
6445
func batchBuildChannelInfo[T sqlc.ChannelAndNodeIDs](ctx context.Context,
6446
        cfg *SQLStoreConfig, db SQLQueries, rows []T) (
6447
        []*models.ChannelEdgeInfo, []int64, error) {
×
6448

×
6449
        if len(rows) == 0 {
×
6450
                return nil, nil, nil
×
6451
        }
×
6452

6453
        // Collect all the channel IDs needed for batch loading.
6454
        channelIDs := make([]int64, len(rows))
×
6455
        for i, row := range rows {
×
6456
                channelIDs[i] = row.Channel().ID
×
6457
        }
×
6458

6459
        // Batch load the channel data.
6460
        channelBatchData, err := batchLoadChannelData(
×
6461
                ctx, cfg.QueryCfg, db, channelIDs, nil,
×
6462
        )
×
6463
        if err != nil {
×
6464
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
6465
                        "data: %w", err)
×
6466
        }
×
6467

6468
        // Build all channel edges using batch data.
6469
        edges := make([]*models.ChannelEdgeInfo, 0, len(rows))
×
6470
        for _, row := range rows {
×
6471
                node1, node2, err := buildNodeVertices(
×
6472
                        row.Node1Pub(), row.Node2Pub(),
×
6473
                )
×
6474
                if err != nil {
×
6475
                        return nil, nil, err
×
6476
                }
×
6477

6478
                // Build channel info using batch data
6479
                info, err := buildEdgeInfoWithBatchData(
×
6480
                        cfg.ChainHash, row.Channel(), node1, node2,
×
6481
                        channelBatchData,
×
6482
                )
×
6483
                if err != nil {
×
6484
                        return nil, nil, err
×
6485
                }
×
6486

6487
                edges = append(edges, info)
×
6488
        }
6489

6490
        return edges, channelIDs, nil
×
6491
}
6492

6493
// handleZombieMarking is a helper function that handles the logic of
6494
// marking a channel as a zombie in the database. It takes into account whether
6495
// we are in strict zombie pruning mode, and adjusts the node public keys
6496
// accordingly based on the last update timestamps of the channel policies.
6497
func handleZombieMarking(ctx context.Context, db SQLQueries,
6498
        v lnwire.GossipVersion,
6499
        row sqlc.GetChannelsBySCIDWithPoliciesRow, info *models.ChannelEdgeInfo,
6500
        strictZombiePruning bool, scid uint64) error {
×
6501

×
6502
        nodeKey1, nodeKey2 := info.NodeKey1Bytes, info.NodeKey2Bytes
×
6503

×
6504
        if strictZombiePruning {
×
6505
                // TODO(elle): update for V2 last update times.
×
6506
                if v != gossipV1 {
×
6507
                        return fmt.Errorf("strict zombie pruning only "+
×
6508
                                "supported for gossip v1, got %v", v)
×
6509
                }
×
6510

6511
                var e1UpdateTime, e2UpdateTime *time.Time
×
6512
                if row.Policy1LastUpdate.Valid {
×
6513
                        e1Time := time.Unix(row.Policy1LastUpdate.Int64, 0)
×
6514
                        e1UpdateTime = &e1Time
×
6515
                }
×
6516
                if row.Policy2LastUpdate.Valid {
×
6517
                        e2Time := time.Unix(row.Policy2LastUpdate.Int64, 0)
×
6518
                        e2UpdateTime = &e2Time
×
6519
                }
×
6520

6521
                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
6522
                        info.NodeKey1Bytes, info.NodeKey2Bytes, e1UpdateTime,
×
6523
                        e2UpdateTime,
×
6524
                )
×
6525
        }
6526

6527
        return db.UpsertZombieChannel(
×
6528
                ctx, sqlc.UpsertZombieChannelParams{
×
6529
                        Version:  int16(v),
×
6530
                        Scid:     channelIDToBytes(scid),
×
6531
                        NodeKey1: nodeKey1[:],
×
6532
                        NodeKey2: nodeKey2[:],
×
6533
                },
×
6534
        )
×
6535
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc