• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 16647329373

31 Jul 2025 11:08AM UTC coverage: 66.99% (-0.05%) from 67.044%
16647329373

Pull #10118

github

web-flow
Merge 369b2892f into b5c290d90
Pull Request #10118: [4] sqldb+graph/db: add and use new pagination & batch query helpers

6 of 410 new or added lines in 2 files covered. (1.46%)

128 existing lines in 27 files now uncovered.

135487 of 202249 relevant lines covered (66.99%)

21660.86 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "maps"
11
        "math"
12
        "net"
13
        "slices"
14
        "strconv"
15
        "sync"
16
        "time"
17

18
        "github.com/btcsuite/btcd/btcec/v2"
19
        "github.com/btcsuite/btcd/btcutil"
20
        "github.com/btcsuite/btcd/chaincfg/chainhash"
21
        "github.com/btcsuite/btcd/wire"
22
        "github.com/lightningnetwork/lnd/aliasmgr"
23
        "github.com/lightningnetwork/lnd/batch"
24
        "github.com/lightningnetwork/lnd/fn/v2"
25
        "github.com/lightningnetwork/lnd/graph/db/models"
26
        "github.com/lightningnetwork/lnd/lnwire"
27
        "github.com/lightningnetwork/lnd/routing/route"
28
        "github.com/lightningnetwork/lnd/sqldb"
29
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
30
        "github.com/lightningnetwork/lnd/tlv"
31
        "github.com/lightningnetwork/lnd/tor"
32
)
33

34
// ProtocolVersion is an enum that defines the gossip protocol version of a
35
// message.
36
type ProtocolVersion uint8
37

38
const (
39
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
40
        ProtocolV1 ProtocolVersion = 1
41
)
42

43
// String returns a string representation of the protocol version.
44
func (v ProtocolVersion) String() string {
×
45
        return fmt.Sprintf("V%d", v)
×
46
}
×
47

48
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
49
// execute queries against the SQL graph tables.
50
//
51
//nolint:ll,interfacebloat
52
type SQLQueries interface {
53
        /*
54
                Node queries.
55
        */
56
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
57
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.GraphNode, error)
58
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
59
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.GraphNode, error)
60
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.GraphNode, error)
61
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
62
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
63
        DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error)
64
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
65
        DeleteNode(ctx context.Context, id int64) error
66

67
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeExtraType, error)
68
        GetNodeExtraTypesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeExtraType, error)
69
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
70
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
71

72
        InsertNodeAddress(ctx context.Context, arg sqlc.InsertNodeAddressParams) error
73
        GetNodeAddresses(ctx context.Context, nodeID int64) ([]sqlc.GetNodeAddressesRow, error)
74
        GetNodeAddressesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress, error)
75
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
76

77
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
78
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeFeature, error)
79
        GetNodeFeaturesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature, error)
80
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
81
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
82

83
        /*
84
                Source node queries.
85
        */
86
        AddSourceNode(ctx context.Context, nodeID int64) error
87
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
88

89
        /*
90
                Channel queries.
91
        */
92
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
93
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
94
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.GraphChannel, error)
95
        GetChannelsBySCIDs(ctx context.Context, arg sqlc.GetChannelsBySCIDsParams) ([]sqlc.GraphChannel, error)
96
        GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]sqlc.GetChannelsByOutpointsRow, error)
97
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
98
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
99
        GetChannelsBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelsBySCIDWithPoliciesParams) ([]sqlc.GetChannelsBySCIDWithPoliciesRow, error)
100
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
101
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
102
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
103
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
104
        ListChannelsWithPoliciesForCachePaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesForCachePaginatedParams) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow, error)
105
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
106
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
107
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
108
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.GraphChannel, error)
109
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
110
        DeleteChannels(ctx context.Context, ids []int64) error
111

112
        CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
113
        GetChannelExtrasBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelExtraType, error)
114
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
115
        GetChannelFeaturesBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelFeature, error)
116

117
        /*
118
                Channel Policy table queries.
119
        */
120
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
121
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.GraphChannelPolicy, error)
122
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
123

124
        InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error
125
        GetChannelPolicyExtraTypesBatch(ctx context.Context, policyIds []int64) ([]sqlc.GetChannelPolicyExtraTypesBatchRow, error)
126
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
127

128
        /*
129
                Zombie index queries.
130
        */
131
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
132
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.GraphZombieChannel, error)
133
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
134
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
135
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
136

137
        /*
138
                Prune log table queries.
139
        */
140
        GetPruneTip(ctx context.Context) (sqlc.GraphPruneLog, error)
141
        GetPruneHashByHeight(ctx context.Context, blockHeight int64) ([]byte, error)
142
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
143
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
144

145
        /*
146
                Closed SCID table queries.
147
        */
148
        InsertClosedChannel(ctx context.Context, scid []byte) error
149
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
150
}
151

152
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
153
// database operations.
154
type BatchedSQLQueries interface {
155
        SQLQueries
156
        sqldb.BatchedTx[SQLQueries]
157
}
158

159
// SQLStore is an implementation of the V1Store interface that uses a SQL
160
// database as the backend.
161
type SQLStore struct {
162
        cfg *SQLStoreConfig
163
        db  BatchedSQLQueries
164

165
        // cacheMu guards all caches (rejectCache and chanCache). If
166
        // this mutex will be acquired at the same time as the DB mutex then
167
        // the cacheMu MUST be acquired first to prevent deadlock.
168
        cacheMu     sync.RWMutex
169
        rejectCache *rejectCache
170
        chanCache   *channelCache
171

172
        chanScheduler batch.Scheduler[SQLQueries]
173
        nodeScheduler batch.Scheduler[SQLQueries]
174

175
        srcNodes  map[ProtocolVersion]*srcNodeInfo
176
        srcNodeMu sync.Mutex
177
}
178

179
// A compile-time assertion to ensure that SQLStore implements the V1Store
180
// interface.
181
var _ V1Store = (*SQLStore)(nil)
182

183
// SQLStoreConfig holds the configuration for the SQLStore.
184
type SQLStoreConfig struct {
185
        // ChainHash is the genesis hash for the chain that all the gossip
186
        // messages in this store are aimed at.
187
        ChainHash chainhash.Hash
188

189
        // QueryConfig holds configuration values for SQL queries.
190
        QueryCfg *sqldb.QueryConfig
191
}
192

193
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
194
// storage backend.
195
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
196
        options ...StoreOptionModifier) (*SQLStore, error) {
×
197

×
198
        opts := DefaultOptions()
×
199
        for _, o := range options {
×
200
                o(opts)
×
201
        }
×
202

203
        if opts.NoMigration {
×
204
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
205
                        "supported for SQL stores")
×
206
        }
×
207

208
        s := &SQLStore{
×
209
                cfg:         cfg,
×
210
                db:          db,
×
211
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
212
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
213
                srcNodes:    make(map[ProtocolVersion]*srcNodeInfo),
×
214
        }
×
215

×
216
        s.chanScheduler = batch.NewTimeScheduler(
×
217
                db, &s.cacheMu, opts.BatchCommitInterval,
×
218
        )
×
219
        s.nodeScheduler = batch.NewTimeScheduler(
×
220
                db, nil, opts.BatchCommitInterval,
×
221
        )
×
222

×
223
        return s, nil
×
224
}
225

226
// AddLightningNode adds a vertex/node to the graph database. If the node is not
227
// in the database from before, this will add a new, unconnected one to the
228
// graph. If it is present from before, this will update that node's
229
// information.
230
//
231
// NOTE: part of the V1Store interface.
232
func (s *SQLStore) AddLightningNode(ctx context.Context,
233
        node *models.LightningNode, opts ...batch.SchedulerOption) error {
×
234

×
235
        r := &batch.Request[SQLQueries]{
×
236
                Opts: batch.NewSchedulerOptions(opts...),
×
237
                Do: func(queries SQLQueries) error {
×
238
                        _, err := upsertNode(ctx, queries, node)
×
239
                        return err
×
240
                },
×
241
        }
242

243
        return s.nodeScheduler.Execute(ctx, r)
×
244
}
245

246
// FetchLightningNode attempts to look up a target node by its identity public
247
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
248
// returned.
249
//
250
// NOTE: part of the V1Store interface.
251
func (s *SQLStore) FetchLightningNode(ctx context.Context,
252
        pubKey route.Vertex) (*models.LightningNode, error) {
×
253

×
254
        var node *models.LightningNode
×
255
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
256
                var err error
×
257
                _, node, err = getNodeByPubKey(ctx, db, pubKey)
×
258

×
259
                return err
×
260
        }, sqldb.NoOpReset)
×
261
        if err != nil {
×
262
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
263
        }
×
264

265
        return node, nil
×
266
}
267

268
// HasLightningNode determines if the graph has a vertex identified by the
269
// target node identity public key. If the node exists in the database, a
270
// timestamp of when the data for the node was lasted updated is returned along
271
// with a true boolean. Otherwise, an empty time.Time is returned with a false
272
// boolean.
273
//
274
// NOTE: part of the V1Store interface.
275
func (s *SQLStore) HasLightningNode(ctx context.Context,
276
        pubKey [33]byte) (time.Time, bool, error) {
×
277

×
278
        var (
×
279
                exists     bool
×
280
                lastUpdate time.Time
×
281
        )
×
282
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
283
                dbNode, err := db.GetNodeByPubKey(
×
284
                        ctx, sqlc.GetNodeByPubKeyParams{
×
285
                                Version: int16(ProtocolV1),
×
286
                                PubKey:  pubKey[:],
×
287
                        },
×
288
                )
×
289
                if errors.Is(err, sql.ErrNoRows) {
×
290
                        return nil
×
291
                } else if err != nil {
×
292
                        return fmt.Errorf("unable to fetch node: %w", err)
×
293
                }
×
294

295
                exists = true
×
296

×
297
                if dbNode.LastUpdate.Valid {
×
298
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
299
                }
×
300

301
                return nil
×
302
        }, sqldb.NoOpReset)
303
        if err != nil {
×
304
                return time.Time{}, false,
×
305
                        fmt.Errorf("unable to fetch node: %w", err)
×
306
        }
×
307

308
        return lastUpdate, exists, nil
×
309
}
310

311
// AddrsForNode returns all known addresses for the target node public key
312
// that the graph DB is aware of. The returned boolean indicates if the
313
// given node is unknown to the graph DB or not.
314
//
315
// NOTE: part of the V1Store interface.
316
func (s *SQLStore) AddrsForNode(ctx context.Context,
317
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
318

×
319
        var (
×
320
                addresses []net.Addr
×
321
                known     bool
×
322
        )
×
323
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
324
                // First, check if the node exists and get its DB ID if it
×
325
                // does.
×
326
                dbID, err := db.GetNodeIDByPubKey(
×
327
                        ctx, sqlc.GetNodeIDByPubKeyParams{
×
328
                                Version: int16(ProtocolV1),
×
329
                                PubKey:  nodePub.SerializeCompressed(),
×
330
                        },
×
331
                )
×
332
                if errors.Is(err, sql.ErrNoRows) {
×
333
                        return nil
×
334
                }
×
335

336
                known = true
×
337

×
338
                addresses, err = getNodeAddresses(ctx, db, dbID)
×
339
                if err != nil {
×
340
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
341
                                err)
×
342
                }
×
343

344
                return nil
×
345
        }, sqldb.NoOpReset)
346
        if err != nil {
×
347
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
348
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
349
        }
×
350

351
        return known, addresses, nil
×
352
}
353

354
// DeleteLightningNode starts a new database transaction to remove a vertex/node
355
// from the database according to the node's public key.
356
//
357
// NOTE: part of the V1Store interface.
358
func (s *SQLStore) DeleteLightningNode(ctx context.Context,
359
        pubKey route.Vertex) error {
×
360

×
361
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
362
                res, err := db.DeleteNodeByPubKey(
×
363
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
364
                                Version: int16(ProtocolV1),
×
365
                                PubKey:  pubKey[:],
×
366
                        },
×
367
                )
×
368
                if err != nil {
×
369
                        return err
×
370
                }
×
371

372
                rows, err := res.RowsAffected()
×
373
                if err != nil {
×
374
                        return err
×
375
                }
×
376

377
                if rows == 0 {
×
378
                        return ErrGraphNodeNotFound
×
379
                } else if rows > 1 {
×
380
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
381
                }
×
382

383
                return err
×
384
        }, sqldb.NoOpReset)
385
        if err != nil {
×
386
                return fmt.Errorf("unable to delete node: %w", err)
×
387
        }
×
388

389
        return nil
×
390
}
391

392
// FetchNodeFeatures returns the features of the given node. If no features are
393
// known for the node, an empty feature vector is returned.
394
//
395
// NOTE: this is part of the graphdb.NodeTraverser interface.
396
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
397
        *lnwire.FeatureVector, error) {
×
398

×
399
        ctx := context.TODO()
×
400

×
401
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
402
}
×
403

404
// DisabledChannelIDs returns the channel ids of disabled channels.
405
// A channel is disabled when two of the associated ChanelEdgePolicies
406
// have their disabled bit on.
407
//
408
// NOTE: part of the V1Store interface.
409
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
410
        var (
×
411
                ctx     = context.TODO()
×
412
                chanIDs []uint64
×
413
        )
×
414
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
415
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
416
                if err != nil {
×
417
                        return fmt.Errorf("unable to fetch disabled "+
×
418
                                "channels: %w", err)
×
419
                }
×
420

421
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
422

×
423
                return nil
×
424
        }, sqldb.NoOpReset)
425
        if err != nil {
×
426
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
427
                        err)
×
428
        }
×
429

430
        return chanIDs, nil
×
431
}
432

433
// LookupAlias attempts to return the alias as advertised by the target node.
434
//
435
// NOTE: part of the V1Store interface.
436
func (s *SQLStore) LookupAlias(ctx context.Context,
437
        pub *btcec.PublicKey) (string, error) {
×
438

×
439
        var alias string
×
440
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
441
                dbNode, err := db.GetNodeByPubKey(
×
442
                        ctx, sqlc.GetNodeByPubKeyParams{
×
443
                                Version: int16(ProtocolV1),
×
444
                                PubKey:  pub.SerializeCompressed(),
×
445
                        },
×
446
                )
×
447
                if errors.Is(err, sql.ErrNoRows) {
×
448
                        return ErrNodeAliasNotFound
×
449
                } else if err != nil {
×
450
                        return fmt.Errorf("unable to fetch node: %w", err)
×
451
                }
×
452

453
                if !dbNode.Alias.Valid {
×
454
                        return ErrNodeAliasNotFound
×
455
                }
×
456

457
                alias = dbNode.Alias.String
×
458

×
459
                return nil
×
460
        }, sqldb.NoOpReset)
461
        if err != nil {
×
462
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
463
        }
×
464

465
        return alias, nil
×
466
}
467

468
// SourceNode returns the source node of the graph. The source node is treated
469
// as the center node within a star-graph. This method may be used to kick off
470
// a path finding algorithm in order to explore the reachability of another
471
// node based off the source node.
472
//
473
// NOTE: part of the V1Store interface.
474
func (s *SQLStore) SourceNode(ctx context.Context) (*models.LightningNode,
475
        error) {
×
476

×
477
        var node *models.LightningNode
×
478
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
479
                _, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
480
                if err != nil {
×
481
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
482
                                err)
×
483
                }
×
484

485
                _, node, err = getNodeByPubKey(ctx, db, nodePub)
×
486

×
487
                return err
×
488
        }, sqldb.NoOpReset)
489
        if err != nil {
×
490
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
491
        }
×
492

493
        return node, nil
×
494
}
495

496
// SetSourceNode sets the source node within the graph database. The source
497
// node is to be used as the center of a star-graph within path finding
498
// algorithms.
499
//
500
// NOTE: part of the V1Store interface.
501
func (s *SQLStore) SetSourceNode(ctx context.Context,
502
        node *models.LightningNode) error {
×
503

×
504
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
505
                id, err := upsertNode(ctx, db, node)
×
506
                if err != nil {
×
507
                        return fmt.Errorf("unable to upsert source node: %w",
×
508
                                err)
×
509
                }
×
510

511
                // Make sure that if a source node for this version is already
512
                // set, then the ID is the same as the one we are about to set.
513
                dbSourceNodeID, _, err := s.getSourceNode(ctx, db, ProtocolV1)
×
514
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
515
                        return fmt.Errorf("unable to fetch source node: %w",
×
516
                                err)
×
517
                } else if err == nil {
×
518
                        if dbSourceNodeID != id {
×
519
                                return fmt.Errorf("v1 source node already "+
×
520
                                        "set to a different node: %d vs %d",
×
521
                                        dbSourceNodeID, id)
×
522
                        }
×
523

524
                        return nil
×
525
                }
526

527
                return db.AddSourceNode(ctx, id)
×
528
        }, sqldb.NoOpReset)
529
}
530

531
// NodeUpdatesInHorizon returns all the known lightning node which have an
532
// update timestamp within the passed range. This method can be used by two
533
// nodes to quickly determine if they have the same set of up to date node
534
// announcements.
535
//
536
// NOTE: This is part of the V1Store interface.
537
func (s *SQLStore) NodeUpdatesInHorizon(startTime,
538
        endTime time.Time) ([]models.LightningNode, error) {
×
539

×
540
        ctx := context.TODO()
×
541

×
542
        var nodes []models.LightningNode
×
543
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
544
                dbNodes, err := db.GetNodesByLastUpdateRange(
×
545
                        ctx, sqlc.GetNodesByLastUpdateRangeParams{
×
546
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
547
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
548
                        },
×
549
                )
×
550
                if err != nil {
×
551
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
552
                }
×
553

554
                err = forEachNodeInBatch(
×
NEW
555
                        ctx, s.cfg.QueryCfg, db, dbNodes,
×
556
                        func(_ int64, node *models.LightningNode) error {
×
557
                                nodes = append(nodes, *node)
×
558

×
559
                                return nil
×
560
                        },
×
561
                )
562
                if err != nil {
×
563
                        return fmt.Errorf("unable to build nodes: %w", err)
×
564
                }
×
565

566
                return nil
×
567
        }, sqldb.NoOpReset)
568
        if err != nil {
×
569
                return nil, fmt.Errorf("unable to fetch nodes: %w", err)
×
570
        }
×
571

572
        return nodes, nil
×
573
}
574

575
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
576
// undirected edge from the two target nodes are created. The information stored
577
// denotes the static attributes of the channel, such as the channelID, the keys
578
// involved in creation of the channel, and the set of features that the channel
579
// supports. The chanPoint and chanID are used to uniquely identify the edge
580
// globally within the database.
581
//
582
// NOTE: part of the V1Store interface.
583
func (s *SQLStore) AddChannelEdge(ctx context.Context,
584
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
585

×
586
        var alreadyExists bool
×
587
        r := &batch.Request[SQLQueries]{
×
588
                Opts: batch.NewSchedulerOptions(opts...),
×
589
                Reset: func() {
×
590
                        alreadyExists = false
×
591
                },
×
592
                Do: func(tx SQLQueries) error {
×
593
                        _, err := insertChannel(ctx, tx, edge)
×
594

×
595
                        // Silence ErrEdgeAlreadyExist so that the batch can
×
596
                        // succeed, but propagate the error via local state.
×
597
                        if errors.Is(err, ErrEdgeAlreadyExist) {
×
598
                                alreadyExists = true
×
599
                                return nil
×
600
                        }
×
601

602
                        return err
×
603
                },
604
                OnCommit: func(err error) error {
×
605
                        switch {
×
606
                        case err != nil:
×
607
                                return err
×
608
                        case alreadyExists:
×
609
                                return ErrEdgeAlreadyExist
×
610
                        default:
×
611
                                s.rejectCache.remove(edge.ChannelID)
×
612
                                s.chanCache.remove(edge.ChannelID)
×
613
                                return nil
×
614
                        }
615
                },
616
        }
617

618
        return s.chanScheduler.Execute(ctx, r)
×
619
}
620

621
// HighestChanID returns the "highest" known channel ID in the channel graph.
622
// This represents the "newest" channel from the PoV of the chain. This method
623
// can be used by peers to quickly determine if their graphs are in sync.
624
//
625
// NOTE: This is part of the V1Store interface.
626
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
627
        var highestChanID uint64
×
628
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
629
                chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
×
630
                if errors.Is(err, sql.ErrNoRows) {
×
631
                        return nil
×
632
                } else if err != nil {
×
633
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
634
                                err)
×
635
                }
×
636

637
                highestChanID = byteOrder.Uint64(chanID)
×
638

×
639
                return nil
×
640
        }, sqldb.NoOpReset)
641
        if err != nil {
×
642
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
643
        }
×
644

645
        return highestChanID, nil
×
646
}
647

648
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
649
// within the database for the referenced channel. The `flags` attribute within
650
// the ChannelEdgePolicy determines which of the directed edges are being
651
// updated. If the flag is 1, then the first node's information is being
652
// updated, otherwise it's the second node's information. The node ordering is
653
// determined by the lexicographical ordering of the identity public keys of the
654
// nodes on either side of the channel.
655
//
656
// NOTE: part of the V1Store interface.
657
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
658
        edge *models.ChannelEdgePolicy,
659
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
660

×
661
        var (
×
662
                isUpdate1    bool
×
663
                edgeNotFound bool
×
664
                from, to     route.Vertex
×
665
        )
×
666

×
667
        r := &batch.Request[SQLQueries]{
×
668
                Opts: batch.NewSchedulerOptions(opts...),
×
669
                Reset: func() {
×
670
                        isUpdate1 = false
×
671
                        edgeNotFound = false
×
672
                },
×
673
                Do: func(tx SQLQueries) error {
×
674
                        var err error
×
675
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
676
                                ctx, tx, edge,
×
677
                        )
×
678
                        if err != nil {
×
679
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
680
                        }
×
681

682
                        // Silence ErrEdgeNotFound so that the batch can
683
                        // succeed, but propagate the error via local state.
684
                        if errors.Is(err, ErrEdgeNotFound) {
×
685
                                edgeNotFound = true
×
686
                                return nil
×
687
                        }
×
688

689
                        return err
×
690
                },
691
                OnCommit: func(err error) error {
×
692
                        switch {
×
693
                        case err != nil:
×
694
                                return err
×
695
                        case edgeNotFound:
×
696
                                return ErrEdgeNotFound
×
697
                        default:
×
698
                                s.updateEdgeCache(edge, isUpdate1)
×
699
                                return nil
×
700
                        }
701
                },
702
        }
703

704
        err := s.chanScheduler.Execute(ctx, r)
×
705

×
706
        return from, to, err
×
707
}
708

709
// updateEdgeCache updates our reject and channel caches with the new
710
// edge policy information.
711
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
712
        isUpdate1 bool) {
×
713

×
714
        // If an entry for this channel is found in reject cache, we'll modify
×
715
        // the entry with the updated timestamp for the direction that was just
×
716
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
717
        // during the next query for this edge.
×
718
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
719
                if isUpdate1 {
×
720
                        entry.upd1Time = e.LastUpdate.Unix()
×
721
                } else {
×
722
                        entry.upd2Time = e.LastUpdate.Unix()
×
723
                }
×
724
                s.rejectCache.insert(e.ChannelID, entry)
×
725
        }
726

727
        // If an entry for this channel is found in channel cache, we'll modify
728
        // the entry with the updated policy for the direction that was just
729
        // written. If the edge doesn't exist, we'll defer loading the info and
730
        // policies and lazily read from disk during the next query.
731
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
732
                if isUpdate1 {
×
733
                        channel.Policy1 = e
×
734
                } else {
×
735
                        channel.Policy2 = e
×
736
                }
×
737
                s.chanCache.insert(e.ChannelID, channel)
×
738
        }
739
}
740

741
// ForEachSourceNodeChannel iterates through all channels of the source node,
742
// executing the passed callback on each. The call-back is provided with the
743
// channel's outpoint, whether we have a policy for the channel and the channel
744
// peer's node information.
745
//
746
// NOTE: part of the V1Store interface.
747
func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context,
748
        cb func(chanPoint wire.OutPoint, havePolicy bool,
749
                otherNode *models.LightningNode) error, reset func()) error {
×
750

×
751
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
752
                nodeID, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
753
                if err != nil {
×
754
                        return fmt.Errorf("unable to fetch source node: %w",
×
755
                                err)
×
756
                }
×
757

758
                return forEachNodeChannel(
×
759
                        ctx, db, s.cfg.ChainHash, nodeID,
×
760
                        func(info *models.ChannelEdgeInfo,
×
761
                                outPolicy *models.ChannelEdgePolicy,
×
762
                                _ *models.ChannelEdgePolicy) error {
×
763

×
764
                                // Fetch the other node.
×
765
                                var (
×
766
                                        otherNodePub [33]byte
×
767
                                        node1        = info.NodeKey1Bytes
×
768
                                        node2        = info.NodeKey2Bytes
×
769
                                )
×
770
                                switch {
×
771
                                case bytes.Equal(node1[:], nodePub[:]):
×
772
                                        otherNodePub = node2
×
773
                                case bytes.Equal(node2[:], nodePub[:]):
×
774
                                        otherNodePub = node1
×
775
                                default:
×
776
                                        return fmt.Errorf("node not " +
×
777
                                                "participating in this channel")
×
778
                                }
779

780
                                _, otherNode, err := getNodeByPubKey(
×
781
                                        ctx, db, otherNodePub,
×
782
                                )
×
783
                                if err != nil {
×
784
                                        return fmt.Errorf("unable to fetch "+
×
785
                                                "other node(%x): %w",
×
786
                                                otherNodePub, err)
×
787
                                }
×
788

789
                                return cb(
×
790
                                        info.ChannelPoint, outPolicy != nil,
×
791
                                        otherNode,
×
792
                                )
×
793
                        },
794
                )
795
        }, reset)
796
}
797

798
// ForEachNode iterates through all the stored vertices/nodes in the graph,
799
// executing the passed callback with each node encountered. If the callback
800
// returns an error, then the transaction is aborted and the iteration stops
801
// early. Any operations performed on the NodeTx passed to the call-back are
802
// executed under the same read transaction and so, methods on the NodeTx object
803
// _MUST_ only be called from within the call-back.
804
//
805
// NOTE: part of the V1Store interface.
806
func (s *SQLStore) ForEachNode(ctx context.Context,
807
        cb func(tx NodeRTx) error, reset func()) error {
×
808

×
809
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
810
                return forEachNodePaginated(
×
NEW
811
                        ctx, s.cfg.QueryCfg, db,
×
NEW
812
                        func(ctx context.Context, dbNodeID int64,
×
NEW
813
                                node *models.LightningNode) error {
×
NEW
814

×
NEW
815
                                return cb(newSQLGraphNodeTx(
×
NEW
816
                                        db, s.cfg.ChainHash, dbNodeID, node,
×
NEW
817
                                ))
×
NEW
818
                        },
×
819
                )
820
        }, reset)
821
}
822

823
// sqlGraphNodeTx is an implementation of the NodeRTx interface backed by the
824
// SQLStore and a SQL transaction.
825
type sqlGraphNodeTx struct {
826
        db    SQLQueries
827
        id    int64
828
        node  *models.LightningNode
829
        chain chainhash.Hash
830
}
831

832
// A compile-time constraint to ensure sqlGraphNodeTx implements the NodeRTx
833
// interface.
834
var _ NodeRTx = (*sqlGraphNodeTx)(nil)
835

836
func newSQLGraphNodeTx(db SQLQueries, chain chainhash.Hash,
837
        id int64, node *models.LightningNode) *sqlGraphNodeTx {
×
838

×
839
        return &sqlGraphNodeTx{
×
840
                db:    db,
×
841
                chain: chain,
×
842
                id:    id,
×
843
                node:  node,
×
844
        }
×
845
}
×
846

847
// Node returns the raw information of the node.
848
//
849
// NOTE: This is a part of the NodeRTx interface.
850
func (s *sqlGraphNodeTx) Node() *models.LightningNode {
×
851
        return s.node
×
852
}
×
853

854
// ForEachChannel can be used to iterate over the node's channels under the same
855
// transaction used to fetch the node.
856
//
857
// NOTE: This is a part of the NodeRTx interface.
858
func (s *sqlGraphNodeTx) ForEachChannel(cb func(*models.ChannelEdgeInfo,
859
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
×
860

×
861
        ctx := context.TODO()
×
862

×
863
        return forEachNodeChannel(ctx, s.db, s.chain, s.id, cb)
×
864
}
×
865

866
// FetchNode fetches the node with the given pub key under the same transaction
867
// used to fetch the current node. The returned node is also a NodeRTx and any
868
// operations on that NodeRTx will also be done under the same transaction.
869
//
870
// NOTE: This is a part of the NodeRTx interface.
871
func (s *sqlGraphNodeTx) FetchNode(nodePub route.Vertex) (NodeRTx, error) {
×
872
        ctx := context.TODO()
×
873

×
874
        id, node, err := getNodeByPubKey(ctx, s.db, nodePub)
×
875
        if err != nil {
×
876
                return nil, fmt.Errorf("unable to fetch V1 node(%x): %w",
×
877
                        nodePub, err)
×
878
        }
×
879

880
        return newSQLGraphNodeTx(s.db, s.chain, id, node), nil
×
881
}
882

883
// ForEachNodeDirectedChannel iterates through all channels of a given node,
884
// executing the passed callback on the directed edge representing the channel
885
// and its incoming policy. If the callback returns an error, then the iteration
886
// is halted with the error propagated back up to the caller.
887
//
888
// Unknown policies are passed into the callback as nil values.
889
//
890
// NOTE: this is part of the graphdb.NodeTraverser interface.
891
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
892
        cb func(channel *DirectedChannel) error, reset func()) error {
×
893

×
894
        var ctx = context.TODO()
×
895

×
896
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
897
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
898
        }, reset)
×
899
}
900

901
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
902
// graph, executing the passed callback with each node encountered. If the
903
// callback returns an error, then the transaction is aborted and the iteration
904
// stops early.
905
//
906
// NOTE: This is a part of the V1Store interface.
907
func (s *SQLStore) ForEachNodeCacheable(ctx context.Context,
908
        cb func(route.Vertex, *lnwire.FeatureVector) error,
909
        reset func()) error {
×
910

×
911
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
912
                return forEachNodeCacheable(
×
NEW
913
                        ctx, s.cfg.QueryCfg, db,
×
NEW
914
                        func(nodeID int64, nodePub route.Vertex) error {
×
NEW
915
                                features, err := getNodeFeatures(
×
NEW
916
                                        ctx, db, nodeID,
×
NEW
917
                                )
×
NEW
918
                                if err != nil {
×
NEW
919
                                        return fmt.Errorf("unable to fetch "+
×
NEW
920
                                                "node features: %w", err)
×
NEW
921
                                }
×
922

NEW
923
                                return cb(nodePub, features)
×
924
                        },
925
                )
926
        }, reset)
927
        if err != nil {
×
928
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
929
        }
×
930

931
        return nil
×
932
}
933

934
// ForEachNodeChannel iterates through all channels of the given node,
935
// executing the passed callback with an edge info structure and the policies
936
// of each end of the channel. The first edge policy is the outgoing edge *to*
937
// the connecting node, while the second is the incoming edge *from* the
938
// connecting node. If the callback returns an error, then the iteration is
939
// halted with the error propagated back up to the caller.
940
//
941
// Unknown policies are passed into the callback as nil values.
942
//
943
// NOTE: part of the V1Store interface.
944
func (s *SQLStore) ForEachNodeChannel(ctx context.Context, nodePub route.Vertex,
945
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
946
                *models.ChannelEdgePolicy) error, reset func()) error {
×
947

×
948
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
949
                dbNode, err := db.GetNodeByPubKey(
×
950
                        ctx, sqlc.GetNodeByPubKeyParams{
×
951
                                Version: int16(ProtocolV1),
×
952
                                PubKey:  nodePub[:],
×
953
                        },
×
954
                )
×
955
                if errors.Is(err, sql.ErrNoRows) {
×
956
                        return nil
×
957
                } else if err != nil {
×
958
                        return fmt.Errorf("unable to fetch node: %w", err)
×
959
                }
×
960

961
                return forEachNodeChannel(
×
962
                        ctx, db, s.cfg.ChainHash, dbNode.ID, cb,
×
963
                )
×
964
        }, reset)
965
}
966

967
// ChanUpdatesInHorizon returns all the known channel edges which have at least
968
// one edge that has an update timestamp within the specified horizon.
969
//
970
// NOTE: This is part of the V1Store interface.
971
func (s *SQLStore) ChanUpdatesInHorizon(startTime,
972
        endTime time.Time) ([]ChannelEdge, error) {
×
973

×
974
        s.cacheMu.Lock()
×
975
        defer s.cacheMu.Unlock()
×
976

×
977
        var (
×
978
                ctx = context.TODO()
×
979
                // To ensure we don't return duplicate ChannelEdges, we'll use
×
980
                // an additional map to keep track of the edges already seen to
×
981
                // prevent re-adding it.
×
982
                edgesSeen    = make(map[uint64]struct{})
×
983
                edgesToCache = make(map[uint64]ChannelEdge)
×
984
                edges        []ChannelEdge
×
985
                hits         int
×
986
        )
×
987
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
988
                rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
989
                        ctx, sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
990
                                Version:   int16(ProtocolV1),
×
991
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
992
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
993
                        },
×
994
                )
×
995
                if err != nil {
×
996
                        return err
×
997
                }
×
998

999
                for _, row := range rows {
×
1000
                        // If we've already retrieved the info and policies for
×
1001
                        // this edge, then we can skip it as we don't need to do
×
1002
                        // so again.
×
1003
                        chanIDInt := byteOrder.Uint64(row.GraphChannel.Scid)
×
1004
                        if _, ok := edgesSeen[chanIDInt]; ok {
×
1005
                                continue
×
1006
                        }
1007

1008
                        if channel, ok := s.chanCache.get(chanIDInt); ok {
×
1009
                                hits++
×
1010
                                edgesSeen[chanIDInt] = struct{}{}
×
1011
                                edges = append(edges, channel)
×
1012

×
1013
                                continue
×
1014
                        }
1015

1016
                        node1, node2, err := buildNodes(
×
1017
                                ctx, db, row.GraphNode, row.GraphNode_2,
×
1018
                        )
×
1019
                        if err != nil {
×
1020
                                return err
×
1021
                        }
×
1022

1023
                        channel, err := getAndBuildEdgeInfo(
×
1024
                                ctx, db, s.cfg.ChainHash, row.GraphChannel,
×
1025
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
1026
                        )
×
1027
                        if err != nil {
×
1028
                                return fmt.Errorf("unable to build channel "+
×
1029
                                        "info: %w", err)
×
1030
                        }
×
1031

1032
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1033
                        if err != nil {
×
1034
                                return fmt.Errorf("unable to extract channel "+
×
1035
                                        "policies: %w", err)
×
1036
                        }
×
1037

1038
                        p1, p2, err := getAndBuildChanPolicies(
×
1039
                                ctx, db, dbPol1, dbPol2, channel.ChannelID,
×
1040
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
1041
                        )
×
1042
                        if err != nil {
×
1043
                                return fmt.Errorf("unable to build channel "+
×
1044
                                        "policies: %w", err)
×
1045
                        }
×
1046

1047
                        edgesSeen[chanIDInt] = struct{}{}
×
1048
                        chanEdge := ChannelEdge{
×
1049
                                Info:    channel,
×
1050
                                Policy1: p1,
×
1051
                                Policy2: p2,
×
1052
                                Node1:   node1,
×
1053
                                Node2:   node2,
×
1054
                        }
×
1055
                        edges = append(edges, chanEdge)
×
1056
                        edgesToCache[chanIDInt] = chanEdge
×
1057
                }
1058

1059
                return nil
×
1060
        }, func() {
×
1061
                edgesSeen = make(map[uint64]struct{})
×
1062
                edgesToCache = make(map[uint64]ChannelEdge)
×
1063
                edges = nil
×
1064
        })
×
1065
        if err != nil {
×
1066
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
1067
        }
×
1068

1069
        // Insert any edges loaded from disk into the cache.
1070
        for chanid, channel := range edgesToCache {
×
1071
                s.chanCache.insert(chanid, channel)
×
1072
        }
×
1073

1074
        if len(edges) > 0 {
×
1075
                log.Debugf("ChanUpdatesInHorizon hit percentage: %.2f (%d/%d)",
×
1076
                        float64(hits)*100/float64(len(edges)), hits, len(edges))
×
1077
        } else {
×
1078
                log.Debugf("ChanUpdatesInHorizon returned no edges in "+
×
1079
                        "horizon (%s, %s)", startTime, endTime)
×
1080
        }
×
1081

1082
        return edges, nil
×
1083
}
1084

1085
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1086
// data to the call-back.
1087
//
1088
// NOTE: The callback contents MUST not be modified.
1089
//
1090
// NOTE: part of the V1Store interface.
1091
func (s *SQLStore) ForEachNodeCached(ctx context.Context,
1092
        cb func(node route.Vertex, chans map[uint64]*DirectedChannel) error,
1093
        reset func()) error {
×
1094

×
NEW
1095
        handleNode := func(db SQLQueries, nodeID int64,
×
NEW
1096
                nodePub route.Vertex) error {
×
NEW
1097

×
NEW
1098
                features, err := getNodeFeatures(ctx, db, nodeID)
×
NEW
1099
                if err != nil {
×
NEW
1100
                        return fmt.Errorf("unable to fetch node(id=%d) "+
×
NEW
1101
                                "features: %w", nodeID, err)
×
NEW
1102
                }
×
1103

NEW
1104
                toNodeCallback := func() route.Vertex {
×
NEW
1105
                        return nodePub
×
NEW
1106
                }
×
1107

NEW
1108
                rows, err := db.ListChannelsByNodeID(
×
NEW
1109
                        ctx, sqlc.ListChannelsByNodeIDParams{
×
NEW
1110
                                Version: int16(ProtocolV1),
×
NEW
1111
                                NodeID1: nodeID,
×
NEW
1112
                        },
×
NEW
1113
                )
×
NEW
1114
                if err != nil {
×
NEW
1115
                        return fmt.Errorf("unable to fetch channels of "+
×
NEW
1116
                                "node(id=%d): %w", nodeID, err)
×
NEW
1117
                }
×
1118

NEW
1119
                channels := make(map[uint64]*DirectedChannel, len(rows))
×
NEW
1120
                for _, row := range rows {
×
NEW
1121
                        node1, node2, err := buildNodeVertices(
×
NEW
1122
                                row.Node1Pubkey, row.Node2Pubkey,
×
NEW
1123
                        )
×
1124
                        if err != nil {
×
NEW
1125
                                return err
×
1126
                        }
×
1127

NEW
1128
                        e, err := getAndBuildEdgeInfo(
×
NEW
1129
                                ctx, db, s.cfg.ChainHash, row.GraphChannel,
×
NEW
1130
                                node1, node2,
×
NEW
1131
                        )
×
NEW
1132
                        if err != nil {
×
NEW
1133
                                return fmt.Errorf("unable to build channel "+
×
NEW
1134
                                        "info: %w", err)
×
UNCOV
1135
                        }
×
1136

NEW
1137
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1138
                        if err != nil {
×
NEW
1139
                                return fmt.Errorf("unable to extract channel "+
×
NEW
1140
                                        "policies: %w", err)
×
UNCOV
1141
                        }
×
1142

NEW
1143
                        p1, p2, err := getAndBuildChanPolicies(
×
NEW
1144
                                ctx, db, dbPol1, dbPol2, e.ChannelID, node1,
×
NEW
1145
                                node2,
×
NEW
1146
                        )
×
NEW
1147
                        if err != nil {
×
NEW
1148
                                return fmt.Errorf("unable to build channel "+
×
NEW
1149
                                        "policies: %w", err)
×
NEW
1150
                        }
×
1151

1152
                        // Determine the outgoing and incoming policy
1153
                        // for this channel and node combo.
NEW
1154
                        outPolicy, inPolicy := p1, p2
×
NEW
1155
                        if p1 != nil && p1.ToNode == nodePub {
×
NEW
1156
                                outPolicy, inPolicy = p2, p1
×
NEW
1157
                        } else if p2 != nil && p2.ToNode != nodePub {
×
NEW
1158
                                outPolicy, inPolicy = p2, p1
×
NEW
1159
                        }
×
1160

NEW
1161
                        var cachedInPolicy *models.CachedEdgePolicy
×
NEW
1162
                        if inPolicy != nil {
×
NEW
1163
                                cachedInPolicy = models.NewCachedPolicy(
×
NEW
1164
                                        inPolicy,
×
1165
                                )
×
NEW
1166
                                cachedInPolicy.ToNodePubKey = toNodeCallback
×
NEW
1167
                                cachedInPolicy.ToNodeFeatures = features
×
NEW
1168
                        }
×
1169

NEW
1170
                        var inboundFee lnwire.Fee
×
NEW
1171
                        if outPolicy != nil {
×
NEW
1172
                                outPolicy.InboundFee.WhenSome(
×
NEW
1173
                                        func(fee lnwire.Fee) {
×
NEW
1174
                                                inboundFee = fee
×
NEW
1175
                                        },
×
1176
                                )
1177
                        }
1178

NEW
1179
                        directedChannel := &DirectedChannel{
×
NEW
1180
                                ChannelID:    e.ChannelID,
×
NEW
1181
                                IsNode1:      nodePub == e.NodeKey1Bytes,
×
NEW
1182
                                OtherNode:    e.NodeKey2Bytes,
×
NEW
1183
                                Capacity:     e.Capacity,
×
NEW
1184
                                OutPolicySet: outPolicy != nil,
×
NEW
1185
                                InPolicy:     cachedInPolicy,
×
NEW
1186
                                InboundFee:   inboundFee,
×
NEW
1187
                        }
×
1188

×
NEW
1189
                        if nodePub == e.NodeKey2Bytes {
×
NEW
1190
                                directedChannel.OtherNode = e.NodeKey1Bytes
×
NEW
1191
                        }
×
1192

NEW
1193
                        channels[e.ChannelID] = directedChannel
×
1194
                }
1195

NEW
1196
                return cb(nodePub, channels)
×
1197
        }
1198

NEW
1199
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
1200
                return forEachNodeCacheable(
×
NEW
1201
                        ctx, s.cfg.QueryCfg, db,
×
NEW
1202
                        func(nodeID int64, nodePub route.Vertex) error {
×
NEW
1203
                                return handleNode(db, nodeID, nodePub)
×
NEW
1204
                        },
×
1205
                )
1206
        }, reset)
1207
}
1208

1209
// ForEachChannelCacheable iterates through all the channel edges stored
1210
// within the graph and invokes the passed callback for each edge. The
1211
// callback takes two edges as since this is a directed graph, both the
1212
// in/out edges are visited. If the callback returns an error, then the
1213
// transaction is aborted and the iteration stops early.
1214
//
1215
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1216
// pointer for that particular channel edge routing policy will be
1217
// passed into the callback.
1218
//
1219
// NOTE: this method is like ForEachChannel but fetches only the data
1220
// required for the graph cache.
1221
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1222
        *models.CachedEdgePolicy, *models.CachedEdgePolicy) error,
1223
        reset func()) error {
×
1224

×
1225
        ctx := context.TODO()
×
1226

×
NEW
1227
        handleChannel := func(_ context.Context,
×
1228
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) error {
×
1229

×
1230
                node1, node2, err := buildNodeVertices(
×
1231
                        row.Node1Pubkey, row.Node2Pubkey,
×
1232
                )
×
1233
                if err != nil {
×
1234
                        return err
×
1235
                }
×
1236

1237
                edge := buildCacheableChannelInfo(
×
1238
                        row.Scid, row.Capacity.Int64, node1, node2,
×
1239
                )
×
1240

×
1241
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1242
                if err != nil {
×
1243
                        return err
×
1244
                }
×
1245

NEW
1246
                pol1, pol2, err := buildCachedChanPolicies(
×
NEW
1247
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
NEW
1248
                )
×
NEW
1249
                if err != nil {
×
NEW
1250
                        return err
×
1251
                }
×
1252

NEW
1253
                return cb(edge, pol1, pol2)
×
1254
        }
1255

NEW
1256
        extractCursor := func(
×
NEW
1257
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) int64 {
×
1258

×
NEW
1259
                return row.ID
×
UNCOV
1260
        }
×
1261

1262
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
1263
                //nolint:ll
×
NEW
1264
                queryFunc := func(ctx context.Context, lastID int64,
×
NEW
1265
                        limit int32) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow,
×
NEW
1266
                        error) {
×
NEW
1267

×
NEW
1268
                        return db.ListChannelsWithPoliciesForCachePaginated(
×
1269
                                ctx, sqlc.ListChannelsWithPoliciesForCachePaginatedParams{
×
1270
                                        Version: int16(ProtocolV1),
×
1271
                                        ID:      lastID,
×
NEW
1272
                                        Limit:   limit,
×
1273
                                },
×
1274
                        )
×
1275
                }
×
1276

NEW
1277
                return sqldb.ExecutePaginatedQuery(
×
NEW
1278
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
NEW
1279
                        extractCursor, handleChannel,
×
NEW
1280
                )
×
1281
        }, reset)
1282
}
1283

1284
// ForEachChannel iterates through all the channel edges stored within the
1285
// graph and invokes the passed callback for each edge. The callback takes two
1286
// edges as since this is a directed graph, both the in/out edges are visited.
1287
// If the callback returns an error, then the transaction is aborted and the
1288
// iteration stops early.
1289
//
1290
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1291
// for that particular channel edge routing policy will be passed into the
1292
// callback.
1293
//
1294
// NOTE: part of the V1Store interface.
1295
func (s *SQLStore) ForEachChannel(ctx context.Context,
1296
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1297
                *models.ChannelEdgePolicy) error, reset func()) error {
×
1298

×
1299
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
1300
                return forEachChannelWithPolicies(ctx, db, s.cfg, cb)
×
1301
        }, reset)
×
1302
}
1303

1304
// FilterChannelRange returns the channel ID's of all known channels which were
1305
// mined in a block height within the passed range. The channel IDs are grouped
1306
// by their common block height. This method can be used to quickly share with a
1307
// peer the set of channels we know of within a particular range to catch them
1308
// up after a period of time offline. If withTimestamps is true then the
1309
// timestamp info of the latest received channel update messages of the channel
1310
// will be included in the response.
1311
//
1312
// NOTE: This is part of the V1Store interface.
1313
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1314
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1315

×
1316
        var (
×
1317
                ctx       = context.TODO()
×
1318
                startSCID = &lnwire.ShortChannelID{
×
1319
                        BlockHeight: startHeight,
×
1320
                }
×
1321
                endSCID = lnwire.ShortChannelID{
×
1322
                        BlockHeight: endHeight,
×
1323
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1324
                        TxPosition:  math.MaxUint16,
×
1325
                }
×
1326
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1327
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1328
        )
×
1329

×
1330
        // 1) get all channels where channelID is between start and end chan ID.
×
1331
        // 2) skip if not public (ie, no channel_proof)
×
1332
        // 3) collect that channel.
×
1333
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1334
        //    and add those timestamps to the collected channel.
×
1335
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1336
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1337
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1338
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1339
                                StartScid: chanIDStart,
×
1340
                                EndScid:   chanIDEnd,
×
1341
                        },
×
1342
                )
×
1343
                if err != nil {
×
1344
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1345
                                err)
×
1346
                }
×
1347

1348
                for _, dbChan := range dbChans {
×
1349
                        cid := lnwire.NewShortChanIDFromInt(
×
1350
                                byteOrder.Uint64(dbChan.Scid),
×
1351
                        )
×
1352
                        chanInfo := NewChannelUpdateInfo(
×
1353
                                cid, time.Time{}, time.Time{},
×
1354
                        )
×
1355

×
1356
                        if !withTimestamps {
×
1357
                                channelsPerBlock[cid.BlockHeight] = append(
×
1358
                                        channelsPerBlock[cid.BlockHeight],
×
1359
                                        chanInfo,
×
1360
                                )
×
1361

×
1362
                                continue
×
1363
                        }
1364

1365
                        //nolint:ll
1366
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1367
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1368
                                        Version:   int16(ProtocolV1),
×
1369
                                        ChannelID: dbChan.ID,
×
1370
                                        NodeID:    dbChan.NodeID1,
×
1371
                                },
×
1372
                        )
×
1373
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1374
                                return fmt.Errorf("unable to fetch node1 "+
×
1375
                                        "policy: %w", err)
×
1376
                        } else if err == nil {
×
1377
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1378
                                        node1Policy.LastUpdate.Int64, 0,
×
1379
                                )
×
1380
                        }
×
1381

1382
                        //nolint:ll
1383
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1384
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1385
                                        Version:   int16(ProtocolV1),
×
1386
                                        ChannelID: dbChan.ID,
×
1387
                                        NodeID:    dbChan.NodeID2,
×
1388
                                },
×
1389
                        )
×
1390
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1391
                                return fmt.Errorf("unable to fetch node2 "+
×
1392
                                        "policy: %w", err)
×
1393
                        } else if err == nil {
×
1394
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1395
                                        node2Policy.LastUpdate.Int64, 0,
×
1396
                                )
×
1397
                        }
×
1398

1399
                        channelsPerBlock[cid.BlockHeight] = append(
×
1400
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1401
                        )
×
1402
                }
1403

1404
                return nil
×
1405
        }, func() {
×
1406
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1407
        })
×
1408
        if err != nil {
×
1409
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1410
        }
×
1411

1412
        if len(channelsPerBlock) == 0 {
×
1413
                return nil, nil
×
1414
        }
×
1415

1416
        // Return the channel ranges in ascending block height order.
1417
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1418
        slices.Sort(blocks)
×
1419

×
1420
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1421
                return BlockChannelRange{
×
1422
                        Height:   block,
×
1423
                        Channels: channelsPerBlock[block],
×
1424
                }
×
1425
        }), nil
×
1426
}
1427

1428
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1429
// zombie. This method is used on an ad-hoc basis, when channels need to be
1430
// marked as zombies outside the normal pruning cycle.
1431
//
1432
// NOTE: part of the V1Store interface.
1433
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1434
        pubKey1, pubKey2 [33]byte) error {
×
1435

×
1436
        ctx := context.TODO()
×
1437

×
1438
        s.cacheMu.Lock()
×
1439
        defer s.cacheMu.Unlock()
×
1440

×
1441
        chanIDB := channelIDToBytes(chanID)
×
1442

×
1443
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1444
                return db.UpsertZombieChannel(
×
1445
                        ctx, sqlc.UpsertZombieChannelParams{
×
1446
                                Version:  int16(ProtocolV1),
×
1447
                                Scid:     chanIDB,
×
1448
                                NodeKey1: pubKey1[:],
×
1449
                                NodeKey2: pubKey2[:],
×
1450
                        },
×
1451
                )
×
1452
        }, sqldb.NoOpReset)
×
1453
        if err != nil {
×
1454
                return fmt.Errorf("unable to upsert zombie channel "+
×
1455
                        "(channel_id=%d): %w", chanID, err)
×
1456
        }
×
1457

1458
        s.rejectCache.remove(chanID)
×
1459
        s.chanCache.remove(chanID)
×
1460

×
1461
        return nil
×
1462
}
1463

1464
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1465
//
1466
// NOTE: part of the V1Store interface.
1467
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1468
        s.cacheMu.Lock()
×
1469
        defer s.cacheMu.Unlock()
×
1470

×
1471
        var (
×
1472
                ctx     = context.TODO()
×
1473
                chanIDB = channelIDToBytes(chanID)
×
1474
        )
×
1475

×
1476
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1477
                res, err := db.DeleteZombieChannel(
×
1478
                        ctx, sqlc.DeleteZombieChannelParams{
×
1479
                                Scid:    chanIDB,
×
1480
                                Version: int16(ProtocolV1),
×
1481
                        },
×
1482
                )
×
1483
                if err != nil {
×
1484
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1485
                                err)
×
1486
                }
×
1487

1488
                rows, err := res.RowsAffected()
×
1489
                if err != nil {
×
1490
                        return err
×
1491
                }
×
1492

1493
                if rows == 0 {
×
1494
                        return ErrZombieEdgeNotFound
×
1495
                } else if rows > 1 {
×
1496
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1497
                                "expected 1", rows)
×
1498
                }
×
1499

1500
                return nil
×
1501
        }, sqldb.NoOpReset)
1502
        if err != nil {
×
1503
                return fmt.Errorf("unable to mark edge live "+
×
1504
                        "(channel_id=%d): %w", chanID, err)
×
1505
        }
×
1506

1507
        s.rejectCache.remove(chanID)
×
1508
        s.chanCache.remove(chanID)
×
1509

×
1510
        return err
×
1511
}
1512

1513
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1514
// zombie, then the two node public keys corresponding to this edge are also
1515
// returned.
1516
//
1517
// NOTE: part of the V1Store interface.
1518
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
1519
        error) {
×
1520

×
1521
        var (
×
1522
                ctx              = context.TODO()
×
1523
                isZombie         bool
×
1524
                pubKey1, pubKey2 route.Vertex
×
1525
                chanIDB          = channelIDToBytes(chanID)
×
1526
        )
×
1527

×
1528
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1529
                zombie, err := db.GetZombieChannel(
×
1530
                        ctx, sqlc.GetZombieChannelParams{
×
1531
                                Scid:    chanIDB,
×
1532
                                Version: int16(ProtocolV1),
×
1533
                        },
×
1534
                )
×
1535
                if errors.Is(err, sql.ErrNoRows) {
×
1536
                        return nil
×
1537
                }
×
1538
                if err != nil {
×
1539
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1540
                                err)
×
1541
                }
×
1542

1543
                copy(pubKey1[:], zombie.NodeKey1)
×
1544
                copy(pubKey2[:], zombie.NodeKey2)
×
1545
                isZombie = true
×
1546

×
1547
                return nil
×
1548
        }, sqldb.NoOpReset)
1549
        if err != nil {
×
1550
                return false, route.Vertex{}, route.Vertex{},
×
1551
                        fmt.Errorf("%w: %w (chanID=%d)",
×
1552
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
×
1553
        }
×
1554

1555
        return isZombie, pubKey1, pubKey2, nil
×
1556
}
1557

1558
// NumZombies returns the current number of zombie channels in the graph.
1559
//
1560
// NOTE: part of the V1Store interface.
1561
func (s *SQLStore) NumZombies() (uint64, error) {
×
1562
        var (
×
1563
                ctx        = context.TODO()
×
1564
                numZombies uint64
×
1565
        )
×
1566
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1567
                count, err := db.CountZombieChannels(ctx, int16(ProtocolV1))
×
1568
                if err != nil {
×
1569
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1570
                                err)
×
1571
                }
×
1572

1573
                numZombies = uint64(count)
×
1574

×
1575
                return nil
×
1576
        }, sqldb.NoOpReset)
1577
        if err != nil {
×
1578
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1579
        }
×
1580

1581
        return numZombies, nil
×
1582
}
1583

1584
// DeleteChannelEdges removes edges with the given channel IDs from the
1585
// database and marks them as zombies. This ensures that we're unable to re-add
1586
// it to our database once again. If an edge does not exist within the
1587
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1588
// true, then when we mark these edges as zombies, we'll set up the keys such
1589
// that we require the node that failed to send the fresh update to be the one
1590
// that resurrects the channel from its zombie state. The markZombie bool
1591
// denotes whether to mark the channel as a zombie.
1592
//
1593
// NOTE: part of the V1Store interface.
1594
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1595
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
1596

×
1597
        s.cacheMu.Lock()
×
1598
        defer s.cacheMu.Unlock()
×
1599

×
1600
        // Keep track of which channels we end up finding so that we can
×
1601
        // correctly return ErrEdgeNotFound if we do not find a channel.
×
1602
        chanLookup := make(map[uint64]struct{}, len(chanIDs))
×
1603
        for _, chanID := range chanIDs {
×
1604
                chanLookup[chanID] = struct{}{}
×
1605
        }
×
1606

1607
        var (
×
1608
                ctx     = context.TODO()
×
1609
                deleted []*models.ChannelEdgeInfo
×
1610
        )
×
1611
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1612
                chanIDsToDelete := make([]int64, 0, len(chanIDs))
×
1613
                chanCallBack := func(ctx context.Context,
×
1614
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
1615

×
1616
                        // Deleting the entry from the map indicates that we
×
1617
                        // have found the channel.
×
1618
                        scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1619
                        delete(chanLookup, scid)
×
1620

×
1621
                        node1, node2, err := buildNodeVertices(
×
1622
                                row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
1623
                        )
×
1624
                        if err != nil {
×
1625
                                return err
×
1626
                        }
×
1627

1628
                        info, err := getAndBuildEdgeInfo(
×
1629
                                ctx, db, s.cfg.ChainHash, row.GraphChannel,
×
1630
                                node1, node2,
×
1631
                        )
×
1632
                        if err != nil {
×
1633
                                return err
×
1634
                        }
×
1635

1636
                        deleted = append(deleted, info)
×
1637
                        chanIDsToDelete = append(
×
1638
                                chanIDsToDelete, row.GraphChannel.ID,
×
1639
                        )
×
1640

×
1641
                        if !markZombie {
×
1642
                                return nil
×
1643
                        }
×
1644

1645
                        nodeKey1, nodeKey2 := info.NodeKey1Bytes,
×
1646
                                info.NodeKey2Bytes
×
1647
                        if strictZombiePruning {
×
1648
                                var e1UpdateTime, e2UpdateTime *time.Time
×
1649
                                if row.Policy1LastUpdate.Valid {
×
1650
                                        e1Time := time.Unix(
×
1651
                                                row.Policy1LastUpdate.Int64, 0,
×
1652
                                        )
×
1653
                                        e1UpdateTime = &e1Time
×
1654
                                }
×
1655
                                if row.Policy2LastUpdate.Valid {
×
1656
                                        e2Time := time.Unix(
×
1657
                                                row.Policy2LastUpdate.Int64, 0,
×
1658
                                        )
×
1659
                                        e2UpdateTime = &e2Time
×
1660
                                }
×
1661

1662
                                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
1663
                                        info, e1UpdateTime, e2UpdateTime,
×
1664
                                )
×
1665
                        }
1666

1667
                        err = db.UpsertZombieChannel(
×
1668
                                ctx, sqlc.UpsertZombieChannelParams{
×
1669
                                        Version:  int16(ProtocolV1),
×
1670
                                        Scid:     channelIDToBytes(scid),
×
1671
                                        NodeKey1: nodeKey1[:],
×
1672
                                        NodeKey2: nodeKey2[:],
×
1673
                                },
×
1674
                        )
×
1675
                        if err != nil {
×
1676
                                return fmt.Errorf("unable to mark channel as "+
×
1677
                                        "zombie: %w", err)
×
1678
                        }
×
1679

1680
                        return nil
×
1681
                }
1682

1683
                err := s.forEachChanWithPoliciesInSCIDList(
×
1684
                        ctx, db, chanCallBack, chanIDs,
×
1685
                )
×
1686
                if err != nil {
×
1687
                        return err
×
1688
                }
×
1689

1690
                if len(chanLookup) > 0 {
×
1691
                        return ErrEdgeNotFound
×
1692
                }
×
1693

1694
                return s.deleteChannels(ctx, db, chanIDsToDelete)
×
1695
        }, func() {
×
1696
                deleted = nil
×
1697

×
1698
                // Re-fill the lookup map.
×
1699
                for _, chanID := range chanIDs {
×
1700
                        chanLookup[chanID] = struct{}{}
×
1701
                }
×
1702
        })
1703
        if err != nil {
×
1704
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1705
                        err)
×
1706
        }
×
1707

1708
        for _, chanID := range chanIDs {
×
1709
                s.rejectCache.remove(chanID)
×
1710
                s.chanCache.remove(chanID)
×
1711
        }
×
1712

1713
        return deleted, nil
×
1714
}
1715

1716
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1717
// channel identified by the channel ID. If the channel can't be found, then
1718
// ErrEdgeNotFound is returned. A struct which houses the general information
1719
// for the channel itself is returned as well as two structs that contain the
1720
// routing policies for the channel in either direction.
1721
//
1722
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1723
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1724
// the ChannelEdgeInfo will only include the public keys of each node.
1725
//
1726
// NOTE: part of the V1Store interface.
1727
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1728
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1729
        *models.ChannelEdgePolicy, error) {
×
1730

×
1731
        var (
×
1732
                ctx              = context.TODO()
×
1733
                edge             *models.ChannelEdgeInfo
×
1734
                policy1, policy2 *models.ChannelEdgePolicy
×
1735
                chanIDB          = channelIDToBytes(chanID)
×
1736
        )
×
1737
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1738
                row, err := db.GetChannelBySCIDWithPolicies(
×
1739
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1740
                                Scid:    chanIDB,
×
1741
                                Version: int16(ProtocolV1),
×
1742
                        },
×
1743
                )
×
1744
                if errors.Is(err, sql.ErrNoRows) {
×
1745
                        // First check if this edge is perhaps in the zombie
×
1746
                        // index.
×
1747
                        zombie, err := db.GetZombieChannel(
×
1748
                                ctx, sqlc.GetZombieChannelParams{
×
1749
                                        Scid:    chanIDB,
×
1750
                                        Version: int16(ProtocolV1),
×
1751
                                },
×
1752
                        )
×
1753
                        if errors.Is(err, sql.ErrNoRows) {
×
1754
                                return ErrEdgeNotFound
×
1755
                        } else if err != nil {
×
1756
                                return fmt.Errorf("unable to check if "+
×
1757
                                        "channel is zombie: %w", err)
×
1758
                        }
×
1759

1760
                        // At this point, we know the channel is a zombie, so
1761
                        // we'll return an error indicating this, and we will
1762
                        // populate the edge info with the public keys of each
1763
                        // party as this is the only information we have about
1764
                        // it.
1765
                        edge = &models.ChannelEdgeInfo{}
×
1766
                        copy(edge.NodeKey1Bytes[:], zombie.NodeKey1)
×
1767
                        copy(edge.NodeKey2Bytes[:], zombie.NodeKey2)
×
1768

×
1769
                        return ErrZombieEdge
×
1770
                } else if err != nil {
×
1771
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1772
                }
×
1773

1774
                node1, node2, err := buildNodeVertices(
×
1775
                        row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
1776
                )
×
1777
                if err != nil {
×
1778
                        return err
×
1779
                }
×
1780

1781
                edge, err = getAndBuildEdgeInfo(
×
1782
                        ctx, db, s.cfg.ChainHash, row.GraphChannel, node1,
×
1783
                        node2,
×
1784
                )
×
1785
                if err != nil {
×
1786
                        return fmt.Errorf("unable to build channel info: %w",
×
1787
                                err)
×
1788
                }
×
1789

1790
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1791
                if err != nil {
×
1792
                        return fmt.Errorf("unable to extract channel "+
×
1793
                                "policies: %w", err)
×
1794
                }
×
1795

1796
                policy1, policy2, err = getAndBuildChanPolicies(
×
1797
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1798
                )
×
1799
                if err != nil {
×
1800
                        return fmt.Errorf("unable to build channel "+
×
1801
                                "policies: %w", err)
×
1802
                }
×
1803

1804
                return nil
×
1805
        }, sqldb.NoOpReset)
1806
        if err != nil {
×
1807
                // If we are returning the ErrZombieEdge, then we also need to
×
1808
                // return the edge info as the method comment indicates that
×
1809
                // this will be populated when the edge is a zombie.
×
1810
                return edge, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1811
                        err)
×
1812
        }
×
1813

1814
        return edge, policy1, policy2, nil
×
1815
}
1816

1817
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
1818
// the channel identified by the funding outpoint. If the channel can't be
1819
// found, then ErrEdgeNotFound is returned. A struct which houses the general
1820
// information for the channel itself is returned as well as two structs that
1821
// contain the routing policies for the channel in either direction.
1822
//
1823
// NOTE: part of the V1Store interface.
1824
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
1825
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1826
        *models.ChannelEdgePolicy, error) {
×
1827

×
1828
        var (
×
1829
                ctx              = context.TODO()
×
1830
                edge             *models.ChannelEdgeInfo
×
1831
                policy1, policy2 *models.ChannelEdgePolicy
×
1832
        )
×
1833
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1834
                row, err := db.GetChannelByOutpointWithPolicies(
×
1835
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
1836
                                Outpoint: op.String(),
×
1837
                                Version:  int16(ProtocolV1),
×
1838
                        },
×
1839
                )
×
1840
                if errors.Is(err, sql.ErrNoRows) {
×
1841
                        return ErrEdgeNotFound
×
1842
                } else if err != nil {
×
1843
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1844
                }
×
1845

1846
                node1, node2, err := buildNodeVertices(
×
1847
                        row.Node1Pubkey, row.Node2Pubkey,
×
1848
                )
×
1849
                if err != nil {
×
1850
                        return err
×
1851
                }
×
1852

1853
                edge, err = getAndBuildEdgeInfo(
×
1854
                        ctx, db, s.cfg.ChainHash, row.GraphChannel, node1,
×
1855
                        node2,
×
1856
                )
×
1857
                if err != nil {
×
1858
                        return fmt.Errorf("unable to build channel info: %w",
×
1859
                                err)
×
1860
                }
×
1861

1862
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1863
                if err != nil {
×
1864
                        return fmt.Errorf("unable to extract channel "+
×
1865
                                "policies: %w", err)
×
1866
                }
×
1867

1868
                policy1, policy2, err = getAndBuildChanPolicies(
×
1869
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1870
                )
×
1871
                if err != nil {
×
1872
                        return fmt.Errorf("unable to build channel "+
×
1873
                                "policies: %w", err)
×
1874
                }
×
1875

1876
                return nil
×
1877
        }, sqldb.NoOpReset)
1878
        if err != nil {
×
1879
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1880
                        err)
×
1881
        }
×
1882

1883
        return edge, policy1, policy2, nil
×
1884
}
1885

1886
// HasChannelEdge returns true if the database knows of a channel edge with the
1887
// passed channel ID, and false otherwise. If an edge with that ID is found
1888
// within the graph, then two time stamps representing the last time the edge
1889
// was updated for both directed edges are returned along with the boolean. If
1890
// it is not found, then the zombie index is checked and its result is returned
1891
// as the second boolean.
1892
//
1893
// NOTE: part of the V1Store interface.
1894
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
1895
        bool, error) {
×
1896

×
1897
        ctx := context.TODO()
×
1898

×
1899
        var (
×
1900
                exists          bool
×
1901
                isZombie        bool
×
1902
                node1LastUpdate time.Time
×
1903
                node2LastUpdate time.Time
×
1904
        )
×
1905

×
1906
        // We'll query the cache with the shared lock held to allow multiple
×
1907
        // readers to access values in the cache concurrently if they exist.
×
1908
        s.cacheMu.RLock()
×
1909
        if entry, ok := s.rejectCache.get(chanID); ok {
×
1910
                s.cacheMu.RUnlock()
×
1911
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
1912
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
1913
                exists, isZombie = entry.flags.unpack()
×
1914

×
1915
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
1916
        }
×
1917
        s.cacheMu.RUnlock()
×
1918

×
1919
        s.cacheMu.Lock()
×
1920
        defer s.cacheMu.Unlock()
×
1921

×
1922
        // The item was not found with the shared lock, so we'll acquire the
×
1923
        // exclusive lock and check the cache again in case another method added
×
1924
        // the entry to the cache while no lock was held.
×
1925
        if entry, ok := s.rejectCache.get(chanID); ok {
×
1926
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
1927
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
1928
                exists, isZombie = entry.flags.unpack()
×
1929

×
1930
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
1931
        }
×
1932

1933
        chanIDB := channelIDToBytes(chanID)
×
1934
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1935
                channel, err := db.GetChannelBySCID(
×
1936
                        ctx, sqlc.GetChannelBySCIDParams{
×
1937
                                Scid:    chanIDB,
×
1938
                                Version: int16(ProtocolV1),
×
1939
                        },
×
1940
                )
×
1941
                if errors.Is(err, sql.ErrNoRows) {
×
1942
                        // Check if it is a zombie channel.
×
1943
                        isZombie, err = db.IsZombieChannel(
×
1944
                                ctx, sqlc.IsZombieChannelParams{
×
1945
                                        Scid:    chanIDB,
×
1946
                                        Version: int16(ProtocolV1),
×
1947
                                },
×
1948
                        )
×
1949
                        if err != nil {
×
1950
                                return fmt.Errorf("could not check if channel "+
×
1951
                                        "is zombie: %w", err)
×
1952
                        }
×
1953

1954
                        return nil
×
1955
                } else if err != nil {
×
1956
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1957
                }
×
1958

1959
                exists = true
×
1960

×
1961
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
1962
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1963
                                Version:   int16(ProtocolV1),
×
1964
                                ChannelID: channel.ID,
×
1965
                                NodeID:    channel.NodeID1,
×
1966
                        },
×
1967
                )
×
1968
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1969
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
1970
                                err)
×
1971
                } else if err == nil {
×
1972
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
1973
                }
×
1974

1975
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
1976
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1977
                                Version:   int16(ProtocolV1),
×
1978
                                ChannelID: channel.ID,
×
1979
                                NodeID:    channel.NodeID2,
×
1980
                        },
×
1981
                )
×
1982
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1983
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
1984
                                err)
×
1985
                } else if err == nil {
×
1986
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
1987
                }
×
1988

1989
                return nil
×
1990
        }, sqldb.NoOpReset)
1991
        if err != nil {
×
1992
                return time.Time{}, time.Time{}, false, false,
×
1993
                        fmt.Errorf("unable to fetch channel: %w", err)
×
1994
        }
×
1995

1996
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
1997
                upd1Time: node1LastUpdate.Unix(),
×
1998
                upd2Time: node2LastUpdate.Unix(),
×
1999
                flags:    packRejectFlags(exists, isZombie),
×
2000
        })
×
2001

×
2002
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2003
}
2004

2005
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2006
// passed channel point (outpoint). If the passed channel doesn't exist within
2007
// the database, then ErrEdgeNotFound is returned.
2008
//
2009
// NOTE: part of the V1Store interface.
2010
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
2011
        var (
×
2012
                ctx       = context.TODO()
×
2013
                channelID uint64
×
2014
        )
×
2015
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2016
                chanID, err := db.GetSCIDByOutpoint(
×
2017
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2018
                                Outpoint: chanPoint.String(),
×
2019
                                Version:  int16(ProtocolV1),
×
2020
                        },
×
2021
                )
×
2022
                if errors.Is(err, sql.ErrNoRows) {
×
2023
                        return ErrEdgeNotFound
×
2024
                } else if err != nil {
×
2025
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2026
                                err)
×
2027
                }
×
2028

2029
                channelID = byteOrder.Uint64(chanID)
×
2030

×
2031
                return nil
×
2032
        }, sqldb.NoOpReset)
2033
        if err != nil {
×
2034
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2035
        }
×
2036

2037
        return channelID, nil
×
2038
}
2039

2040
// IsPublicNode is a helper method that determines whether the node with the
2041
// given public key is seen as a public node in the graph from the graph's
2042
// source node's point of view.
2043
//
2044
// NOTE: part of the V1Store interface.
2045
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
2046
        ctx := context.TODO()
×
2047

×
2048
        var isPublic bool
×
2049
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2050
                var err error
×
2051
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2052

×
2053
                return err
×
2054
        }, sqldb.NoOpReset)
×
2055
        if err != nil {
×
2056
                return false, fmt.Errorf("unable to check if node is "+
×
2057
                        "public: %w", err)
×
2058
        }
×
2059

2060
        return isPublic, nil
×
2061
}
2062

2063
// FetchChanInfos returns the set of channel edges that correspond to the passed
2064
// channel ID's. If an edge is the query is unknown to the database, it will
2065
// skipped and the result will contain only those edges that exist at the time
2066
// of the query. This can be used to respond to peer queries that are seeking to
2067
// fill in gaps in their view of the channel graph.
2068
//
2069
// NOTE: part of the V1Store interface.
2070
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2071
        var (
×
2072
                ctx   = context.TODO()
×
2073
                edges = make(map[uint64]ChannelEdge)
×
2074
        )
×
2075
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2076
                chanCallBack := func(ctx context.Context,
×
2077
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
2078

×
2079
                        node1, node2, err := buildNodes(
×
2080
                                ctx, db, row.GraphNode, row.GraphNode_2,
×
2081
                        )
×
2082
                        if err != nil {
×
2083
                                return fmt.Errorf("unable to fetch nodes: %w",
×
2084
                                        err)
×
2085
                        }
×
2086

2087
                        edge, err := getAndBuildEdgeInfo(
×
2088
                                ctx, db, s.cfg.ChainHash, row.GraphChannel,
×
2089
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
2090
                        )
×
2091
                        if err != nil {
×
2092
                                return fmt.Errorf("unable to build "+
×
2093
                                        "channel info: %w", err)
×
2094
                        }
×
2095

2096
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2097
                        if err != nil {
×
2098
                                return fmt.Errorf("unable to extract channel "+
×
2099
                                        "policies: %w", err)
×
2100
                        }
×
2101

2102
                        p1, p2, err := getAndBuildChanPolicies(
×
2103
                                ctx, db, dbPol1, dbPol2, edge.ChannelID,
×
2104
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
2105
                        )
×
2106
                        if err != nil {
×
2107
                                return fmt.Errorf("unable to build channel "+
×
2108
                                        "policies: %w", err)
×
2109
                        }
×
2110

2111
                        edges[edge.ChannelID] = ChannelEdge{
×
2112
                                Info:    edge,
×
2113
                                Policy1: p1,
×
2114
                                Policy2: p2,
×
2115
                                Node1:   node1,
×
2116
                                Node2:   node2,
×
2117
                        }
×
2118

×
2119
                        return nil
×
2120
                }
2121

2122
                return s.forEachChanWithPoliciesInSCIDList(
×
2123
                        ctx, db, chanCallBack, chanIDs,
×
2124
                )
×
2125
        }, func() {
×
2126
                clear(edges)
×
2127
        })
×
2128
        if err != nil {
×
2129
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2130
        }
×
2131

2132
        res := make([]ChannelEdge, 0, len(edges))
×
2133
        for _, chanID := range chanIDs {
×
2134
                edge, ok := edges[chanID]
×
2135
                if !ok {
×
2136
                        continue
×
2137
                }
2138

2139
                res = append(res, edge)
×
2140
        }
2141

2142
        return res, nil
×
2143
}
2144

2145
// forEachChanWithPoliciesInSCIDList is a wrapper around the
2146
// GetChannelsBySCIDWithPolicies query that allows us to iterate through
2147
// channels in a paginated manner.
2148
func (s *SQLStore) forEachChanWithPoliciesInSCIDList(ctx context.Context,
2149
        db SQLQueries, cb func(ctx context.Context,
2150
                row sqlc.GetChannelsBySCIDWithPoliciesRow) error,
2151
        chanIDs []uint64) error {
×
2152

×
2153
        queryWrapper := func(ctx context.Context,
×
2154
                scids [][]byte) ([]sqlc.GetChannelsBySCIDWithPoliciesRow,
×
2155
                error) {
×
2156

×
2157
                return db.GetChannelsBySCIDWithPolicies(
×
2158
                        ctx, sqlc.GetChannelsBySCIDWithPoliciesParams{
×
2159
                                Version: int16(ProtocolV1),
×
2160
                                Scids:   scids,
×
2161
                        },
×
2162
                )
×
2163
        }
×
2164

NEW
2165
        return sqldb.ExecuteBatchQuery(
×
NEW
2166
                ctx, s.cfg.QueryCfg, chanIDs, channelIDToBytes, queryWrapper,
×
NEW
2167
                cb,
×
UNCOV
2168
        )
×
2169
}
2170

2171
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2172
// ID's that we don't know and are not known zombies of the passed set. In other
2173
// words, we perform a set difference of our set of chan ID's and the ones
2174
// passed in. This method can be used by callers to determine the set of
2175
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2176
// known zombies is also returned.
2177
//
2178
// NOTE: part of the V1Store interface.
2179
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
2180
        []ChannelUpdateInfo, error) {
×
2181

×
2182
        var (
×
2183
                ctx          = context.TODO()
×
2184
                newChanIDs   []uint64
×
2185
                knownZombies []ChannelUpdateInfo
×
2186
                infoLookup   = make(
×
2187
                        map[uint64]ChannelUpdateInfo, len(chansInfo),
×
2188
                )
×
2189
        )
×
2190

×
2191
        // We first build a lookup map of the channel ID's to the
×
2192
        // ChannelUpdateInfo. This allows us to quickly delete channels that we
×
2193
        // already know about.
×
2194
        for _, chanInfo := range chansInfo {
×
2195
                infoLookup[chanInfo.ShortChannelID.ToUint64()] = chanInfo
×
2196
        }
×
2197

2198
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2199
                // The call-back function deletes known channels from
×
2200
                // infoLookup, so that we can later check which channels are
×
2201
                // zombies by only looking at the remaining channels in the set.
×
2202
                cb := func(ctx context.Context,
×
2203
                        channel sqlc.GraphChannel) error {
×
2204

×
2205
                        delete(infoLookup, byteOrder.Uint64(channel.Scid))
×
2206

×
2207
                        return nil
×
2208
                }
×
2209

2210
                err := s.forEachChanInSCIDList(ctx, db, cb, chansInfo)
×
2211
                if err != nil {
×
2212
                        return fmt.Errorf("unable to iterate through "+
×
2213
                                "channels: %w", err)
×
2214
                }
×
2215

2216
                // We want to ensure that we deal with the channels in the
2217
                // same order that they were passed in, so we iterate over the
2218
                // original chansInfo slice and then check if that channel is
2219
                // still in the infoLookup map.
2220
                for _, chanInfo := range chansInfo {
×
2221
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
2222
                        if _, ok := infoLookup[channelID]; !ok {
×
2223
                                continue
×
2224
                        }
2225

2226
                        isZombie, err := db.IsZombieChannel(
×
2227
                                ctx, sqlc.IsZombieChannelParams{
×
2228
                                        Scid:    channelIDToBytes(channelID),
×
2229
                                        Version: int16(ProtocolV1),
×
2230
                                },
×
2231
                        )
×
2232
                        if err != nil {
×
2233
                                return fmt.Errorf("unable to fetch zombie "+
×
2234
                                        "channel: %w", err)
×
2235
                        }
×
2236

2237
                        if isZombie {
×
2238
                                knownZombies = append(knownZombies, chanInfo)
×
2239

×
2240
                                continue
×
2241
                        }
2242

2243
                        newChanIDs = append(newChanIDs, channelID)
×
2244
                }
2245

2246
                return nil
×
2247
        }, func() {
×
2248
                newChanIDs = nil
×
2249
                knownZombies = nil
×
2250
                // Rebuild the infoLookup map in case of a rollback.
×
2251
                for _, chanInfo := range chansInfo {
×
2252
                        scid := chanInfo.ShortChannelID.ToUint64()
×
2253
                        infoLookup[scid] = chanInfo
×
2254
                }
×
2255
        })
2256
        if err != nil {
×
2257
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2258
        }
×
2259

2260
        return newChanIDs, knownZombies, nil
×
2261
}
2262

2263
// forEachChanInSCIDList is a helper method that executes a paged query
2264
// against the database to fetch all channels that match the passed
2265
// ChannelUpdateInfo slice. The callback function is called for each channel
2266
// that is found.
2267
func (s *SQLStore) forEachChanInSCIDList(ctx context.Context, db SQLQueries,
2268
        cb func(ctx context.Context, channel sqlc.GraphChannel) error,
2269
        chansInfo []ChannelUpdateInfo) error {
×
2270

×
2271
        queryWrapper := func(ctx context.Context,
×
2272
                scids [][]byte) ([]sqlc.GraphChannel, error) {
×
2273

×
2274
                return db.GetChannelsBySCIDs(
×
2275
                        ctx, sqlc.GetChannelsBySCIDsParams{
×
2276
                                Version: int16(ProtocolV1),
×
2277
                                Scids:   scids,
×
2278
                        },
×
2279
                )
×
2280
        }
×
2281

2282
        chanIDConverter := func(chanInfo ChannelUpdateInfo) []byte {
×
2283
                channelID := chanInfo.ShortChannelID.ToUint64()
×
2284

×
2285
                return channelIDToBytes(channelID)
×
2286
        }
×
2287

NEW
2288
        return sqldb.ExecuteBatchQuery(
×
NEW
2289
                ctx, s.cfg.QueryCfg, chansInfo, chanIDConverter, queryWrapper,
×
NEW
2290
                cb,
×
UNCOV
2291
        )
×
2292
}
2293

2294
// PruneGraphNodes is a garbage collection method which attempts to prune out
2295
// any nodes from the channel graph that are currently unconnected. This ensure
2296
// that we only maintain a graph of reachable nodes. In the event that a pruned
2297
// node gains more channels, it will be re-added back to the graph.
2298
//
2299
// NOTE: this prunes nodes across protocol versions. It will never prune the
2300
// source nodes.
2301
//
2302
// NOTE: part of the V1Store interface.
2303
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
2304
        var ctx = context.TODO()
×
2305

×
2306
        var prunedNodes []route.Vertex
×
2307
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2308
                var err error
×
2309
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2310

×
2311
                return err
×
2312
        }, func() {
×
2313
                prunedNodes = nil
×
2314
        })
×
2315
        if err != nil {
×
2316
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2317
        }
×
2318

2319
        return prunedNodes, nil
×
2320
}
2321

2322
// PruneGraph prunes newly closed channels from the channel graph in response
2323
// to a new block being solved on the network. Any transactions which spend the
2324
// funding output of any known channels within he graph will be deleted.
2325
// Additionally, the "prune tip", or the last block which has been used to
2326
// prune the graph is stored so callers can ensure the graph is fully in sync
2327
// with the current UTXO state. A slice of channels that have been closed by
2328
// the target block along with any pruned nodes are returned if the function
2329
// succeeds without error.
2330
//
2331
// NOTE: part of the V1Store interface.
2332
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2333
        blockHash *chainhash.Hash, blockHeight uint32) (
2334
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
2335

×
2336
        ctx := context.TODO()
×
2337

×
2338
        s.cacheMu.Lock()
×
2339
        defer s.cacheMu.Unlock()
×
2340

×
2341
        var (
×
2342
                closedChans []*models.ChannelEdgeInfo
×
2343
                prunedNodes []route.Vertex
×
2344
        )
×
2345
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2346
                var chansToDelete []int64
×
2347

×
2348
                // Define the callback function for processing each channel.
×
2349
                channelCallback := func(ctx context.Context,
×
2350
                        row sqlc.GetChannelsByOutpointsRow) error {
×
2351

×
2352
                        node1, node2, err := buildNodeVertices(
×
2353
                                row.Node1Pubkey, row.Node2Pubkey,
×
2354
                        )
×
2355
                        if err != nil {
×
2356
                                return err
×
2357
                        }
×
2358

2359
                        info, err := getAndBuildEdgeInfo(
×
2360
                                ctx, db, s.cfg.ChainHash, row.GraphChannel,
×
2361
                                node1, node2,
×
2362
                        )
×
2363
                        if err != nil {
×
2364
                                return err
×
2365
                        }
×
2366

2367
                        closedChans = append(closedChans, info)
×
2368
                        chansToDelete = append(
×
2369
                                chansToDelete, row.GraphChannel.ID,
×
2370
                        )
×
2371

×
2372
                        return nil
×
2373
                }
2374

2375
                err := s.forEachChanInOutpoints(
×
2376
                        ctx, db, spentOutputs, channelCallback,
×
2377
                )
×
2378
                if err != nil {
×
2379
                        return fmt.Errorf("unable to fetch channels by "+
×
2380
                                "outpoints: %w", err)
×
2381
                }
×
2382

2383
                err = s.deleteChannels(ctx, db, chansToDelete)
×
2384
                if err != nil {
×
2385
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2386
                }
×
2387

2388
                err = db.UpsertPruneLogEntry(
×
2389
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
2390
                                BlockHash:   blockHash[:],
×
2391
                                BlockHeight: int64(blockHeight),
×
2392
                        },
×
2393
                )
×
2394
                if err != nil {
×
2395
                        return fmt.Errorf("unable to insert prune log "+
×
2396
                                "entry: %w", err)
×
2397
                }
×
2398

2399
                // Now that we've pruned some channels, we'll also prune any
2400
                // nodes that no longer have any channels.
2401
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2402
                if err != nil {
×
2403
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2404
                                err)
×
2405
                }
×
2406

2407
                return nil
×
2408
        }, func() {
×
2409
                prunedNodes = nil
×
2410
                closedChans = nil
×
2411
        })
×
2412
        if err != nil {
×
2413
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2414
        }
×
2415

2416
        for _, channel := range closedChans {
×
2417
                s.rejectCache.remove(channel.ChannelID)
×
2418
                s.chanCache.remove(channel.ChannelID)
×
2419
        }
×
2420

2421
        return closedChans, prunedNodes, nil
×
2422
}
2423

2424
// forEachChanInOutpoints is a helper function that executes a paginated
2425
// query to fetch channels by their outpoints and applies the given call-back
2426
// to each.
2427
//
2428
// NOTE: this fetches channels for all protocol versions.
2429
func (s *SQLStore) forEachChanInOutpoints(ctx context.Context, db SQLQueries,
2430
        outpoints []*wire.OutPoint, cb func(ctx context.Context,
2431
                row sqlc.GetChannelsByOutpointsRow) error) error {
×
2432

×
2433
        // Create a wrapper that uses the transaction's db instance to execute
×
2434
        // the query.
×
2435
        queryWrapper := func(ctx context.Context,
×
2436
                pageOutpoints []string) ([]sqlc.GetChannelsByOutpointsRow,
×
2437
                error) {
×
2438

×
2439
                return db.GetChannelsByOutpoints(ctx, pageOutpoints)
×
2440
        }
×
2441

2442
        // Define the conversion function from Outpoint to string.
2443
        outpointToString := func(outpoint *wire.OutPoint) string {
×
2444
                return outpoint.String()
×
2445
        }
×
2446

NEW
2447
        return sqldb.ExecuteBatchQuery(
×
NEW
2448
                ctx, s.cfg.QueryCfg, outpoints, outpointToString,
×
2449
                queryWrapper, cb,
×
2450
        )
×
2451
}
2452

2453
func (s *SQLStore) deleteChannels(ctx context.Context, db SQLQueries,
2454
        dbIDs []int64) error {
×
2455

×
2456
        // Create a wrapper that uses the transaction's db instance to execute
×
2457
        // the query.
×
2458
        queryWrapper := func(ctx context.Context, ids []int64) ([]any, error) {
×
2459
                return nil, db.DeleteChannels(ctx, ids)
×
2460
        }
×
2461

2462
        idConverter := func(id int64) int64 {
×
2463
                return id
×
2464
        }
×
2465

NEW
2466
        return sqldb.ExecuteBatchQuery(
×
NEW
2467
                ctx, s.cfg.QueryCfg, dbIDs, idConverter,
×
2468
                queryWrapper, func(ctx context.Context, _ any) error {
×
2469
                        return nil
×
2470
                },
×
2471
        )
2472
}
2473

2474
// ChannelView returns the verifiable edge information for each active channel
2475
// within the known channel graph. The set of UTXOs (along with their scripts)
2476
// returned are the ones that need to be watched on chain to detect channel
2477
// closes on the resident blockchain.
2478
//
2479
// NOTE: part of the V1Store interface.
2480
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
2481
        var (
×
2482
                ctx        = context.TODO()
×
2483
                edgePoints []EdgePoint
×
2484
        )
×
2485

×
NEW
2486
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
2487
                handleChannel := func(_ context.Context,
×
NEW
2488
                        channel sqlc.ListChannelsPaginatedRow) error {
×
2489

×
NEW
2490
                        pkScript, err := genMultiSigP2WSH(
×
NEW
2491
                                channel.BitcoinKey1, channel.BitcoinKey2,
×
NEW
2492
                        )
×
NEW
2493
                        if err != nil {
×
NEW
2494
                                return err
×
NEW
2495
                        }
×
2496

NEW
2497
                        op, err := wire.NewOutPointFromString(channel.Outpoint)
×
NEW
2498
                        if err != nil {
×
NEW
2499
                                return err
×
NEW
2500
                        }
×
2501

NEW
2502
                        edgePoints = append(edgePoints, EdgePoint{
×
NEW
2503
                                FundingPkScript: pkScript,
×
NEW
2504
                                OutPoint:        *op,
×
NEW
2505
                        })
×
2506

×
NEW
2507
                        return nil
×
2508
                }
2509

NEW
2510
                queryFunc := func(ctx context.Context, lastID int64,
×
NEW
2511
                        limit int32) ([]sqlc.ListChannelsPaginatedRow, error) {
×
NEW
2512

×
NEW
2513
                        return db.ListChannelsPaginated(
×
2514
                                ctx, sqlc.ListChannelsPaginatedParams{
×
2515
                                        Version: int16(ProtocolV1),
×
2516
                                        ID:      lastID,
×
NEW
2517
                                        Limit:   limit,
×
2518
                                },
×
2519
                        )
×
NEW
2520
                }
×
2521

NEW
2522
                extractCursor := func(row sqlc.ListChannelsPaginatedRow) int64 {
×
NEW
2523
                        return row.ID
×
UNCOV
2524
                }
×
2525

NEW
2526
                return sqldb.ExecutePaginatedQuery(
×
NEW
2527
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
NEW
2528
                        extractCursor, handleChannel,
×
NEW
2529
                )
×
2530
        }, func() {
×
2531
                edgePoints = nil
×
2532
        })
×
2533
        if err != nil {
×
2534
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
2535
        }
×
2536

2537
        return edgePoints, nil
×
2538
}
2539

2540
// PruneTip returns the block height and hash of the latest block that has been
2541
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2542
// to tell if the graph is currently in sync with the current best known UTXO
2543
// state.
2544
//
2545
// NOTE: part of the V1Store interface.
2546
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
2547
        var (
×
2548
                ctx       = context.TODO()
×
2549
                tipHash   chainhash.Hash
×
2550
                tipHeight uint32
×
2551
        )
×
2552
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2553
                pruneTip, err := db.GetPruneTip(ctx)
×
2554
                if errors.Is(err, sql.ErrNoRows) {
×
2555
                        return ErrGraphNeverPruned
×
2556
                } else if err != nil {
×
2557
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
2558
                }
×
2559

2560
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
2561
                tipHeight = uint32(pruneTip.BlockHeight)
×
2562

×
2563
                return nil
×
2564
        }, sqldb.NoOpReset)
2565
        if err != nil {
×
2566
                return nil, 0, err
×
2567
        }
×
2568

2569
        return &tipHash, tipHeight, nil
×
2570
}
2571

2572
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2573
//
2574
// NOTE: this prunes nodes across protocol versions. It will never prune the
2575
// source nodes.
2576
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
2577
        db SQLQueries) ([]route.Vertex, error) {
×
2578

×
2579
        nodeKeys, err := db.DeleteUnconnectedNodes(ctx)
×
2580
        if err != nil {
×
2581
                return nil, fmt.Errorf("unable to delete unconnected "+
×
2582
                        "nodes: %w", err)
×
2583
        }
×
2584

2585
        prunedNodes := make([]route.Vertex, len(nodeKeys))
×
2586
        for i, nodeKey := range nodeKeys {
×
2587
                pub, err := route.NewVertexFromBytes(nodeKey)
×
2588
                if err != nil {
×
2589
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
2590
                                "from bytes: %w", err)
×
2591
                }
×
2592

2593
                prunedNodes[i] = pub
×
2594
        }
2595

2596
        return prunedNodes, nil
×
2597
}
2598

2599
// DisconnectBlockAtHeight is used to indicate that the block specified
2600
// by the passed height has been disconnected from the main chain. This
2601
// will "rewind" the graph back to the height below, deleting channels
2602
// that are no longer confirmed from the graph. The prune log will be
2603
// set to the last prune height valid for the remaining chain.
2604
// Channels that were removed from the graph resulting from the
2605
// disconnected block are returned.
2606
//
2607
// NOTE: part of the V1Store interface.
2608
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
2609
        []*models.ChannelEdgeInfo, error) {
×
2610

×
2611
        ctx := context.TODO()
×
2612

×
2613
        var (
×
2614
                // Every channel having a ShortChannelID starting at 'height'
×
2615
                // will no longer be confirmed.
×
2616
                startShortChanID = lnwire.ShortChannelID{
×
2617
                        BlockHeight: height,
×
2618
                }
×
2619

×
2620
                // Delete everything after this height from the db up until the
×
2621
                // SCID alias range.
×
2622
                endShortChanID = aliasmgr.StartingAlias
×
2623

×
2624
                removedChans []*models.ChannelEdgeInfo
×
2625

×
2626
                chanIDStart = channelIDToBytes(startShortChanID.ToUint64())
×
2627
                chanIDEnd   = channelIDToBytes(endShortChanID.ToUint64())
×
2628
        )
×
2629

×
2630
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2631
                rows, err := db.GetChannelsBySCIDRange(
×
2632
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
2633
                                StartScid: chanIDStart,
×
2634
                                EndScid:   chanIDEnd,
×
2635
                        },
×
2636
                )
×
2637
                if err != nil {
×
2638
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2639
                }
×
2640

2641
                chanIDsToDelete := make([]int64, len(rows))
×
2642
                for i, row := range rows {
×
2643
                        node1, node2, err := buildNodeVertices(
×
2644
                                row.Node1PubKey, row.Node2PubKey,
×
2645
                        )
×
2646
                        if err != nil {
×
2647
                                return err
×
2648
                        }
×
2649

2650
                        channel, err := getAndBuildEdgeInfo(
×
2651
                                ctx, db, s.cfg.ChainHash, row.GraphChannel,
×
2652
                                node1, node2,
×
2653
                        )
×
2654
                        if err != nil {
×
2655
                                return err
×
2656
                        }
×
2657

2658
                        chanIDsToDelete[i] = row.GraphChannel.ID
×
2659
                        removedChans = append(removedChans, channel)
×
2660
                }
2661

2662
                err = s.deleteChannels(ctx, db, chanIDsToDelete)
×
2663
                if err != nil {
×
2664
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2665
                }
×
2666

2667
                return db.DeletePruneLogEntriesInRange(
×
2668
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2669
                                StartHeight: int64(height),
×
2670
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
2671
                        },
×
2672
                )
×
2673
        }, func() {
×
2674
                removedChans = nil
×
2675
        })
×
2676
        if err != nil {
×
2677
                return nil, fmt.Errorf("unable to disconnect block at "+
×
2678
                        "height: %w", err)
×
2679
        }
×
2680

2681
        for _, channel := range removedChans {
×
2682
                s.rejectCache.remove(channel.ChannelID)
×
2683
                s.chanCache.remove(channel.ChannelID)
×
2684
        }
×
2685

2686
        return removedChans, nil
×
2687
}
2688

2689
// AddEdgeProof sets the proof of an existing edge in the graph database.
2690
//
2691
// NOTE: part of the V1Store interface.
2692
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
2693
        proof *models.ChannelAuthProof) error {
×
2694

×
2695
        var (
×
2696
                ctx       = context.TODO()
×
2697
                scidBytes = channelIDToBytes(scid.ToUint64())
×
2698
        )
×
2699

×
2700
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2701
                res, err := db.AddV1ChannelProof(
×
2702
                        ctx, sqlc.AddV1ChannelProofParams{
×
2703
                                Scid:              scidBytes,
×
2704
                                Node1Signature:    proof.NodeSig1Bytes,
×
2705
                                Node2Signature:    proof.NodeSig2Bytes,
×
2706
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
2707
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
2708
                        },
×
2709
                )
×
2710
                if err != nil {
×
2711
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
2712
                }
×
2713

2714
                n, err := res.RowsAffected()
×
2715
                if err != nil {
×
2716
                        return err
×
2717
                }
×
2718

2719
                if n == 0 {
×
2720
                        return fmt.Errorf("no rows affected when adding edge "+
×
2721
                                "proof for SCID %v", scid)
×
2722
                } else if n > 1 {
×
2723
                        return fmt.Errorf("multiple rows affected when adding "+
×
2724
                                "edge proof for SCID %v: %d rows affected",
×
2725
                                scid, n)
×
2726
                }
×
2727

2728
                return nil
×
2729
        }, sqldb.NoOpReset)
2730
        if err != nil {
×
2731
                return fmt.Errorf("unable to add edge proof: %w", err)
×
2732
        }
×
2733

2734
        return nil
×
2735
}
2736

2737
// PutClosedScid stores a SCID for a closed channel in the database. This is so
2738
// that we can ignore channel announcements that we know to be closed without
2739
// having to validate them and fetch a block.
2740
//
2741
// NOTE: part of the V1Store interface.
2742
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
2743
        var (
×
2744
                ctx     = context.TODO()
×
2745
                chanIDB = channelIDToBytes(scid.ToUint64())
×
2746
        )
×
2747

×
2748
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2749
                return db.InsertClosedChannel(ctx, chanIDB)
×
2750
        }, sqldb.NoOpReset)
×
2751
}
2752

2753
// IsClosedScid checks whether a channel identified by the passed in scid is
2754
// closed. This helps avoid having to perform expensive validation checks.
2755
//
2756
// NOTE: part of the V1Store interface.
2757
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
2758
        var (
×
2759
                ctx      = context.TODO()
×
2760
                isClosed bool
×
2761
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
2762
        )
×
2763
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2764
                var err error
×
2765
                isClosed, err = db.IsClosedChannel(ctx, chanIDB)
×
2766
                if err != nil {
×
2767
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
2768
                                err)
×
2769
                }
×
2770

2771
                return nil
×
2772
        }, sqldb.NoOpReset)
2773
        if err != nil {
×
2774
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
2775
                        err)
×
2776
        }
×
2777

2778
        return isClosed, nil
×
2779
}
2780

2781
// GraphSession will provide the call-back with access to a NodeTraverser
2782
// instance which can be used to perform queries against the channel graph.
2783
//
2784
// NOTE: part of the V1Store interface.
2785
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error,
2786
        reset func()) error {
×
2787

×
2788
        var ctx = context.TODO()
×
2789

×
2790
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2791
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
2792
        }, reset)
×
2793
}
2794

2795
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
2796
// read only transaction for a consistent view of the graph.
2797
type sqlNodeTraverser struct {
2798
        db    SQLQueries
2799
        chain chainhash.Hash
2800
}
2801

2802
// A compile-time assertion to ensure that sqlNodeTraverser implements the
2803
// NodeTraverser interface.
2804
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
2805

2806
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
2807
func newSQLNodeTraverser(db SQLQueries,
2808
        chain chainhash.Hash) *sqlNodeTraverser {
×
2809

×
2810
        return &sqlNodeTraverser{
×
2811
                db:    db,
×
2812
                chain: chain,
×
2813
        }
×
2814
}
×
2815

2816
// ForEachNodeDirectedChannel calls the callback for every channel of the given
2817
// node.
2818
//
2819
// NOTE: Part of the NodeTraverser interface.
2820
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
2821
        cb func(channel *DirectedChannel) error, _ func()) error {
×
2822

×
2823
        ctx := context.TODO()
×
2824

×
2825
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
2826
}
×
2827

2828
// FetchNodeFeatures returns the features of the given node. If the node is
2829
// unknown, assume no additional features are supported.
2830
//
2831
// NOTE: Part of the NodeTraverser interface.
2832
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
2833
        *lnwire.FeatureVector, error) {
×
2834

×
2835
        ctx := context.TODO()
×
2836

×
2837
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
2838
}
×
2839

2840
// forEachNodeDirectedChannel iterates through all channels of a given
2841
// node, executing the passed callback on the directed edge representing the
2842
// channel and its incoming policy. If the node is not found, no error is
2843
// returned.
2844
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
2845
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
2846

×
2847
        toNodeCallback := func() route.Vertex {
×
2848
                return nodePub
×
2849
        }
×
2850

2851
        dbID, err := db.GetNodeIDByPubKey(
×
2852
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
2853
                        Version: int16(ProtocolV1),
×
2854
                        PubKey:  nodePub[:],
×
2855
                },
×
2856
        )
×
2857
        if errors.Is(err, sql.ErrNoRows) {
×
2858
                return nil
×
2859
        } else if err != nil {
×
2860
                return fmt.Errorf("unable to fetch node: %w", err)
×
2861
        }
×
2862

2863
        rows, err := db.ListChannelsByNodeID(
×
2864
                ctx, sqlc.ListChannelsByNodeIDParams{
×
2865
                        Version: int16(ProtocolV1),
×
2866
                        NodeID1: dbID,
×
2867
                },
×
2868
        )
×
2869
        if err != nil {
×
2870
                return fmt.Errorf("unable to fetch channels: %w", err)
×
2871
        }
×
2872

2873
        // Exit early if there are no channels for this node so we don't
2874
        // do the unnecessary feature fetching.
2875
        if len(rows) == 0 {
×
2876
                return nil
×
2877
        }
×
2878

2879
        features, err := getNodeFeatures(ctx, db, dbID)
×
2880
        if err != nil {
×
2881
                return fmt.Errorf("unable to fetch node features: %w", err)
×
2882
        }
×
2883

2884
        for _, row := range rows {
×
2885
                node1, node2, err := buildNodeVertices(
×
2886
                        row.Node1Pubkey, row.Node2Pubkey,
×
2887
                )
×
2888
                if err != nil {
×
2889
                        return fmt.Errorf("unable to build node vertices: %w",
×
2890
                                err)
×
2891
                }
×
2892

2893
                edge := buildCacheableChannelInfo(
×
2894
                        row.GraphChannel.Scid, row.GraphChannel.Capacity.Int64,
×
2895
                        node1, node2,
×
2896
                )
×
2897

×
2898
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2899
                if err != nil {
×
2900
                        return err
×
2901
                }
×
2902

NEW
2903
                p1, p2, err := buildCachedChanPolicies(
×
NEW
2904
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
NEW
2905
                )
×
NEW
2906
                if err != nil {
×
NEW
2907
                        return err
×
UNCOV
2908
                }
×
2909

2910
                // Determine the outgoing and incoming policy for this
2911
                // channel and node combo.
2912
                outPolicy, inPolicy := p1, p2
×
2913
                if p1 != nil && node2 == nodePub {
×
2914
                        outPolicy, inPolicy = p2, p1
×
2915
                } else if p2 != nil && node1 != nodePub {
×
2916
                        outPolicy, inPolicy = p2, p1
×
2917
                }
×
2918

2919
                var cachedInPolicy *models.CachedEdgePolicy
×
2920
                if inPolicy != nil {
×
2921
                        cachedInPolicy = inPolicy
×
2922
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
2923
                        cachedInPolicy.ToNodeFeatures = features
×
2924
                }
×
2925

2926
                directedChannel := &DirectedChannel{
×
2927
                        ChannelID:    edge.ChannelID,
×
2928
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
2929
                        OtherNode:    edge.NodeKey2Bytes,
×
2930
                        Capacity:     edge.Capacity,
×
2931
                        OutPolicySet: outPolicy != nil,
×
2932
                        InPolicy:     cachedInPolicy,
×
2933
                }
×
2934
                if outPolicy != nil {
×
2935
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
2936
                                directedChannel.InboundFee = fee
×
2937
                        })
×
2938
                }
2939

2940
                if nodePub == edge.NodeKey2Bytes {
×
2941
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
2942
                }
×
2943

2944
                if err := cb(directedChannel); err != nil {
×
2945
                        return err
×
2946
                }
×
2947
        }
2948

2949
        return nil
×
2950
}
2951

2952
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
2953
// and executes the provided callback for each node.
2954
func forEachNodeCacheable(ctx context.Context, cfg *sqldb.QueryConfig,
2955
        db SQLQueries,
2956
        cb func(nodeID int64, nodePub route.Vertex) error) error {
×
2957

×
NEW
2958
        handleNode := func(_ context.Context,
×
NEW
2959
                node sqlc.ListNodeIDsAndPubKeysRow) error {
×
2960

×
NEW
2961
                var pub route.Vertex
×
NEW
2962
                copy(pub[:], node.PubKey)
×
NEW
2963

×
NEW
2964
                return cb(node.ID, pub)
×
NEW
2965
        }
×
2966

NEW
2967
        queryFunc := func(ctx context.Context, lastID int64,
×
NEW
2968
                limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
NEW
2969

×
NEW
2970
                return db.ListNodeIDsAndPubKeys(
×
2971
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
2972
                                Version: int16(ProtocolV1),
×
2973
                                ID:      lastID,
×
NEW
2974
                                Limit:   limit,
×
2975
                        },
×
2976
                )
×
NEW
2977
        }
×
2978

NEW
2979
        extractCursor := func(row sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
NEW
2980
                return row.ID
×
UNCOV
2981
        }
×
2982

NEW
2983
        return sqldb.ExecutePaginatedQuery(
×
NEW
2984
                ctx, cfg, int64(-1), queryFunc, extractCursor, handleNode,
×
NEW
2985
        )
×
2986
}
2987

2988
// forEachNodeChannel iterates through all channels of a node, executing
2989
// the passed callback on each. The call-back is provided with the channel's
2990
// edge information, the outgoing policy and the incoming policy for the
2991
// channel and node combo.
2992
func forEachNodeChannel(ctx context.Context, db SQLQueries,
2993
        chain chainhash.Hash, id int64, cb func(*models.ChannelEdgeInfo,
2994
                *models.ChannelEdgePolicy,
2995
                *models.ChannelEdgePolicy) error) error {
×
2996

×
2997
        // Get all the V1 channels for this node.Add commentMore actions
×
2998
        rows, err := db.ListChannelsByNodeID(
×
2999
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3000
                        Version: int16(ProtocolV1),
×
3001
                        NodeID1: id,
×
3002
                },
×
3003
        )
×
3004
        if err != nil {
×
3005
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3006
        }
×
3007

3008
        // Call the call-back for each channel and its known policies.
3009
        for _, row := range rows {
×
3010
                node1, node2, err := buildNodeVertices(
×
3011
                        row.Node1Pubkey, row.Node2Pubkey,
×
3012
                )
×
3013
                if err != nil {
×
3014
                        return fmt.Errorf("unable to build node vertices: %w",
×
3015
                                err)
×
3016
                }
×
3017

3018
                edge, err := getAndBuildEdgeInfo(
×
3019
                        ctx, db, chain, row.GraphChannel, node1, node2,
×
3020
                )
×
3021
                if err != nil {
×
3022
                        return fmt.Errorf("unable to build channel info: %w",
×
3023
                                err)
×
3024
                }
×
3025

3026
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3027
                if err != nil {
×
3028
                        return fmt.Errorf("unable to extract channel "+
×
3029
                                "policies: %w", err)
×
3030
                }
×
3031

3032
                p1, p2, err := getAndBuildChanPolicies(
×
3033
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
3034
                )
×
3035
                if err != nil {
×
3036
                        return fmt.Errorf("unable to build channel "+
×
3037
                                "policies: %w", err)
×
3038
                }
×
3039

3040
                // Determine the outgoing and incoming policy for this
3041
                // channel and node combo.
3042
                p1ToNode := row.GraphChannel.NodeID2
×
3043
                p2ToNode := row.GraphChannel.NodeID1
×
3044
                outPolicy, inPolicy := p1, p2
×
3045
                if (p1 != nil && p1ToNode == id) ||
×
3046
                        (p2 != nil && p2ToNode != id) {
×
3047

×
3048
                        outPolicy, inPolicy = p2, p1
×
3049
                }
×
3050

3051
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3052
                        return err
×
3053
                }
×
3054
        }
3055

3056
        return nil
×
3057
}
3058

3059
// updateChanEdgePolicy upserts the channel policy info we have stored for
3060
// a channel we already know of.
3061
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3062
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
3063
        error) {
×
3064

×
3065
        var (
×
3066
                node1Pub, node2Pub route.Vertex
×
3067
                isNode1            bool
×
3068
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3069
        )
×
3070

×
3071
        // Check that this edge policy refers to a channel that we already
×
3072
        // know of. We do this explicitly so that we can return the appropriate
×
3073
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
3074
        // abort the transaction which would abort the entire batch.
×
3075
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3076
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3077
                        Scid:    chanIDB,
×
3078
                        Version: int16(ProtocolV1),
×
3079
                },
×
3080
        )
×
3081
        if errors.Is(err, sql.ErrNoRows) {
×
3082
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3083
        } else if err != nil {
×
3084
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3085
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3086
        }
×
3087

3088
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3089
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3090

×
3091
        // Figure out which node this edge is from.
×
3092
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3093
        nodeID := dbChan.NodeID1
×
3094
        if !isNode1 {
×
3095
                nodeID = dbChan.NodeID2
×
3096
        }
×
3097

3098
        var (
×
3099
                inboundBase sql.NullInt64
×
3100
                inboundRate sql.NullInt64
×
3101
        )
×
3102
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3103
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3104
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3105
        })
×
3106

3107
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
3108
                Version:     int16(ProtocolV1),
×
3109
                ChannelID:   dbChan.ID,
×
3110
                NodeID:      nodeID,
×
3111
                Timelock:    int32(edge.TimeLockDelta),
×
3112
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
3113
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
3114
                MinHtlcMsat: int64(edge.MinHTLC),
×
3115
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3116
                Disabled: sql.NullBool{
×
3117
                        Valid: true,
×
3118
                        Bool:  edge.IsDisabled(),
×
3119
                },
×
3120
                MaxHtlcMsat: sql.NullInt64{
×
3121
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3122
                        Int64: int64(edge.MaxHTLC),
×
3123
                },
×
3124
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
3125
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
3126
                InboundBaseFeeMsat:      inboundBase,
×
3127
                InboundFeeRateMilliMsat: inboundRate,
×
3128
                Signature:               edge.SigBytes,
×
3129
        })
×
3130
        if err != nil {
×
3131
                return node1Pub, node2Pub, isNode1,
×
3132
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3133
        }
×
3134

3135
        // Convert the flat extra opaque data into a map of TLV types to
3136
        // values.
3137
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3138
        if err != nil {
×
3139
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3140
                        "marshal extra opaque data: %w", err)
×
3141
        }
×
3142

3143
        // Update the channel policy's extra signed fields.
3144
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3145
        if err != nil {
×
3146
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3147
                        "policy extra TLVs: %w", err)
×
3148
        }
×
3149

3150
        return node1Pub, node2Pub, isNode1, nil
×
3151
}
3152

3153
// getNodeByPubKey attempts to look up a target node by its public key.
3154
func getNodeByPubKey(ctx context.Context, db SQLQueries,
3155
        pubKey route.Vertex) (int64, *models.LightningNode, error) {
×
3156

×
3157
        dbNode, err := db.GetNodeByPubKey(
×
3158
                ctx, sqlc.GetNodeByPubKeyParams{
×
3159
                        Version: int16(ProtocolV1),
×
3160
                        PubKey:  pubKey[:],
×
3161
                },
×
3162
        )
×
3163
        if errors.Is(err, sql.ErrNoRows) {
×
3164
                return 0, nil, ErrGraphNodeNotFound
×
3165
        } else if err != nil {
×
3166
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3167
        }
×
3168

3169
        node, err := buildNode(ctx, db, &dbNode)
×
3170
        if err != nil {
×
3171
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3172
        }
×
3173

3174
        return dbNode.ID, node, nil
×
3175
}
3176

3177
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3178
// provided parameters.
3179
func buildCacheableChannelInfo(scid []byte, capacity int64, node1Pub,
3180
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
3181

×
3182
        return &models.CachedEdgeInfo{
×
3183
                ChannelID:     byteOrder.Uint64(scid),
×
3184
                NodeKey1Bytes: node1Pub,
×
3185
                NodeKey2Bytes: node2Pub,
×
3186
                Capacity:      btcutil.Amount(capacity),
×
3187
        }
×
3188
}
×
3189

3190
// buildNode constructs a LightningNode instance from the given database node
3191
// record. The node's features, addresses and extra signed fields are also
3192
// fetched from the database and set on the node.
3193
func buildNode(ctx context.Context, db SQLQueries,
3194
        dbNode *sqlc.GraphNode) (*models.LightningNode, error) {
×
3195

×
3196
        // NOTE: buildNode is only used to load the data for a single node, and
×
3197
        // so no paged queries will be performed. This means that it's ok to
×
3198
        // used pass in default config values here.
×
NEW
3199
        cfg := sqldb.DefaultQueryConfig()
×
3200

×
3201
        data, err := batchLoadNodeData(ctx, cfg, db, []int64{dbNode.ID})
×
3202
        if err != nil {
×
3203
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
3204
                        err)
×
3205
        }
×
3206

3207
        return buildNodeWithBatchData(dbNode, data)
×
3208
}
3209

3210
// buildNodeWithBatchData builds a models.LightningNode instance
3211
// from the provided sqlc.GraphNode and batchNodeData. If the node does have
3212
// features/addresses/extra fields, then the corresponding fields are expected
3213
// to be present in the batchNodeData.
3214
func buildNodeWithBatchData(dbNode *sqlc.GraphNode,
3215
        batchData *batchNodeData) (*models.LightningNode, error) {
×
3216

×
3217
        if dbNode.Version != int16(ProtocolV1) {
×
3218
                return nil, fmt.Errorf("unsupported node version: %d",
×
3219
                        dbNode.Version)
×
3220
        }
×
3221

3222
        var pub [33]byte
×
3223
        copy(pub[:], dbNode.PubKey)
×
3224

×
3225
        node := &models.LightningNode{
×
3226
                PubKeyBytes: pub,
×
3227
                Features:    lnwire.EmptyFeatureVector(),
×
3228
                LastUpdate:  time.Unix(0, 0),
×
3229
        }
×
3230

×
3231
        if len(dbNode.Signature) == 0 {
×
3232
                return node, nil
×
3233
        }
×
3234

3235
        node.HaveNodeAnnouncement = true
×
3236
        node.AuthSigBytes = dbNode.Signature
×
3237
        node.Alias = dbNode.Alias.String
×
3238
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
3239

×
3240
        var err error
×
3241
        if dbNode.Color.Valid {
×
3242
                node.Color, err = DecodeHexColor(dbNode.Color.String)
×
3243
                if err != nil {
×
3244
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3245
                                err)
×
3246
                }
×
3247
        }
3248

3249
        // Use preloaded features.
3250
        if features, exists := batchData.features[dbNode.ID]; exists {
×
3251
                fv := lnwire.EmptyFeatureVector()
×
3252
                for _, bit := range features {
×
3253
                        fv.Set(lnwire.FeatureBit(bit))
×
3254
                }
×
3255
                node.Features = fv
×
3256
        }
3257

3258
        // Use preloaded addresses.
3259
        addresses, exists := batchData.addresses[dbNode.ID]
×
3260
        if exists && len(addresses) > 0 {
×
3261
                node.Addresses, err = buildNodeAddresses(addresses)
×
3262
                if err != nil {
×
3263
                        return nil, fmt.Errorf("unable to build addresses "+
×
3264
                                "for node(%d): %w", dbNode.ID, err)
×
3265
                }
×
3266
        }
3267

3268
        // Use preloaded extra fields.
3269
        if extraFields, exists := batchData.extraFields[dbNode.ID]; exists {
×
3270
                recs, err := lnwire.CustomRecords(extraFields).Serialize()
×
3271
                if err != nil {
×
3272
                        return nil, fmt.Errorf("unable to serialize extra "+
×
3273
                                "signed fields: %w", err)
×
3274
                }
×
3275
                if len(recs) != 0 {
×
3276
                        node.ExtraOpaqueData = recs
×
3277
                }
×
3278
        }
3279

3280
        return node, nil
×
3281
}
3282

3283
// forEachNodeInBatch fetches all nodes in the provided batch, builds them
3284
// with the preloaded data, and executes the provided callback for each node.
3285
func forEachNodeInBatch(ctx context.Context, cfg *sqldb.QueryConfig,
3286
        db SQLQueries, nodes []sqlc.GraphNode,
3287
        cb func(dbID int64, node *models.LightningNode) error) error {
×
3288

×
3289
        // Extract node IDs for batch loading.
×
3290
        nodeIDs := make([]int64, len(nodes))
×
3291
        for i, node := range nodes {
×
3292
                nodeIDs[i] = node.ID
×
3293
        }
×
3294

3295
        // Batch load all related data for this page.
3296
        batchData, err := batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
3297
        if err != nil {
×
3298
                return fmt.Errorf("unable to batch load node data: %w", err)
×
3299
        }
×
3300

3301
        for _, dbNode := range nodes {
×
3302
                node, err := buildNodeWithBatchData(&dbNode, batchData)
×
3303
                if err != nil {
×
3304
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
3305
                                dbNode.ID, err)
×
3306
                }
×
3307

3308
                if err := cb(dbNode.ID, node); err != nil {
×
3309
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
3310
                                dbNode.ID, err)
×
3311
                }
×
3312
        }
3313

3314
        return nil
×
3315
}
3316

3317
// getNodeFeatures fetches the feature bits and constructs the feature vector
3318
// for a node with the given DB ID.
3319
func getNodeFeatures(ctx context.Context, db SQLQueries,
3320
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3321

×
3322
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3323
        if err != nil {
×
3324
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3325
                        nodeID, err)
×
3326
        }
×
3327

3328
        features := lnwire.EmptyFeatureVector()
×
3329
        for _, feature := range rows {
×
3330
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3331
        }
×
3332

3333
        return features, nil
×
3334
}
3335

3336
// upsertNode upserts the node record into the database. If the node already
3337
// exists, then the node's information is updated. If the node doesn't exist,
3338
// then a new node is created. The node's features, addresses and extra TLV
3339
// types are also updated. The node's DB ID is returned.
3340
func upsertNode(ctx context.Context, db SQLQueries,
3341
        node *models.LightningNode) (int64, error) {
×
3342

×
3343
        params := sqlc.UpsertNodeParams{
×
3344
                Version: int16(ProtocolV1),
×
3345
                PubKey:  node.PubKeyBytes[:],
×
3346
        }
×
3347

×
3348
        if node.HaveNodeAnnouncement {
×
3349
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
3350
                params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
×
3351
                params.Alias = sqldb.SQLStr(node.Alias)
×
3352
                params.Signature = node.AuthSigBytes
×
3353
        }
×
3354

3355
        nodeID, err := db.UpsertNode(ctx, params)
×
3356
        if err != nil {
×
3357
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3358
                        err)
×
3359
        }
×
3360

3361
        // We can exit here if we don't have the announcement yet.
3362
        if !node.HaveNodeAnnouncement {
×
3363
                return nodeID, nil
×
3364
        }
×
3365

3366
        // Update the node's features.
3367
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3368
        if err != nil {
×
3369
                return 0, fmt.Errorf("inserting node features: %w", err)
×
3370
        }
×
3371

3372
        // Update the node's addresses.
3373
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3374
        if err != nil {
×
3375
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
3376
        }
×
3377

3378
        // Convert the flat extra opaque data into a map of TLV types to
3379
        // values.
3380
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
3381
        if err != nil {
×
3382
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
3383
                        err)
×
3384
        }
×
3385

3386
        // Update the node's extra signed fields.
3387
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3388
        if err != nil {
×
3389
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
3390
        }
×
3391

3392
        return nodeID, nil
×
3393
}
3394

3395
// upsertNodeFeatures updates the node's features node_features table. This
3396
// includes deleting any feature bits no longer present and inserting any new
3397
// feature bits. If the feature bit does not yet exist in the features table,
3398
// then an entry is created in that table first.
3399
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3400
        features *lnwire.FeatureVector) error {
×
3401

×
3402
        // Get any existing features for the node.
×
3403
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3404
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3405
                return err
×
3406
        }
×
3407

3408
        // Copy the nodes latest set of feature bits.
3409
        newFeatures := make(map[int32]struct{})
×
3410
        if features != nil {
×
3411
                for feature := range features.Features() {
×
3412
                        newFeatures[int32(feature)] = struct{}{}
×
3413
                }
×
3414
        }
3415

3416
        // For any current feature that already exists in the DB, remove it from
3417
        // the in-memory map. For any existing feature that does not exist in
3418
        // the in-memory map, delete it from the database.
3419
        for _, feature := range existingFeatures {
×
3420
                // The feature is still present, so there are no updates to be
×
3421
                // made.
×
3422
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3423
                        delete(newFeatures, feature.FeatureBit)
×
3424
                        continue
×
3425
                }
3426

3427
                // The feature is no longer present, so we remove it from the
3428
                // database.
3429
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
3430
                        NodeID:     nodeID,
×
3431
                        FeatureBit: feature.FeatureBit,
×
3432
                })
×
3433
                if err != nil {
×
3434
                        return fmt.Errorf("unable to delete node(%d) "+
×
3435
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3436
                                err)
×
3437
                }
×
3438
        }
3439

3440
        // Any remaining entries in newFeatures are new features that need to be
3441
        // added to the database for the first time.
3442
        for feature := range newFeatures {
×
3443
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3444
                        NodeID:     nodeID,
×
3445
                        FeatureBit: feature,
×
3446
                })
×
3447
                if err != nil {
×
3448
                        return fmt.Errorf("unable to insert node(%d) "+
×
3449
                                "feature(%v): %w", nodeID, feature, err)
×
3450
                }
×
3451
        }
3452

3453
        return nil
×
3454
}
3455

3456
// fetchNodeFeatures fetches the features for a node with the given public key.
3457
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3458
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3459

×
3460
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3461
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3462
                        PubKey:  nodePub[:],
×
3463
                        Version: int16(ProtocolV1),
×
3464
                },
×
3465
        )
×
3466
        if err != nil {
×
3467
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
3468
                        nodePub, err)
×
3469
        }
×
3470

3471
        features := lnwire.EmptyFeatureVector()
×
3472
        for _, bit := range rows {
×
3473
                features.Set(lnwire.FeatureBit(bit))
×
3474
        }
×
3475

3476
        return features, nil
×
3477
}
3478

3479
// dbAddressType is an enum type that represents the different address types
3480
// that we store in the node_addresses table. The address type determines how
3481
// the address is to be serialised/deserialize.
3482
type dbAddressType uint8
3483

3484
const (
3485
        addressTypeIPv4   dbAddressType = 1
3486
        addressTypeIPv6   dbAddressType = 2
3487
        addressTypeTorV2  dbAddressType = 3
3488
        addressTypeTorV3  dbAddressType = 4
3489
        addressTypeOpaque dbAddressType = math.MaxInt8
3490
)
3491

3492
// upsertNodeAddresses updates the node's addresses in the database. This
3493
// includes deleting any existing addresses and inserting the new set of
3494
// addresses. The deletion is necessary since the ordering of the addresses may
3495
// change, and we need to ensure that the database reflects the latest set of
3496
// addresses so that at the time of reconstructing the node announcement, the
3497
// order is preserved and the signature over the message remains valid.
3498
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
3499
        addresses []net.Addr) error {
×
3500

×
3501
        // Delete any existing addresses for the node. This is required since
×
3502
        // even if the new set of addresses is the same, the ordering may have
×
3503
        // changed for a given address type.
×
3504
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
3505
        if err != nil {
×
3506
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
3507
                        nodeID, err)
×
3508
        }
×
3509

3510
        // Copy the nodes latest set of addresses.
3511
        newAddresses := map[dbAddressType][]string{
×
3512
                addressTypeIPv4:   {},
×
3513
                addressTypeIPv6:   {},
×
3514
                addressTypeTorV2:  {},
×
3515
                addressTypeTorV3:  {},
×
3516
                addressTypeOpaque: {},
×
3517
        }
×
3518
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3519
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3520
        }
×
3521

3522
        for _, address := range addresses {
×
3523
                switch addr := address.(type) {
×
3524
                case *net.TCPAddr:
×
3525
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3526
                                addAddr(addressTypeIPv4, addr)
×
3527
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3528
                                addAddr(addressTypeIPv6, addr)
×
3529
                        } else {
×
3530
                                return fmt.Errorf("unhandled IP address: %v",
×
3531
                                        addr)
×
3532
                        }
×
3533

3534
                case *tor.OnionAddr:
×
3535
                        switch len(addr.OnionService) {
×
3536
                        case tor.V2Len:
×
3537
                                addAddr(addressTypeTorV2, addr)
×
3538
                        case tor.V3Len:
×
3539
                                addAddr(addressTypeTorV3, addr)
×
3540
                        default:
×
3541
                                return fmt.Errorf("invalid length for a tor " +
×
3542
                                        "address")
×
3543
                        }
3544

3545
                case *lnwire.OpaqueAddrs:
×
3546
                        addAddr(addressTypeOpaque, addr)
×
3547

3548
                default:
×
3549
                        return fmt.Errorf("unhandled address type: %T", addr)
×
3550
                }
3551
        }
3552

3553
        // Any remaining entries in newAddresses are new addresses that need to
3554
        // be added to the database for the first time.
3555
        for addrType, addrList := range newAddresses {
×
3556
                for position, addr := range addrList {
×
3557
                        err := db.InsertNodeAddress(
×
3558
                                ctx, sqlc.InsertNodeAddressParams{
×
3559
                                        NodeID:   nodeID,
×
3560
                                        Type:     int16(addrType),
×
3561
                                        Address:  addr,
×
3562
                                        Position: int32(position),
×
3563
                                },
×
3564
                        )
×
3565
                        if err != nil {
×
3566
                                return fmt.Errorf("unable to insert "+
×
3567
                                        "node(%d) address(%v): %w", nodeID,
×
3568
                                        addr, err)
×
3569
                        }
×
3570
                }
3571
        }
3572

3573
        return nil
×
3574
}
3575

3576
// getNodeAddresses fetches the addresses for a node with the given DB ID.
3577
func getNodeAddresses(ctx context.Context, db SQLQueries, id int64) ([]net.Addr,
3578
        error) {
×
3579

×
3580
        // GetNodeAddresses ensures that the addresses for a given type are
×
3581
        // returned in the same order as they were inserted.
×
3582
        rows, err := db.GetNodeAddresses(ctx, id)
×
3583
        if err != nil {
×
3584
                return nil, err
×
3585
        }
×
3586

3587
        addresses := make([]net.Addr, 0, len(rows))
×
3588
        for _, row := range rows {
×
3589
                address := row.Address
×
3590

×
3591
                addr, err := parseAddress(dbAddressType(row.Type), address)
×
3592
                if err != nil {
×
3593
                        return nil, fmt.Errorf("unable to parse address "+
×
3594
                                "for node(%d): %v: %w", id, address, err)
×
3595
                }
×
3596

3597
                addresses = append(addresses, addr)
×
3598
        }
3599

3600
        // If we have no addresses, then we'll return nil instead of an
3601
        // empty slice.
3602
        if len(addresses) == 0 {
×
3603
                addresses = nil
×
3604
        }
×
3605

3606
        return addresses, nil
×
3607
}
3608

3609
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
3610
// database. This includes updating any existing types, inserting any new types,
3611
// and deleting any types that are no longer present.
3612
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3613
        nodeID int64, extraFields map[uint64][]byte) error {
×
3614

×
3615
        // Get any existing extra signed fields for the node.
×
3616
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3617
        if err != nil {
×
3618
                return err
×
3619
        }
×
3620

3621
        // Make a lookup map of the existing field types so that we can use it
3622
        // to keep track of any fields we should delete.
3623
        m := make(map[uint64]bool)
×
3624
        for _, field := range existingFields {
×
3625
                m[uint64(field.Type)] = true
×
3626
        }
×
3627

3628
        // For all the new fields, we'll upsert them and remove them from the
3629
        // map of existing fields.
3630
        for tlvType, value := range extraFields {
×
3631
                err = db.UpsertNodeExtraType(
×
3632
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
3633
                                NodeID: nodeID,
×
3634
                                Type:   int64(tlvType),
×
3635
                                Value:  value,
×
3636
                        },
×
3637
                )
×
3638
                if err != nil {
×
3639
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
3640
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3641
                }
×
3642

3643
                // Remove the field from the map of existing fields if it was
3644
                // present.
3645
                delete(m, tlvType)
×
3646
        }
3647

3648
        // For all the fields that are left in the map of existing fields, we'll
3649
        // delete them as they are no longer present in the new set of fields.
3650
        for tlvType := range m {
×
3651
                err = db.DeleteExtraNodeType(
×
3652
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
3653
                                NodeID: nodeID,
×
3654
                                Type:   int64(tlvType),
×
3655
                        },
×
3656
                )
×
3657
                if err != nil {
×
3658
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
3659
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3660
                }
×
3661
        }
3662

3663
        return nil
×
3664
}
3665

3666
// srcNodeInfo holds the information about the source node of the graph.
3667
type srcNodeInfo struct {
3668
        // id is the DB level ID of the source node entry in the "nodes" table.
3669
        id int64
3670

3671
        // pub is the public key of the source node.
3672
        pub route.Vertex
3673
}
3674

3675
// sourceNode returns the DB node ID and pub key of the source node for the
3676
// specified protocol version.
3677
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
3678
        version ProtocolVersion) (int64, route.Vertex, error) {
×
3679

×
3680
        s.srcNodeMu.Lock()
×
3681
        defer s.srcNodeMu.Unlock()
×
3682

×
3683
        // If we already have the source node ID and pub key cached, then
×
3684
        // return them.
×
3685
        if info, ok := s.srcNodes[version]; ok {
×
3686
                return info.id, info.pub, nil
×
3687
        }
×
3688

3689
        var pubKey route.Vertex
×
3690

×
3691
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
3692
        if err != nil {
×
3693
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
3694
                        err)
×
3695
        }
×
3696

3697
        if len(nodes) == 0 {
×
3698
                return 0, pubKey, ErrSourceNodeNotSet
×
3699
        } else if len(nodes) > 1 {
×
3700
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
3701
                        "protocol %s found", version)
×
3702
        }
×
3703

3704
        copy(pubKey[:], nodes[0].PubKey)
×
3705

×
3706
        s.srcNodes[version] = &srcNodeInfo{
×
3707
                id:  nodes[0].NodeID,
×
3708
                pub: pubKey,
×
3709
        }
×
3710

×
3711
        return nodes[0].NodeID, pubKey, nil
×
3712
}
3713

3714
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
3715
// This then produces a map from TLV type to value. If the input is not a
3716
// valid TLV stream, then an error is returned.
3717
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
3718
        r := bytes.NewReader(data)
×
3719

×
3720
        tlvStream, err := tlv.NewStream()
×
3721
        if err != nil {
×
3722
                return nil, err
×
3723
        }
×
3724

3725
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
3726
        // pass it into the P2P decoding variant.
3727
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
3728
        if err != nil {
×
3729
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
3730
        }
×
3731
        if len(parsedTypes) == 0 {
×
3732
                return nil, nil
×
3733
        }
×
3734

3735
        records := make(map[uint64][]byte)
×
3736
        for k, v := range parsedTypes {
×
3737
                records[uint64(k)] = v
×
3738
        }
×
3739

3740
        return records, nil
×
3741
}
3742

3743
// dbChanInfo holds the DB level IDs of a channel and the nodes involved in the
3744
// channel.
3745
type dbChanInfo struct {
3746
        channelID int64
3747
        node1ID   int64
3748
        node2ID   int64
3749
}
3750

3751
// insertChannel inserts a new channel record into the database.
3752
func insertChannel(ctx context.Context, db SQLQueries,
3753
        edge *models.ChannelEdgeInfo) (*dbChanInfo, error) {
×
3754

×
3755
        chanIDB := channelIDToBytes(edge.ChannelID)
×
3756

×
3757
        // Make sure that the channel doesn't already exist. We do this
×
3758
        // explicitly instead of relying on catching a unique constraint error
×
3759
        // because relying on SQL to throw that error would abort the entire
×
3760
        // batch of transactions.
×
3761
        _, err := db.GetChannelBySCID(
×
3762
                ctx, sqlc.GetChannelBySCIDParams{
×
3763
                        Scid:    chanIDB,
×
3764
                        Version: int16(ProtocolV1),
×
3765
                },
×
3766
        )
×
3767
        if err == nil {
×
3768
                return nil, ErrEdgeAlreadyExist
×
3769
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3770
                return nil, fmt.Errorf("unable to fetch channel: %w", err)
×
3771
        }
×
3772

3773
        // Make sure that at least a "shell" entry for each node is present in
3774
        // the nodes table.
3775
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
3776
        if err != nil {
×
3777
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3778
        }
×
3779

3780
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
3781
        if err != nil {
×
3782
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3783
        }
×
3784

3785
        var capacity sql.NullInt64
×
3786
        if edge.Capacity != 0 {
×
3787
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
3788
        }
×
3789

3790
        createParams := sqlc.CreateChannelParams{
×
3791
                Version:     int16(ProtocolV1),
×
3792
                Scid:        chanIDB,
×
3793
                NodeID1:     node1DBID,
×
3794
                NodeID2:     node2DBID,
×
3795
                Outpoint:    edge.ChannelPoint.String(),
×
3796
                Capacity:    capacity,
×
3797
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
3798
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
3799
        }
×
3800

×
3801
        if edge.AuthProof != nil {
×
3802
                proof := edge.AuthProof
×
3803

×
3804
                createParams.Node1Signature = proof.NodeSig1Bytes
×
3805
                createParams.Node2Signature = proof.NodeSig2Bytes
×
3806
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
3807
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
3808
        }
×
3809

3810
        // Insert the new channel record.
3811
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
3812
        if err != nil {
×
3813
                return nil, err
×
3814
        }
×
3815

3816
        // Insert any channel features.
3817
        for feature := range edge.Features.Features() {
×
3818
                err = db.InsertChannelFeature(
×
3819
                        ctx, sqlc.InsertChannelFeatureParams{
×
3820
                                ChannelID:  dbChanID,
×
3821
                                FeatureBit: int32(feature),
×
3822
                        },
×
3823
                )
×
3824
                if err != nil {
×
3825
                        return nil, fmt.Errorf("unable to insert channel(%d) "+
×
3826
                                "feature(%v): %w", dbChanID, feature, err)
×
3827
                }
×
3828
        }
3829

3830
        // Finally, insert any extra TLV fields in the channel announcement.
3831
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3832
        if err != nil {
×
3833
                return nil, fmt.Errorf("unable to marshal extra opaque "+
×
3834
                        "data: %w", err)
×
3835
        }
×
3836

3837
        for tlvType, value := range extra {
×
3838
                err := db.CreateChannelExtraType(
×
3839
                        ctx, sqlc.CreateChannelExtraTypeParams{
×
3840
                                ChannelID: dbChanID,
×
3841
                                Type:      int64(tlvType),
×
3842
                                Value:     value,
×
3843
                        },
×
3844
                )
×
3845
                if err != nil {
×
3846
                        return nil, fmt.Errorf("unable to upsert "+
×
3847
                                "channel(%d) extra signed field(%v): %w",
×
3848
                                edge.ChannelID, tlvType, err)
×
3849
                }
×
3850
        }
3851

3852
        return &dbChanInfo{
×
3853
                channelID: dbChanID,
×
3854
                node1ID:   node1DBID,
×
3855
                node2ID:   node2DBID,
×
3856
        }, nil
×
3857
}
3858

3859
// maybeCreateShellNode checks if a shell node entry exists for the
3860
// given public key. If it does not exist, then a new shell node entry is
3861
// created. The ID of the node is returned. A shell node only has a protocol
3862
// version and public key persisted.
3863
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
3864
        pubKey route.Vertex) (int64, error) {
×
3865

×
3866
        dbNode, err := db.GetNodeByPubKey(
×
3867
                ctx, sqlc.GetNodeByPubKeyParams{
×
3868
                        PubKey:  pubKey[:],
×
3869
                        Version: int16(ProtocolV1),
×
3870
                },
×
3871
        )
×
3872
        // The node exists. Return the ID.
×
3873
        if err == nil {
×
3874
                return dbNode.ID, nil
×
3875
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3876
                return 0, err
×
3877
        }
×
3878

3879
        // Otherwise, the node does not exist, so we create a shell entry for
3880
        // it.
3881
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
3882
                Version: int16(ProtocolV1),
×
3883
                PubKey:  pubKey[:],
×
3884
        })
×
3885
        if err != nil {
×
3886
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
3887
        }
×
3888

3889
        return id, nil
×
3890
}
3891

3892
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
3893
// the database. This includes deleting any existing types and then inserting
3894
// the new types.
3895
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
3896
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
3897

×
3898
        // Delete all existing extra signed fields for the channel policy.
×
3899
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
3900
        if err != nil {
×
3901
                return fmt.Errorf("unable to delete "+
×
3902
                        "existing policy extra signed fields for policy %d: %w",
×
3903
                        chanPolicyID, err)
×
3904
        }
×
3905

3906
        // Insert all new extra signed fields for the channel policy.
3907
        for tlvType, value := range extraFields {
×
3908
                err = db.InsertChanPolicyExtraType(
×
3909
                        ctx, sqlc.InsertChanPolicyExtraTypeParams{
×
3910
                                ChannelPolicyID: chanPolicyID,
×
3911
                                Type:            int64(tlvType),
×
3912
                                Value:           value,
×
3913
                        },
×
3914
                )
×
3915
                if err != nil {
×
3916
                        return fmt.Errorf("unable to insert "+
×
3917
                                "channel_policy(%d) extra signed field(%v): %w",
×
3918
                                chanPolicyID, tlvType, err)
×
3919
                }
×
3920
        }
3921

3922
        return nil
×
3923
}
3924

3925
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
3926
// provided dbChanRow and also fetches any other required information
3927
// to construct the edge info.
3928
func getAndBuildEdgeInfo(ctx context.Context, db SQLQueries,
3929
        chain chainhash.Hash, dbChan sqlc.GraphChannel, node1,
3930
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
3931

×
3932
        // NOTE: getAndBuildEdgeInfo is only used to load the data for a single
×
3933
        // edge, and so no paged queries will be performed. This means that
×
3934
        // it's ok to used pass in default config values here.
×
NEW
3935
        cfg := sqldb.DefaultQueryConfig()
×
3936

×
3937
        data, err := batchLoadChannelData(ctx, cfg, db, []int64{dbChan.ID}, nil)
×
3938
        if err != nil {
×
3939
                return nil, fmt.Errorf("unable to batch load channel data: %w",
×
3940
                        err)
×
3941
        }
×
3942

3943
        return buildEdgeInfoWithBatchData(chain, dbChan, node1, node2, data)
×
3944
}
3945

3946
// buildEdgeInfoWithBatchData builds edge info using pre-loaded batch data.
3947
func buildEdgeInfoWithBatchData(chain chainhash.Hash,
3948
        dbChan sqlc.GraphChannel, node1, node2 route.Vertex,
3949
        batchData *batchChannelData) (*models.ChannelEdgeInfo, error) {
×
3950

×
3951
        if dbChan.Version != int16(ProtocolV1) {
×
3952
                return nil, fmt.Errorf("unsupported channel version: %d",
×
3953
                        dbChan.Version)
×
3954
        }
×
3955

3956
        // Use pre-loaded features and extras types.
3957
        fv := lnwire.EmptyFeatureVector()
×
3958
        if features, exists := batchData.chanfeatures[dbChan.ID]; exists {
×
3959
                for _, bit := range features {
×
3960
                        fv.Set(lnwire.FeatureBit(bit))
×
3961
                }
×
3962
        }
3963

3964
        var extras map[uint64][]byte
×
3965
        channelExtras, exists := batchData.chanExtraTypes[dbChan.ID]
×
3966
        if exists {
×
3967
                extras = channelExtras
×
3968
        } else {
×
3969
                extras = make(map[uint64][]byte)
×
3970
        }
×
3971

3972
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
3973
        if err != nil {
×
3974
                return nil, err
×
3975
        }
×
3976

3977
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
3978
        if err != nil {
×
3979
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
3980
                        "fields: %w", err)
×
3981
        }
×
3982
        if recs == nil {
×
3983
                recs = make([]byte, 0)
×
3984
        }
×
3985

3986
        var btcKey1, btcKey2 route.Vertex
×
3987
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
3988
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
3989

×
3990
        channel := &models.ChannelEdgeInfo{
×
3991
                ChainHash:        chain,
×
3992
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
3993
                NodeKey1Bytes:    node1,
×
3994
                NodeKey2Bytes:    node2,
×
3995
                BitcoinKey1Bytes: btcKey1,
×
3996
                BitcoinKey2Bytes: btcKey2,
×
3997
                ChannelPoint:     *op,
×
3998
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
3999
                Features:         fv,
×
4000
                ExtraOpaqueData:  recs,
×
4001
        }
×
4002

×
4003
        // We always set all the signatures at the same time, so we can
×
4004
        // safely check if one signature is present to determine if we have the
×
4005
        // rest of the signatures for the auth proof.
×
4006
        if len(dbChan.Bitcoin1Signature) > 0 {
×
4007
                channel.AuthProof = &models.ChannelAuthProof{
×
4008
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
4009
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
4010
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
4011
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
4012
                }
×
4013
        }
×
4014

4015
        return channel, nil
×
4016
}
4017

4018
// buildNodeVertices is a helper that converts raw node public keys
4019
// into route.Vertex instances.
4020
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4021
        route.Vertex, error) {
×
4022

×
4023
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4024
        if err != nil {
×
4025
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4026
                        "create vertex from node1 pubkey: %w", err)
×
4027
        }
×
4028

4029
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
4030
        if err != nil {
×
4031
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4032
                        "create vertex from node2 pubkey: %w", err)
×
4033
        }
×
4034

4035
        return node1Vertex, node2Vertex, nil
×
4036
}
4037

4038
// getAndBuildChanPolicies uses the given sqlc.GraphChannelPolicy and also
4039
// retrieves all the extra info required to build the complete
4040
// models.ChannelEdgePolicy types. It returns two policies, which may be nil if
4041
// the provided sqlc.GraphChannelPolicy records are nil.
4042
func getAndBuildChanPolicies(ctx context.Context, db SQLQueries,
4043
        dbPol1, dbPol2 *sqlc.GraphChannelPolicy, channelID uint64, node1,
4044
        node2 route.Vertex) (*models.ChannelEdgePolicy,
4045
        *models.ChannelEdgePolicy, error) {
×
4046

×
4047
        if dbPol1 == nil && dbPol2 == nil {
×
4048
                return nil, nil, nil
×
4049
        }
×
4050

4051
        var policyIDs = make([]int64, 0, 2)
×
4052
        if dbPol1 != nil {
×
4053
                policyIDs = append(policyIDs, dbPol1.ID)
×
4054
        }
×
4055
        if dbPol2 != nil {
×
4056
                policyIDs = append(policyIDs, dbPol2.ID)
×
4057
        }
×
4058

4059
        // NOTE: getAndBuildChanPolicies is only used to load the data for
4060
        // a maximum of two policies, and so no paged queries will be
4061
        // performed (unless the page size is one). So it's ok to use
4062
        // the default config values here.
NEW
4063
        cfg := sqldb.DefaultQueryConfig()
×
4064

×
4065
        batchData, err := batchLoadChannelData(ctx, cfg, db, nil, policyIDs)
×
4066
        if err != nil {
×
4067
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
4068
                        "data: %w", err)
×
4069
        }
×
4070

4071
        pol1, err := buildChanPolicyWithBatchData(
×
4072
                dbPol1, channelID, node2, batchData,
×
4073
        )
×
4074
        if err != nil {
×
4075
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
4076
        }
×
4077

4078
        pol2, err := buildChanPolicyWithBatchData(
×
4079
                dbPol2, channelID, node1, batchData,
×
4080
        )
×
4081
        if err != nil {
×
4082
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
4083
        }
×
4084

4085
        return pol1, pol2, nil
×
4086
}
4087

4088
// buildChanPolicyWithBatchData models.CachedEdgePolicy instances from the
4089
// provided sqlc.GraphChannelPolicy objects. If the provided policy is nil,
4090
// then nil is returned.
4091
func buildCachedChanPolicies(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4092
        channelID uint64, node1, node2 route.Vertex) (*models.CachedEdgePolicy,
NEW
4093
        *models.CachedEdgePolicy, error) {
×
NEW
4094

×
NEW
4095
        var p1, p2 *models.CachedEdgePolicy
×
NEW
4096
        if dbPol1 != nil {
×
NEW
4097
                policy1, err := buildChanPolicy(*dbPol1, channelID, nil, node2)
×
NEW
4098
                if err != nil {
×
NEW
4099
                        return nil, nil, err
×
NEW
4100
                }
×
4101

NEW
4102
                p1 = models.NewCachedPolicy(policy1)
×
4103
        }
NEW
4104
        if dbPol2 != nil {
×
NEW
4105
                policy2, err := buildChanPolicy(*dbPol2, channelID, nil, node1)
×
NEW
4106
                if err != nil {
×
NEW
4107
                        return nil, nil, err
×
NEW
4108
                }
×
4109

NEW
4110
                p2 = models.NewCachedPolicy(policy2)
×
4111
        }
4112

NEW
4113
        return p1, p2, nil
×
4114
}
4115

4116
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4117
// provided sqlc.GraphChannelPolicy and other required information.
4118
func buildChanPolicy(dbPolicy sqlc.GraphChannelPolicy, channelID uint64,
4119
        extras map[uint64][]byte,
4120
        toNode route.Vertex) (*models.ChannelEdgePolicy, error) {
×
4121

×
4122
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4123
        if err != nil {
×
4124
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4125
                        "fields: %w", err)
×
4126
        }
×
4127

4128
        var inboundFee fn.Option[lnwire.Fee]
×
4129
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
4130
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4131

×
4132
                inboundFee = fn.Some(lnwire.Fee{
×
4133
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4134
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4135
                })
×
4136
        }
×
4137

4138
        return &models.ChannelEdgePolicy{
×
4139
                SigBytes:  dbPolicy.Signature,
×
4140
                ChannelID: channelID,
×
4141
                LastUpdate: time.Unix(
×
4142
                        dbPolicy.LastUpdate.Int64, 0,
×
4143
                ),
×
4144
                MessageFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags](
×
4145
                        dbPolicy.MessageFlags,
×
4146
                ),
×
4147
                ChannelFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags](
×
4148
                        dbPolicy.ChannelFlags,
×
4149
                ),
×
4150
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
4151
                MinHTLC: lnwire.MilliSatoshi(
×
4152
                        dbPolicy.MinHtlcMsat,
×
4153
                ),
×
4154
                MaxHTLC: lnwire.MilliSatoshi(
×
4155
                        dbPolicy.MaxHtlcMsat.Int64,
×
4156
                ),
×
4157
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4158
                        dbPolicy.BaseFeeMsat,
×
4159
                ),
×
4160
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4161
                ToNode:                    toNode,
×
4162
                InboundFee:                inboundFee,
×
4163
                ExtraOpaqueData:           recs,
×
4164
        }, nil
×
4165
}
4166

4167
// buildNodes builds the models.LightningNode instances for the
4168
// given row which is expected to be a sqlc type that contains node information.
4169
func buildNodes(ctx context.Context, db SQLQueries, dbNode1,
4170
        dbNode2 sqlc.GraphNode) (*models.LightningNode, *models.LightningNode,
4171
        error) {
×
4172

×
4173
        node1, err := buildNode(ctx, db, &dbNode1)
×
4174
        if err != nil {
×
4175
                return nil, nil, err
×
4176
        }
×
4177

4178
        node2, err := buildNode(ctx, db, &dbNode2)
×
4179
        if err != nil {
×
4180
                return nil, nil, err
×
4181
        }
×
4182

4183
        return node1, node2, nil
×
4184
}
4185

4186
// extractChannelPolicies extracts the sqlc.GraphChannelPolicy records from the give
4187
// row which is expected to be a sqlc type that contains channel policy
4188
// information. It returns two policies, which may be nil if the policy
4189
// information is not present in the row.
4190
//
4191
//nolint:ll,dupl,funlen
4192
func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy,
4193
        *sqlc.GraphChannelPolicy, error) {
×
4194

×
4195
        var policy1, policy2 *sqlc.GraphChannelPolicy
×
4196
        switch r := row.(type) {
×
4197
        case sqlc.ListChannelsWithPoliciesForCachePaginatedRow:
×
4198
                if r.Policy1Timelock.Valid {
×
4199
                        policy1 = &sqlc.GraphChannelPolicy{
×
4200
                                Timelock:                r.Policy1Timelock.Int32,
×
4201
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4202
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4203
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4204
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4205
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4206
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4207
                                Disabled:                r.Policy1Disabled,
×
4208
                                MessageFlags:            r.Policy1MessageFlags,
×
4209
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4210
                        }
×
4211
                }
×
4212
                if r.Policy2Timelock.Valid {
×
4213
                        policy2 = &sqlc.GraphChannelPolicy{
×
4214
                                Timelock:                r.Policy2Timelock.Int32,
×
4215
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4216
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4217
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4218
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4219
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4220
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4221
                                Disabled:                r.Policy2Disabled,
×
4222
                                MessageFlags:            r.Policy2MessageFlags,
×
4223
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4224
                        }
×
4225
                }
×
4226

4227
                return policy1, policy2, nil
×
4228

4229
        case sqlc.GetChannelsBySCIDWithPoliciesRow:
×
4230
                if r.Policy1ID.Valid {
×
4231
                        policy1 = &sqlc.GraphChannelPolicy{
×
4232
                                ID:                      r.Policy1ID.Int64,
×
4233
                                Version:                 r.Policy1Version.Int16,
×
4234
                                ChannelID:               r.GraphChannel.ID,
×
4235
                                NodeID:                  r.Policy1NodeID.Int64,
×
4236
                                Timelock:                r.Policy1Timelock.Int32,
×
4237
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4238
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4239
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4240
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4241
                                LastUpdate:              r.Policy1LastUpdate,
×
4242
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4243
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4244
                                Disabled:                r.Policy1Disabled,
×
4245
                                MessageFlags:            r.Policy1MessageFlags,
×
4246
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4247
                                Signature:               r.Policy1Signature,
×
4248
                        }
×
4249
                }
×
4250
                if r.Policy2ID.Valid {
×
4251
                        policy2 = &sqlc.GraphChannelPolicy{
×
4252
                                ID:                      r.Policy2ID.Int64,
×
4253
                                Version:                 r.Policy2Version.Int16,
×
4254
                                ChannelID:               r.GraphChannel.ID,
×
4255
                                NodeID:                  r.Policy2NodeID.Int64,
×
4256
                                Timelock:                r.Policy2Timelock.Int32,
×
4257
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4258
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4259
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4260
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4261
                                LastUpdate:              r.Policy2LastUpdate,
×
4262
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4263
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4264
                                Disabled:                r.Policy2Disabled,
×
4265
                                MessageFlags:            r.Policy2MessageFlags,
×
4266
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4267
                                Signature:               r.Policy2Signature,
×
4268
                        }
×
4269
                }
×
4270

4271
                return policy1, policy2, nil
×
4272

4273
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
4274
                if r.Policy1ID.Valid {
×
4275
                        policy1 = &sqlc.GraphChannelPolicy{
×
4276
                                ID:                      r.Policy1ID.Int64,
×
4277
                                Version:                 r.Policy1Version.Int16,
×
4278
                                ChannelID:               r.GraphChannel.ID,
×
4279
                                NodeID:                  r.Policy1NodeID.Int64,
×
4280
                                Timelock:                r.Policy1Timelock.Int32,
×
4281
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4282
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4283
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4284
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4285
                                LastUpdate:              r.Policy1LastUpdate,
×
4286
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4287
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4288
                                Disabled:                r.Policy1Disabled,
×
4289
                                MessageFlags:            r.Policy1MessageFlags,
×
4290
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4291
                                Signature:               r.Policy1Signature,
×
4292
                        }
×
4293
                }
×
4294
                if r.Policy2ID.Valid {
×
4295
                        policy2 = &sqlc.GraphChannelPolicy{
×
4296
                                ID:                      r.Policy2ID.Int64,
×
4297
                                Version:                 r.Policy2Version.Int16,
×
4298
                                ChannelID:               r.GraphChannel.ID,
×
4299
                                NodeID:                  r.Policy2NodeID.Int64,
×
4300
                                Timelock:                r.Policy2Timelock.Int32,
×
4301
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4302
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4303
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4304
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4305
                                LastUpdate:              r.Policy2LastUpdate,
×
4306
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4307
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4308
                                Disabled:                r.Policy2Disabled,
×
4309
                                MessageFlags:            r.Policy2MessageFlags,
×
4310
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4311
                                Signature:               r.Policy2Signature,
×
4312
                        }
×
4313
                }
×
4314

4315
                return policy1, policy2, nil
×
4316

4317
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4318
                if r.Policy1ID.Valid {
×
4319
                        policy1 = &sqlc.GraphChannelPolicy{
×
4320
                                ID:                      r.Policy1ID.Int64,
×
4321
                                Version:                 r.Policy1Version.Int16,
×
4322
                                ChannelID:               r.GraphChannel.ID,
×
4323
                                NodeID:                  r.Policy1NodeID.Int64,
×
4324
                                Timelock:                r.Policy1Timelock.Int32,
×
4325
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4326
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4327
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4328
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4329
                                LastUpdate:              r.Policy1LastUpdate,
×
4330
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4331
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4332
                                Disabled:                r.Policy1Disabled,
×
4333
                                MessageFlags:            r.Policy1MessageFlags,
×
4334
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4335
                                Signature:               r.Policy1Signature,
×
4336
                        }
×
4337
                }
×
4338
                if r.Policy2ID.Valid {
×
4339
                        policy2 = &sqlc.GraphChannelPolicy{
×
4340
                                ID:                      r.Policy2ID.Int64,
×
4341
                                Version:                 r.Policy2Version.Int16,
×
4342
                                ChannelID:               r.GraphChannel.ID,
×
4343
                                NodeID:                  r.Policy2NodeID.Int64,
×
4344
                                Timelock:                r.Policy2Timelock.Int32,
×
4345
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4346
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4347
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4348
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4349
                                LastUpdate:              r.Policy2LastUpdate,
×
4350
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4351
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4352
                                Disabled:                r.Policy2Disabled,
×
4353
                                MessageFlags:            r.Policy2MessageFlags,
×
4354
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4355
                                Signature:               r.Policy2Signature,
×
4356
                        }
×
4357
                }
×
4358

4359
                return policy1, policy2, nil
×
4360

4361
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4362
                if r.Policy1ID.Valid {
×
4363
                        policy1 = &sqlc.GraphChannelPolicy{
×
4364
                                ID:                      r.Policy1ID.Int64,
×
4365
                                Version:                 r.Policy1Version.Int16,
×
4366
                                ChannelID:               r.GraphChannel.ID,
×
4367
                                NodeID:                  r.Policy1NodeID.Int64,
×
4368
                                Timelock:                r.Policy1Timelock.Int32,
×
4369
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4370
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4371
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4372
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4373
                                LastUpdate:              r.Policy1LastUpdate,
×
4374
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4375
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4376
                                Disabled:                r.Policy1Disabled,
×
4377
                                MessageFlags:            r.Policy1MessageFlags,
×
4378
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4379
                                Signature:               r.Policy1Signature,
×
4380
                        }
×
4381
                }
×
4382
                if r.Policy2ID.Valid {
×
4383
                        policy2 = &sqlc.GraphChannelPolicy{
×
4384
                                ID:                      r.Policy2ID.Int64,
×
4385
                                Version:                 r.Policy2Version.Int16,
×
4386
                                ChannelID:               r.GraphChannel.ID,
×
4387
                                NodeID:                  r.Policy2NodeID.Int64,
×
4388
                                Timelock:                r.Policy2Timelock.Int32,
×
4389
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4390
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4391
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4392
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4393
                                LastUpdate:              r.Policy2LastUpdate,
×
4394
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4395
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4396
                                Disabled:                r.Policy2Disabled,
×
4397
                                MessageFlags:            r.Policy2MessageFlags,
×
4398
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4399
                                Signature:               r.Policy2Signature,
×
4400
                        }
×
4401
                }
×
4402

4403
                return policy1, policy2, nil
×
4404

4405
        case sqlc.ListChannelsByNodeIDRow:
×
4406
                if r.Policy1ID.Valid {
×
4407
                        policy1 = &sqlc.GraphChannelPolicy{
×
4408
                                ID:                      r.Policy1ID.Int64,
×
4409
                                Version:                 r.Policy1Version.Int16,
×
4410
                                ChannelID:               r.GraphChannel.ID,
×
4411
                                NodeID:                  r.Policy1NodeID.Int64,
×
4412
                                Timelock:                r.Policy1Timelock.Int32,
×
4413
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4414
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4415
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4416
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4417
                                LastUpdate:              r.Policy1LastUpdate,
×
4418
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4419
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4420
                                Disabled:                r.Policy1Disabled,
×
4421
                                MessageFlags:            r.Policy1MessageFlags,
×
4422
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4423
                                Signature:               r.Policy1Signature,
×
4424
                        }
×
4425
                }
×
4426
                if r.Policy2ID.Valid {
×
4427
                        policy2 = &sqlc.GraphChannelPolicy{
×
4428
                                ID:                      r.Policy2ID.Int64,
×
4429
                                Version:                 r.Policy2Version.Int16,
×
4430
                                ChannelID:               r.GraphChannel.ID,
×
4431
                                NodeID:                  r.Policy2NodeID.Int64,
×
4432
                                Timelock:                r.Policy2Timelock.Int32,
×
4433
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4434
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4435
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4436
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4437
                                LastUpdate:              r.Policy2LastUpdate,
×
4438
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4439
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4440
                                Disabled:                r.Policy2Disabled,
×
4441
                                MessageFlags:            r.Policy2MessageFlags,
×
4442
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4443
                                Signature:               r.Policy2Signature,
×
4444
                        }
×
4445
                }
×
4446

4447
                return policy1, policy2, nil
×
4448

4449
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4450
                if r.Policy1ID.Valid {
×
4451
                        policy1 = &sqlc.GraphChannelPolicy{
×
4452
                                ID:                      r.Policy1ID.Int64,
×
4453
                                Version:                 r.Policy1Version.Int16,
×
4454
                                ChannelID:               r.GraphChannel.ID,
×
4455
                                NodeID:                  r.Policy1NodeID.Int64,
×
4456
                                Timelock:                r.Policy1Timelock.Int32,
×
4457
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4458
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4459
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4460
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4461
                                LastUpdate:              r.Policy1LastUpdate,
×
4462
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4463
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4464
                                Disabled:                r.Policy1Disabled,
×
4465
                                MessageFlags:            r.Policy1MessageFlags,
×
4466
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4467
                                Signature:               r.Policy1Signature,
×
4468
                        }
×
4469
                }
×
4470
                if r.Policy2ID.Valid {
×
4471
                        policy2 = &sqlc.GraphChannelPolicy{
×
4472
                                ID:                      r.Policy2ID.Int64,
×
4473
                                Version:                 r.Policy2Version.Int16,
×
4474
                                ChannelID:               r.GraphChannel.ID,
×
4475
                                NodeID:                  r.Policy2NodeID.Int64,
×
4476
                                Timelock:                r.Policy2Timelock.Int32,
×
4477
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4478
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4479
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4480
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4481
                                LastUpdate:              r.Policy2LastUpdate,
×
4482
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4483
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4484
                                Disabled:                r.Policy2Disabled,
×
4485
                                MessageFlags:            r.Policy2MessageFlags,
×
4486
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4487
                                Signature:               r.Policy2Signature,
×
4488
                        }
×
4489
                }
×
4490

4491
                return policy1, policy2, nil
×
4492
        default:
×
4493
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
4494
                        "extractChannelPolicies: %T", r)
×
4495
        }
4496
}
4497

4498
// channelIDToBytes converts a channel ID (SCID) to a byte array
4499
// representation.
4500
func channelIDToBytes(channelID uint64) []byte {
×
4501
        var chanIDB [8]byte
×
4502
        byteOrder.PutUint64(chanIDB[:], channelID)
×
4503

×
4504
        return chanIDB[:]
×
4505
}
×
4506

4507
// buildNodeAddresses converts a slice of nodeAddress into a slice of net.Addr.
4508
func buildNodeAddresses(addresses []nodeAddress) ([]net.Addr, error) {
×
4509
        if len(addresses) == 0 {
×
4510
                return nil, nil
×
4511
        }
×
4512

4513
        result := make([]net.Addr, 0, len(addresses))
×
4514
        for _, addr := range addresses {
×
4515
                netAddr, err := parseAddress(addr.addrType, addr.address)
×
4516
                if err != nil {
×
4517
                        return nil, fmt.Errorf("unable to parse address %s "+
×
4518
                                "of type %d: %w", addr.address, addr.addrType,
×
4519
                                err)
×
4520
                }
×
4521
                if netAddr != nil {
×
4522
                        result = append(result, netAddr)
×
4523
                }
×
4524
        }
4525

4526
        // If we have no valid addresses, return nil instead of empty slice.
4527
        if len(result) == 0 {
×
4528
                return nil, nil
×
4529
        }
×
4530

4531
        return result, nil
×
4532
}
4533

4534
// parseAddress parses the given address string based on the address type
4535
// and returns a net.Addr instance. It supports IPv4, IPv6, Tor v2, Tor v3,
4536
// and opaque addresses.
4537
func parseAddress(addrType dbAddressType, address string) (net.Addr, error) {
×
4538
        switch addrType {
×
4539
        case addressTypeIPv4:
×
4540
                tcp, err := net.ResolveTCPAddr("tcp4", address)
×
4541
                if err != nil {
×
4542
                        return nil, err
×
4543
                }
×
4544

4545
                tcp.IP = tcp.IP.To4()
×
4546

×
4547
                return tcp, nil
×
4548

4549
        case addressTypeIPv6:
×
4550
                tcp, err := net.ResolveTCPAddr("tcp6", address)
×
4551
                if err != nil {
×
4552
                        return nil, err
×
4553
                }
×
4554

4555
                return tcp, nil
×
4556

4557
        case addressTypeTorV3, addressTypeTorV2:
×
4558
                service, portStr, err := net.SplitHostPort(address)
×
4559
                if err != nil {
×
4560
                        return nil, fmt.Errorf("unable to split tor "+
×
4561
                                "address: %v", address)
×
4562
                }
×
4563

4564
                port, err := strconv.Atoi(portStr)
×
4565
                if err != nil {
×
4566
                        return nil, err
×
4567
                }
×
4568

4569
                return &tor.OnionAddr{
×
4570
                        OnionService: service,
×
4571
                        Port:         port,
×
4572
                }, nil
×
4573

4574
        case addressTypeOpaque:
×
4575
                opaque, err := hex.DecodeString(address)
×
4576
                if err != nil {
×
4577
                        return nil, fmt.Errorf("unable to decode opaque "+
×
4578
                                "address: %v", address)
×
4579
                }
×
4580

4581
                return &lnwire.OpaqueAddrs{
×
4582
                        Payload: opaque,
×
4583
                }, nil
×
4584

4585
        default:
×
4586
                return nil, fmt.Errorf("unknown address type: %v", addrType)
×
4587
        }
4588
}
4589

4590
// batchNodeData holds all the related data for a batch of nodes.
4591
type batchNodeData struct {
4592
        // features is a map from a DB node ID to the feature bits for that
4593
        // node.
4594
        features map[int64][]int
4595

4596
        // addresses is a map from a DB node ID to the node's addresses.
4597
        addresses map[int64][]nodeAddress
4598

4599
        // extraFields is a map from a DB node ID to the extra signed fields
4600
        // for that node.
4601
        extraFields map[int64]map[uint64][]byte
4602
}
4603

4604
// nodeAddress holds the address type, position and address string for a
4605
// node. This is used to batch the fetching of node addresses.
4606
type nodeAddress struct {
4607
        addrType dbAddressType
4608
        position int32
4609
        address  string
4610
}
4611

4612
// batchLoadNodeData loads all related data for a batch of node IDs using the
4613
// provided SQLQueries interface. It returns a batchNodeData instance containing
4614
// the node features, addresses and extra signed fields.
4615
func batchLoadNodeData(ctx context.Context, cfg *sqldb.QueryConfig,
4616
        db SQLQueries, nodeIDs []int64) (*batchNodeData, error) {
×
4617

×
4618
        // Batch load the node features.
×
4619
        features, err := batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
4620
        if err != nil {
×
4621
                return nil, fmt.Errorf("unable to batch load node "+
×
4622
                        "features: %w", err)
×
4623
        }
×
4624

4625
        // Batch load the node addresses.
4626
        addrs, err := batchLoadNodeAddressesHelper(ctx, cfg, db, nodeIDs)
×
4627
        if err != nil {
×
4628
                return nil, fmt.Errorf("unable to batch load node "+
×
4629
                        "addresses: %w", err)
×
4630
        }
×
4631

4632
        // Batch load the node extra signed fields.
4633
        extraTypes, err := batchLoadNodeExtraTypesHelper(ctx, cfg, db, nodeIDs)
×
4634
        if err != nil {
×
4635
                return nil, fmt.Errorf("unable to batch load node extra "+
×
4636
                        "signed fields: %w", err)
×
4637
        }
×
4638

4639
        return &batchNodeData{
×
4640
                features:    features,
×
4641
                addresses:   addrs,
×
4642
                extraFields: extraTypes,
×
4643
        }, nil
×
4644
}
4645

4646
// batchLoadNodeFeaturesHelper loads node features for a batch of node IDs
4647
// using ExecuteBatchQuery wrapper around the GetNodeFeaturesBatch query.
4648
func batchLoadNodeFeaturesHelper(ctx context.Context,
4649
        cfg *sqldb.QueryConfig, db SQLQueries,
4650
        nodeIDs []int64) (map[int64][]int, error) {
×
4651

×
4652
        features := make(map[int64][]int)
×
4653

×
NEW
4654
        return features, sqldb.ExecuteBatchQuery(
×
4655
                ctx, cfg, nodeIDs,
×
4656
                func(id int64) int64 {
×
4657
                        return id
×
4658
                },
×
4659
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature,
4660
                        error) {
×
4661

×
4662
                        return db.GetNodeFeaturesBatch(ctx, ids)
×
4663
                },
×
4664
                func(ctx context.Context, feature sqlc.GraphNodeFeature) error {
×
4665
                        features[feature.NodeID] = append(
×
4666
                                features[feature.NodeID],
×
4667
                                int(feature.FeatureBit),
×
4668
                        )
×
4669

×
4670
                        return nil
×
4671
                },
×
4672
        )
4673
}
4674

4675
// batchLoadNodeAddressesHelper loads node addresses using ExecuteBatchQuery
4676
// wrapper around the GetNodeAddressesBatch query. It returns a map from
4677
// node ID to a slice of nodeAddress structs.
4678
func batchLoadNodeAddressesHelper(ctx context.Context,
4679
        cfg *sqldb.QueryConfig, db SQLQueries,
4680
        nodeIDs []int64) (map[int64][]nodeAddress, error) {
×
4681

×
4682
        addrs := make(map[int64][]nodeAddress)
×
4683

×
NEW
4684
        return addrs, sqldb.ExecuteBatchQuery(
×
4685
                ctx, cfg, nodeIDs,
×
4686
                func(id int64) int64 {
×
4687
                        return id
×
4688
                },
×
4689
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress,
4690
                        error) {
×
4691

×
4692
                        return db.GetNodeAddressesBatch(ctx, ids)
×
4693
                },
×
4694
                func(ctx context.Context, addr sqlc.GraphNodeAddress) error {
×
4695
                        addrs[addr.NodeID] = append(
×
4696
                                addrs[addr.NodeID], nodeAddress{
×
4697
                                        addrType: dbAddressType(addr.Type),
×
4698
                                        position: addr.Position,
×
4699
                                        address:  addr.Address,
×
4700
                                },
×
4701
                        )
×
4702

×
4703
                        return nil
×
4704
                },
×
4705
        )
4706
}
4707

4708
// batchLoadNodeExtraTypesHelper loads node extra type bytes for a batch of
4709
// node IDs using ExecuteBatchQuery wrapper around the GetNodeExtraTypesBatch
4710
// query.
4711
func batchLoadNodeExtraTypesHelper(ctx context.Context,
4712
        cfg *sqldb.QueryConfig, db SQLQueries,
4713
        nodeIDs []int64) (map[int64]map[uint64][]byte, error) {
×
4714

×
4715
        extraFields := make(map[int64]map[uint64][]byte)
×
4716

×
4717
        callback := func(ctx context.Context,
×
4718
                field sqlc.GraphNodeExtraType) error {
×
4719

×
4720
                if extraFields[field.NodeID] == nil {
×
4721
                        extraFields[field.NodeID] = make(map[uint64][]byte)
×
4722
                }
×
4723
                extraFields[field.NodeID][uint64(field.Type)] = field.Value
×
4724

×
4725
                return nil
×
4726
        }
4727

NEW
4728
        return extraFields, sqldb.ExecuteBatchQuery(
×
4729
                ctx, cfg, nodeIDs,
×
4730
                func(id int64) int64 {
×
4731
                        return id
×
4732
                },
×
4733
                func(ctx context.Context, ids []int64) (
4734
                        []sqlc.GraphNodeExtraType, error) {
×
4735

×
4736
                        return db.GetNodeExtraTypesBatch(ctx, ids)
×
4737
                },
×
4738
                callback,
4739
        )
4740
}
4741

4742
// buildChanPoliciesWithBatchData builds two models.ChannelEdgePolicy instances
4743
// from the provided sqlc.GraphChannelPolicy records and the
4744
// provided batchChannelData.
4745
func buildChanPoliciesWithBatchData(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4746
        channelID uint64, node1, node2 route.Vertex,
4747
        batchData *batchChannelData) (*models.ChannelEdgePolicy,
4748
        *models.ChannelEdgePolicy, error) {
×
4749

×
4750
        pol1, err := buildChanPolicyWithBatchData(
×
4751
                dbPol1, channelID, node2, batchData,
×
4752
        )
×
4753
        if err != nil {
×
4754
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
4755
        }
×
4756

4757
        pol2, err := buildChanPolicyWithBatchData(
×
4758
                dbPol2, channelID, node1, batchData,
×
4759
        )
×
4760
        if err != nil {
×
4761
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
4762
        }
×
4763

4764
        return pol1, pol2, nil
×
4765
}
4766

4767
// buildChanPolicyWithBatchData builds a models.ChannelEdgePolicy instance from
4768
// the provided sqlc.GraphChannelPolicy and the provided batchChannelData.
4769
func buildChanPolicyWithBatchData(dbPol *sqlc.GraphChannelPolicy,
4770
        channelID uint64, toNode route.Vertex,
4771
        batchData *batchChannelData) (*models.ChannelEdgePolicy, error) {
×
4772

×
4773
        if dbPol == nil {
×
4774
                return nil, nil
×
4775
        }
×
4776

4777
        var dbPol1Extras map[uint64][]byte
×
4778
        if extras, exists := batchData.policyExtras[dbPol.ID]; exists {
×
4779
                dbPol1Extras = extras
×
4780
        } else {
×
4781
                dbPol1Extras = make(map[uint64][]byte)
×
4782
        }
×
4783

4784
        return buildChanPolicy(*dbPol, channelID, dbPol1Extras, toNode)
×
4785
}
4786

4787
// batchChannelData holds all the related data for a batch of channels.
4788
type batchChannelData struct {
4789
        // chanFeatures is a map from DB channel ID to a slice of feature bits.
4790
        chanfeatures map[int64][]int
4791

4792
        // chanExtras is a map from DB channel ID to a map of TLV type to
4793
        // extra signed field bytes.
4794
        chanExtraTypes map[int64]map[uint64][]byte
4795

4796
        // policyExtras is a map from DB channel policy ID to a map of TLV type
4797
        // to extra signed field bytes.
4798
        policyExtras map[int64]map[uint64][]byte
4799
}
4800

4801
// batchLoadChannelData loads all related data for batches of channels and
4802
// policies.
4803
func batchLoadChannelData(ctx context.Context, cfg *sqldb.QueryConfig,
4804
        db SQLQueries, channelIDs []int64,
4805
        policyIDs []int64) (*batchChannelData, error) {
×
4806

×
4807
        batchData := &batchChannelData{
×
4808
                chanfeatures:   make(map[int64][]int),
×
4809
                chanExtraTypes: make(map[int64]map[uint64][]byte),
×
4810
                policyExtras:   make(map[int64]map[uint64][]byte),
×
4811
        }
×
4812

×
4813
        // Batch load channel features and extras
×
4814
        var err error
×
4815
        if len(channelIDs) > 0 {
×
4816
                batchData.chanfeatures, err = batchLoadChannelFeaturesHelper(
×
4817
                        ctx, cfg, db, channelIDs,
×
4818
                )
×
4819
                if err != nil {
×
4820
                        return nil, fmt.Errorf("unable to batch load "+
×
4821
                                "channel features: %w", err)
×
4822
                }
×
4823

4824
                batchData.chanExtraTypes, err = batchLoadChannelExtrasHelper(
×
4825
                        ctx, cfg, db, channelIDs,
×
4826
                )
×
4827
                if err != nil {
×
4828
                        return nil, fmt.Errorf("unable to batch load "+
×
4829
                                "channel extras: %w", err)
×
4830
                }
×
4831
        }
4832

4833
        if len(policyIDs) > 0 {
×
4834
                policyExtras, err := batchLoadChannelPolicyExtrasHelper(
×
4835
                        ctx, cfg, db, policyIDs,
×
4836
                )
×
4837
                if err != nil {
×
4838
                        return nil, fmt.Errorf("unable to batch load "+
×
4839
                                "policy extras: %w", err)
×
4840
                }
×
4841
                batchData.policyExtras = policyExtras
×
4842
        }
4843

4844
        return batchData, nil
×
4845
}
4846

4847
// batchLoadChannelFeaturesHelper loads channel features for a batch of
4848
// channel IDs using ExecuteBatchQuery wrapper around the
4849
// GetChannelFeaturesBatch query. It returns a map from DB channel ID to a
4850
// slice of feature bits.
4851
func batchLoadChannelFeaturesHelper(ctx context.Context,
4852
        cfg *sqldb.QueryConfig, db SQLQueries,
4853
        channelIDs []int64) (map[int64][]int, error) {
×
4854

×
4855
        features := make(map[int64][]int)
×
4856

×
NEW
4857
        return features, sqldb.ExecuteBatchQuery(
×
4858
                ctx, cfg, channelIDs,
×
4859
                func(id int64) int64 {
×
4860
                        return id
×
4861
                },
×
4862
                func(ctx context.Context,
4863
                        ids []int64) ([]sqlc.GraphChannelFeature, error) {
×
4864

×
4865
                        return db.GetChannelFeaturesBatch(ctx, ids)
×
4866
                },
×
4867
                func(ctx context.Context,
4868
                        feature sqlc.GraphChannelFeature) error {
×
4869

×
4870
                        features[feature.ChannelID] = append(
×
4871
                                features[feature.ChannelID],
×
4872
                                int(feature.FeatureBit),
×
4873
                        )
×
4874

×
4875
                        return nil
×
4876
                },
×
4877
        )
4878
}
4879

4880
// batchLoadChannelExtrasHelper loads channel extra types for a batch of
4881
// channel IDs using ExecuteBatchQuery wrapper around the GetChannelExtrasBatch
4882
// query. It returns a map from DB channel ID to a map of TLV type to extra
4883
// signed field bytes.
4884
func batchLoadChannelExtrasHelper(ctx context.Context,
4885
        cfg *sqldb.QueryConfig, db SQLQueries,
4886
        channelIDs []int64) (map[int64]map[uint64][]byte, error) {
×
4887

×
4888
        extras := make(map[int64]map[uint64][]byte)
×
4889

×
4890
        cb := func(ctx context.Context,
×
4891
                extra sqlc.GraphChannelExtraType) error {
×
4892

×
4893
                if extras[extra.ChannelID] == nil {
×
4894
                        extras[extra.ChannelID] = make(map[uint64][]byte)
×
4895
                }
×
4896
                extras[extra.ChannelID][uint64(extra.Type)] = extra.Value
×
4897

×
4898
                return nil
×
4899
        }
4900

NEW
4901
        return extras, sqldb.ExecuteBatchQuery(
×
4902
                ctx, cfg, channelIDs,
×
4903
                func(id int64) int64 {
×
4904
                        return id
×
4905
                },
×
4906
                func(ctx context.Context,
4907
                        ids []int64) ([]sqlc.GraphChannelExtraType, error) {
×
4908

×
4909
                        return db.GetChannelExtrasBatch(ctx, ids)
×
4910
                }, cb,
×
4911
        )
4912
}
4913

4914
// batchLoadChannelPolicyExtrasHelper loads channel policy extra types for a
4915
// batch of policy IDs using ExecuteBatchQuery wrapper around the
4916
// GetChannelPolicyExtraTypesBatch query. It returns a map from DB policy ID to
4917
// a map of TLV type to extra signed field bytes.
4918
func batchLoadChannelPolicyExtrasHelper(ctx context.Context,
4919
        cfg *sqldb.QueryConfig, db SQLQueries,
4920
        policyIDs []int64) (map[int64]map[uint64][]byte, error) {
×
4921

×
4922
        extras := make(map[int64]map[uint64][]byte)
×
4923

×
NEW
4924
        return extras, sqldb.ExecuteBatchQuery(
×
4925
                ctx, cfg, policyIDs,
×
4926
                func(id int64) int64 {
×
4927
                        return id
×
4928
                },
×
4929
                func(ctx context.Context, ids []int64) (
4930
                        []sqlc.GetChannelPolicyExtraTypesBatchRow, error) {
×
4931

×
4932
                        return db.GetChannelPolicyExtraTypesBatch(ctx, ids)
×
4933
                },
×
4934
                func(ctx context.Context,
4935
                        row sqlc.GetChannelPolicyExtraTypesBatchRow) error {
×
4936

×
4937
                        if extras[row.PolicyID] == nil {
×
4938
                                extras[row.PolicyID] = make(map[uint64][]byte)
×
4939
                        }
×
4940
                        extras[row.PolicyID][uint64(row.Type)] = row.Value
×
4941

×
4942
                        return nil
×
4943
                },
4944
        )
4945
}
4946

4947
// forEachNodePaginated executes a paginated query to process each node in the
4948
// graph. It uses the provided SQLQueries interface to fetch nodes in batches
4949
// and applies the provided processNode function to each node
4950
func forEachNodePaginated(ctx context.Context, cfg *sqldb.QueryConfig,
4951
        db SQLQueries, processNode func(context.Context, int64,
NEW
4952
                *models.LightningNode) error) error {
×
NEW
4953

×
NEW
4954
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
NEW
4955
                limit int32) ([]sqlc.GraphNode, error) {
×
NEW
4956

×
NEW
4957
                return db.ListNodesPaginated(
×
NEW
4958
                        ctx, sqlc.ListNodesPaginatedParams{
×
NEW
4959
                                Version: int16(ProtocolV1),
×
NEW
4960
                                ID:      lastID,
×
NEW
4961
                                Limit:   limit,
×
NEW
4962
                        },
×
NEW
4963
                )
×
NEW
4964
        }
×
4965

NEW
4966
        extractPageCursor := func(node sqlc.GraphNode) int64 {
×
NEW
4967
                return node.ID
×
NEW
4968
        }
×
4969

NEW
4970
        collectFunc := func(node sqlc.GraphNode) (int64, error) {
×
NEW
4971
                return node.ID, nil
×
NEW
4972
        }
×
4973

NEW
4974
        batchQueryFunc := func(ctx context.Context,
×
NEW
4975
                nodeIDs []int64) (*batchNodeData, error) {
×
NEW
4976

×
NEW
4977
                return batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
NEW
4978
        }
×
4979

NEW
4980
        processItem := func(ctx context.Context, dbNode sqlc.GraphNode,
×
NEW
4981
                batchData *batchNodeData) error {
×
NEW
4982

×
NEW
4983
                node, err := buildNodeWithBatchData(&dbNode, batchData)
×
NEW
4984
                if err != nil {
×
NEW
4985
                        return fmt.Errorf("unable to build "+
×
NEW
4986
                                "node(id=%d): %w", dbNode.ID, err)
×
NEW
4987
                }
×
4988

NEW
4989
                return processNode(ctx, dbNode.ID, node)
×
4990
        }
4991

NEW
4992
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
NEW
4993
                ctx, cfg, int64(-1), pageQueryFunc, extractPageCursor,
×
NEW
4994
                collectFunc, batchQueryFunc, processItem,
×
NEW
4995
        )
×
4996
}
4997

4998
// forEachChannelWithPolicies executes a paginated query to process each channel
4999
// with policies in the graph.
5000
func forEachChannelWithPolicies(ctx context.Context, db SQLQueries,
5001
        cfg *SQLStoreConfig, processChannel func(*models.ChannelEdgeInfo,
5002
                *models.ChannelEdgePolicy,
NEW
5003
                *models.ChannelEdgePolicy) error) error {
×
NEW
5004

×
NEW
5005
        type channelBatchIDs struct {
×
NEW
5006
                channelID int64
×
NEW
5007
                policyIDs []int64
×
NEW
5008
        }
×
NEW
5009

×
NEW
5010
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
NEW
5011
                limit int32) ([]sqlc.ListChannelsWithPoliciesPaginatedRow,
×
NEW
5012
                error) {
×
NEW
5013

×
NEW
5014
                return db.ListChannelsWithPoliciesPaginated(
×
NEW
5015
                        ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
NEW
5016
                                Version: int16(ProtocolV1),
×
NEW
5017
                                ID:      lastID,
×
NEW
5018
                                Limit:   limit,
×
NEW
5019
                        },
×
NEW
5020
                )
×
NEW
5021
        }
×
5022

NEW
5023
        extractPageCursor := func(
×
NEW
5024
                row sqlc.ListChannelsWithPoliciesPaginatedRow) int64 {
×
NEW
5025

×
NEW
5026
                return row.GraphChannel.ID
×
NEW
5027
        }
×
5028

NEW
5029
        collectFunc := func(row sqlc.ListChannelsWithPoliciesPaginatedRow) (
×
NEW
5030
                channelBatchIDs, error) {
×
NEW
5031

×
NEW
5032
                ids := channelBatchIDs{
×
NEW
5033
                        channelID: row.GraphChannel.ID,
×
NEW
5034
                }
×
NEW
5035

×
NEW
5036
                // Extract policy IDs from the row.
×
NEW
5037
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
NEW
5038
                if err != nil {
×
NEW
5039
                        return ids, err
×
NEW
5040
                }
×
5041

NEW
5042
                if dbPol1 != nil {
×
NEW
5043
                        ids.policyIDs = append(ids.policyIDs, dbPol1.ID)
×
NEW
5044
                }
×
NEW
5045
                if dbPol2 != nil {
×
NEW
5046
                        ids.policyIDs = append(ids.policyIDs, dbPol2.ID)
×
NEW
5047
                }
×
5048

NEW
5049
                return ids, nil
×
5050
        }
5051

NEW
5052
        batchDataFunc := func(ctx context.Context,
×
NEW
5053
                allIDs []channelBatchIDs) (*batchChannelData, error) {
×
NEW
5054

×
NEW
5055
                // Separate channel IDs from policy IDs.
×
NEW
5056
                var (
×
NEW
5057
                        channelIDs = make([]int64, len(allIDs))
×
NEW
5058
                        policyIDs  = make([]int64, 0, len(allIDs)*2)
×
NEW
5059
                )
×
NEW
5060

×
NEW
5061
                for i, ids := range allIDs {
×
NEW
5062
                        channelIDs[i] = ids.channelID
×
NEW
5063
                        policyIDs = append(policyIDs, ids.policyIDs...)
×
NEW
5064
                }
×
5065

NEW
5066
                return batchLoadChannelData(
×
NEW
5067
                        ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
NEW
5068
                )
×
5069
        }
5070

NEW
5071
        processItem := func(ctx context.Context,
×
NEW
5072
                row sqlc.ListChannelsWithPoliciesPaginatedRow,
×
NEW
5073
                batchData *batchChannelData) error {
×
NEW
5074

×
NEW
5075
                node1, node2, err := buildNodeVertices(
×
NEW
5076
                        row.Node1Pubkey, row.Node2Pubkey,
×
NEW
5077
                )
×
NEW
5078
                if err != nil {
×
NEW
5079
                        return err
×
NEW
5080
                }
×
5081

NEW
5082
                edge, err := buildEdgeInfoWithBatchData(
×
NEW
5083
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
NEW
5084
                        batchData,
×
NEW
5085
                )
×
NEW
5086
                if err != nil {
×
NEW
5087
                        return fmt.Errorf("unable to build channel info: %w",
×
NEW
5088
                                err)
×
NEW
5089
                }
×
5090

NEW
5091
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
NEW
5092
                if err != nil {
×
NEW
5093
                        return err
×
NEW
5094
                }
×
5095

NEW
5096
                p1, p2, err := buildChanPoliciesWithBatchData(
×
NEW
5097
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
NEW
5098
                )
×
NEW
5099
                if err != nil {
×
NEW
5100
                        return err
×
NEW
5101
                }
×
5102

NEW
5103
                return processChannel(edge, p1, p2)
×
5104
        }
5105

NEW
5106
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
NEW
5107
                ctx, cfg.QueryCfg, int64(-1), pageQueryFunc, extractPageCursor,
×
NEW
5108
                collectFunc, batchDataFunc, processItem,
×
NEW
5109
        )
×
5110
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc