• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 15978727834

30 Jun 2025 04:44PM UTC coverage: 57.823% (-9.8%) from 67.608%
15978727834

Pull #10010

github

web-flow
Merge f1b7ccc6b into e54206f8c
Pull Request #10010: graph/db: various misc updates

0 of 11 new or added lines in 1 file covered. (0.0%)

28388 existing lines in 458 files now uncovered.

98481 of 170315 relevant lines covered (57.82%)

1.79 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "maps"
11
        "math"
12
        "net"
13
        "slices"
14
        "strconv"
15
        "sync"
16
        "time"
17

18
        "github.com/btcsuite/btcd/btcec/v2"
19
        "github.com/btcsuite/btcd/btcutil"
20
        "github.com/btcsuite/btcd/chaincfg/chainhash"
21
        "github.com/btcsuite/btcd/wire"
22
        "github.com/lightningnetwork/lnd/aliasmgr"
23
        "github.com/lightningnetwork/lnd/batch"
24
        "github.com/lightningnetwork/lnd/fn/v2"
25
        "github.com/lightningnetwork/lnd/graph/db/models"
26
        "github.com/lightningnetwork/lnd/lnwire"
27
        "github.com/lightningnetwork/lnd/routing/route"
28
        "github.com/lightningnetwork/lnd/sqldb"
29
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
30
        "github.com/lightningnetwork/lnd/tlv"
31
        "github.com/lightningnetwork/lnd/tor"
32
)
33

34
// pageSize is the limit for the number of records that can be returned
35
// in a paginated query. This can be tuned after some benchmarks.
36
const pageSize = 2000
37

38
// ProtocolVersion is an enum that defines the gossip protocol version of a
39
// message.
40
type ProtocolVersion uint8
41

42
const (
43
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
44
        ProtocolV1 ProtocolVersion = 1
45
)
46

47
// String returns a string representation of the protocol version.
48
func (v ProtocolVersion) String() string {
×
49
        return fmt.Sprintf("V%d", v)
×
50
}
×
51

52
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
53
// execute queries against the SQL graph tables.
54
//
55
//nolint:ll,interfacebloat
56
type SQLQueries interface {
57
        /*
58
                Node queries.
59
        */
60
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
61
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.Node, error)
62
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
63
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.Node, error)
64
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.Node, error)
65
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
66
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
67
        DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error)
68
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
69
        DeleteNode(ctx context.Context, id int64) error
70

71
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.NodeExtraType, error)
72
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
73
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
74

75
        InsertNodeAddress(ctx context.Context, arg sqlc.InsertNodeAddressParams) error
76
        GetNodeAddressesByPubKey(ctx context.Context, arg sqlc.GetNodeAddressesByPubKeyParams) ([]sqlc.GetNodeAddressesByPubKeyRow, error)
77
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
78

79
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
80
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.NodeFeature, error)
81
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
82
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
83

84
        /*
85
                Source node queries.
86
        */
87
        AddSourceNode(ctx context.Context, nodeID int64) error
88
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
89

90
        /*
91
                Channel queries.
92
        */
93
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
94
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
95
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.Channel, error)
96
        GetChannelByOutpoint(ctx context.Context, outpoint string) (sqlc.GetChannelByOutpointRow, error)
97
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
98
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
99
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
100
        GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]sqlc.GetChannelFeaturesAndExtrasRow, error)
101
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
102
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
103
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
104
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
105
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
106
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
107
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.Channel, error)
108
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
109
        DeleteChannel(ctx context.Context, id int64) error
110

111
        CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
112
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
113

114
        /*
115
                Channel Policy table queries.
116
        */
117
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
118
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.ChannelPolicy, error)
119
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
120

121
        InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error
122
        GetChannelPolicyExtraTypes(ctx context.Context, arg sqlc.GetChannelPolicyExtraTypesParams) ([]sqlc.GetChannelPolicyExtraTypesRow, error)
123
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
124

125
        /*
126
                Zombie index queries.
127
        */
128
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
129
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.ZombieChannel, error)
130
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
131
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
132
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
133

134
        /*
135
                Prune log table queries.
136
        */
137
        GetPruneTip(ctx context.Context) (sqlc.PruneLog, error)
138
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
139
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
140

141
        /*
142
                Closed SCID table queries.
143
        */
144
        InsertClosedChannel(ctx context.Context, scid []byte) error
145
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
146
}
147

148
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
149
// database operations.
150
type BatchedSQLQueries interface {
151
        SQLQueries
152
        sqldb.BatchedTx[SQLQueries]
153
}
154

155
// SQLStore is an implementation of the V1Store interface that uses a SQL
156
// database as the backend.
157
type SQLStore struct {
158
        cfg *SQLStoreConfig
159
        db  BatchedSQLQueries
160

161
        // cacheMu guards all caches (rejectCache and chanCache). If
162
        // this mutex will be acquired at the same time as the DB mutex then
163
        // the cacheMu MUST be acquired first to prevent deadlock.
164
        cacheMu     sync.RWMutex
165
        rejectCache *rejectCache
166
        chanCache   *channelCache
167

168
        chanScheduler batch.Scheduler[SQLQueries]
169
        nodeScheduler batch.Scheduler[SQLQueries]
170

171
        srcNodes  map[ProtocolVersion]*srcNodeInfo
172
        srcNodeMu sync.Mutex
173
}
174

175
// A compile-time assertion to ensure that SQLStore implements the V1Store
176
// interface.
177
var _ V1Store = (*SQLStore)(nil)
178

179
// SQLStoreConfig holds the configuration for the SQLStore.
180
type SQLStoreConfig struct {
181
        // ChainHash is the genesis hash for the chain that all the gossip
182
        // messages in this store are aimed at.
183
        ChainHash chainhash.Hash
184
}
185

186
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
187
// storage backend.
188
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
189
        options ...StoreOptionModifier) (*SQLStore, error) {
×
190

×
191
        opts := DefaultOptions()
×
192
        for _, o := range options {
×
193
                o(opts)
×
194
        }
×
195

196
        if opts.NoMigration {
×
197
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
198
                        "supported for SQL stores")
×
199
        }
×
200

201
        s := &SQLStore{
×
202
                cfg:         cfg,
×
203
                db:          db,
×
204
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
205
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
206
                srcNodes:    make(map[ProtocolVersion]*srcNodeInfo),
×
207
        }
×
208

×
209
        s.chanScheduler = batch.NewTimeScheduler(
×
210
                db, &s.cacheMu, opts.BatchCommitInterval,
×
211
        )
×
212
        s.nodeScheduler = batch.NewTimeScheduler(
×
213
                db, nil, opts.BatchCommitInterval,
×
214
        )
×
215

×
216
        return s, nil
×
217
}
218

219
// AddLightningNode adds a vertex/node to the graph database. If the node is not
220
// in the database from before, this will add a new, unconnected one to the
221
// graph. If it is present from before, this will update that node's
222
// information.
223
//
224
// NOTE: part of the V1Store interface.
225
func (s *SQLStore) AddLightningNode(ctx context.Context,
226
        node *models.LightningNode, opts ...batch.SchedulerOption) error {
×
227

×
228
        r := &batch.Request[SQLQueries]{
×
229
                Opts: batch.NewSchedulerOptions(opts...),
×
230
                Do: func(queries SQLQueries) error {
×
231
                        _, err := upsertNode(ctx, queries, node)
×
232
                        return err
×
233
                },
×
234
        }
235

236
        return s.nodeScheduler.Execute(ctx, r)
×
237
}
238

239
// FetchLightningNode attempts to look up a target node by its identity public
240
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
241
// returned.
242
//
243
// NOTE: part of the V1Store interface.
244
func (s *SQLStore) FetchLightningNode(ctx context.Context,
245
        pubKey route.Vertex) (*models.LightningNode, error) {
×
246

×
247
        var node *models.LightningNode
×
248
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
249
                var err error
×
250
                _, node, err = getNodeByPubKey(ctx, db, pubKey)
×
251

×
252
                return err
×
253
        }, sqldb.NoOpReset)
×
254
        if err != nil {
×
255
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
256
        }
×
257

258
        return node, nil
×
259
}
260

261
// HasLightningNode determines if the graph has a vertex identified by the
262
// target node identity public key. If the node exists in the database, a
263
// timestamp of when the data for the node was lasted updated is returned along
264
// with a true boolean. Otherwise, an empty time.Time is returned with a false
265
// boolean.
266
//
267
// NOTE: part of the V1Store interface.
268
func (s *SQLStore) HasLightningNode(ctx context.Context,
269
        pubKey [33]byte) (time.Time, bool, error) {
×
270

×
271
        var (
×
272
                exists     bool
×
273
                lastUpdate time.Time
×
274
        )
×
275
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
276
                dbNode, err := db.GetNodeByPubKey(
×
277
                        ctx, sqlc.GetNodeByPubKeyParams{
×
278
                                Version: int16(ProtocolV1),
×
279
                                PubKey:  pubKey[:],
×
280
                        },
×
281
                )
×
282
                if errors.Is(err, sql.ErrNoRows) {
×
283
                        return nil
×
284
                } else if err != nil {
×
285
                        return fmt.Errorf("unable to fetch node: %w", err)
×
286
                }
×
287

288
                exists = true
×
289

×
290
                if dbNode.LastUpdate.Valid {
×
291
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
292
                }
×
293

294
                return nil
×
295
        }, sqldb.NoOpReset)
296
        if err != nil {
×
297
                return time.Time{}, false,
×
298
                        fmt.Errorf("unable to fetch node: %w", err)
×
299
        }
×
300

301
        return lastUpdate, exists, nil
×
302
}
303

304
// AddrsForNode returns all known addresses for the target node public key
305
// that the graph DB is aware of. The returned boolean indicates if the
306
// given node is unknown to the graph DB or not.
307
//
308
// NOTE: part of the V1Store interface.
309
func (s *SQLStore) AddrsForNode(ctx context.Context,
310
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
311

×
312
        var (
×
313
                addresses []net.Addr
×
314
                known     bool
×
315
        )
×
316
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
317
                var err error
×
318
                known, addresses, err = getNodeAddresses(
×
319
                        ctx, db, nodePub.SerializeCompressed(),
×
320
                )
×
321
                if err != nil {
×
322
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
323
                                err)
×
324
                }
×
325

326
                return nil
×
327
        }, sqldb.NoOpReset)
328
        if err != nil {
×
329
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
330
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
331
        }
×
332

333
        return known, addresses, nil
×
334
}
335

336
// DeleteLightningNode starts a new database transaction to remove a vertex/node
337
// from the database according to the node's public key.
338
//
339
// NOTE: part of the V1Store interface.
340
func (s *SQLStore) DeleteLightningNode(ctx context.Context,
341
        pubKey route.Vertex) error {
×
342

×
343
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
344
                res, err := db.DeleteNodeByPubKey(
×
345
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
346
                                Version: int16(ProtocolV1),
×
347
                                PubKey:  pubKey[:],
×
348
                        },
×
349
                )
×
350
                if err != nil {
×
351
                        return err
×
352
                }
×
353

354
                rows, err := res.RowsAffected()
×
355
                if err != nil {
×
356
                        return err
×
357
                }
×
358

359
                if rows == 0 {
×
360
                        return ErrGraphNodeNotFound
×
361
                } else if rows > 1 {
×
362
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
363
                }
×
364

365
                return err
×
366
        }, sqldb.NoOpReset)
367
        if err != nil {
×
368
                return fmt.Errorf("unable to delete node: %w", err)
×
369
        }
×
370

371
        return nil
×
372
}
373

374
// FetchNodeFeatures returns the features of the given node. If no features are
375
// known for the node, an empty feature vector is returned.
376
//
377
// NOTE: this is part of the graphdb.NodeTraverser interface.
378
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
379
        *lnwire.FeatureVector, error) {
×
380

×
381
        ctx := context.TODO()
×
382

×
383
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
384
}
×
385

386
// DisabledChannelIDs returns the channel ids of disabled channels.
387
// A channel is disabled when two of the associated ChanelEdgePolicies
388
// have their disabled bit on.
389
//
390
// NOTE: part of the V1Store interface.
391
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
392
        var (
×
393
                ctx     = context.TODO()
×
394
                chanIDs []uint64
×
395
        )
×
396
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
397
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
398
                if err != nil {
×
399
                        return fmt.Errorf("unable to fetch disabled "+
×
400
                                "channels: %w", err)
×
401
                }
×
402

403
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
404

×
405
                return nil
×
406
        }, sqldb.NoOpReset)
407
        if err != nil {
×
408
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
409
                        err)
×
410
        }
×
411

412
        return chanIDs, nil
×
413
}
414

415
// LookupAlias attempts to return the alias as advertised by the target node.
416
//
417
// NOTE: part of the V1Store interface.
418
func (s *SQLStore) LookupAlias(ctx context.Context,
419
        pub *btcec.PublicKey) (string, error) {
×
420

×
421
        var alias string
×
422
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
423
                dbNode, err := db.GetNodeByPubKey(
×
424
                        ctx, sqlc.GetNodeByPubKeyParams{
×
425
                                Version: int16(ProtocolV1),
×
426
                                PubKey:  pub.SerializeCompressed(),
×
427
                        },
×
428
                )
×
429
                if errors.Is(err, sql.ErrNoRows) {
×
430
                        return ErrNodeAliasNotFound
×
431
                } else if err != nil {
×
432
                        return fmt.Errorf("unable to fetch node: %w", err)
×
433
                }
×
434

435
                if !dbNode.Alias.Valid {
×
436
                        return ErrNodeAliasNotFound
×
437
                }
×
438

439
                alias = dbNode.Alias.String
×
440

×
441
                return nil
×
442
        }, sqldb.NoOpReset)
443
        if err != nil {
×
444
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
445
        }
×
446

447
        return alias, nil
×
448
}
449

450
// SourceNode returns the source node of the graph. The source node is treated
451
// as the center node within a star-graph. This method may be used to kick off
452
// a path finding algorithm in order to explore the reachability of another
453
// node based off the source node.
454
//
455
// NOTE: part of the V1Store interface.
456
func (s *SQLStore) SourceNode(ctx context.Context) (*models.LightningNode,
457
        error) {
×
458

×
459
        var node *models.LightningNode
×
460
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
461
                _, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
462
                if err != nil {
×
463
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
464
                                err)
×
465
                }
×
466

467
                _, node, err = getNodeByPubKey(ctx, db, nodePub)
×
468

×
469
                return err
×
470
        }, sqldb.NoOpReset)
471
        if err != nil {
×
472
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
473
        }
×
474

475
        return node, nil
×
476
}
477

478
// SetSourceNode sets the source node within the graph database. The source
479
// node is to be used as the center of a star-graph within path finding
480
// algorithms.
481
//
482
// NOTE: part of the V1Store interface.
483
func (s *SQLStore) SetSourceNode(ctx context.Context,
484
        node *models.LightningNode) error {
×
485

×
486
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
487
                id, err := upsertNode(ctx, db, node)
×
488
                if err != nil {
×
489
                        return fmt.Errorf("unable to upsert source node: %w",
×
490
                                err)
×
491
                }
×
492

493
                // Make sure that if a source node for this version is already
494
                // set, then the ID is the same as the one we are about to set.
495
                dbSourceNodeID, _, err := s.getSourceNode(ctx, db, ProtocolV1)
×
496
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
497
                        return fmt.Errorf("unable to fetch source node: %w",
×
498
                                err)
×
499
                } else if err == nil {
×
500
                        if dbSourceNodeID != id {
×
501
                                return fmt.Errorf("v1 source node already "+
×
502
                                        "set to a different node: %d vs %d",
×
503
                                        dbSourceNodeID, id)
×
504
                        }
×
505

506
                        return nil
×
507
                }
508

509
                return db.AddSourceNode(ctx, id)
×
510
        }, sqldb.NoOpReset)
511
}
512

513
// NodeUpdatesInHorizon returns all the known lightning node which have an
514
// update timestamp within the passed range. This method can be used by two
515
// nodes to quickly determine if they have the same set of up to date node
516
// announcements.
517
//
518
// NOTE: This is part of the V1Store interface.
519
func (s *SQLStore) NodeUpdatesInHorizon(startTime,
520
        endTime time.Time) ([]models.LightningNode, error) {
×
521

×
522
        ctx := context.TODO()
×
523

×
524
        var nodes []models.LightningNode
×
525
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
526
                dbNodes, err := db.GetNodesByLastUpdateRange(
×
527
                        ctx, sqlc.GetNodesByLastUpdateRangeParams{
×
528
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
529
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
530
                        },
×
531
                )
×
532
                if err != nil {
×
533
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
534
                }
×
535

536
                for _, dbNode := range dbNodes {
×
537
                        node, err := buildNode(ctx, db, &dbNode)
×
538
                        if err != nil {
×
539
                                return fmt.Errorf("unable to build node: %w",
×
540
                                        err)
×
541
                        }
×
542

543
                        nodes = append(nodes, *node)
×
544
                }
545

546
                return nil
×
547
        }, sqldb.NoOpReset)
548
        if err != nil {
×
549
                return nil, fmt.Errorf("unable to fetch nodes: %w", err)
×
550
        }
×
551

552
        return nodes, nil
×
553
}
554

555
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
556
// undirected edge from the two target nodes are created. The information stored
557
// denotes the static attributes of the channel, such as the channelID, the keys
558
// involved in creation of the channel, and the set of features that the channel
559
// supports. The chanPoint and chanID are used to uniquely identify the edge
560
// globally within the database.
561
//
562
// NOTE: part of the V1Store interface.
563
func (s *SQLStore) AddChannelEdge(ctx context.Context,
564
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
565

×
566
        var alreadyExists bool
×
567
        r := &batch.Request[SQLQueries]{
×
568
                Opts: batch.NewSchedulerOptions(opts...),
×
569
                Reset: func() {
×
570
                        alreadyExists = false
×
571
                },
×
572
                Do: func(tx SQLQueries) error {
×
573
                        err := insertChannel(ctx, tx, edge)
×
574

×
575
                        // Silence ErrEdgeAlreadyExist so that the batch can
×
576
                        // succeed, but propagate the error via local state.
×
577
                        if errors.Is(err, ErrEdgeAlreadyExist) {
×
578
                                alreadyExists = true
×
579
                                return nil
×
580
                        }
×
581

582
                        return err
×
583
                },
584
                OnCommit: func(err error) error {
×
585
                        switch {
×
586
                        case err != nil:
×
587
                                return err
×
588
                        case alreadyExists:
×
589
                                return ErrEdgeAlreadyExist
×
590
                        default:
×
591
                                s.rejectCache.remove(edge.ChannelID)
×
592
                                s.chanCache.remove(edge.ChannelID)
×
593
                                return nil
×
594
                        }
595
                },
596
        }
597

598
        return s.chanScheduler.Execute(ctx, r)
×
599
}
600

601
// HighestChanID returns the "highest" known channel ID in the channel graph.
602
// This represents the "newest" channel from the PoV of the chain. This method
603
// can be used by peers to quickly determine if their graphs are in sync.
604
//
605
// NOTE: This is part of the V1Store interface.
606
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
607
        var highestChanID uint64
×
608
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
609
                chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
×
610
                if errors.Is(err, sql.ErrNoRows) {
×
611
                        return nil
×
612
                } else if err != nil {
×
613
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
614
                                err)
×
615
                }
×
616

617
                highestChanID = byteOrder.Uint64(chanID)
×
618

×
619
                return nil
×
620
        }, sqldb.NoOpReset)
621
        if err != nil {
×
622
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
623
        }
×
624

625
        return highestChanID, nil
×
626
}
627

628
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
629
// within the database for the referenced channel. The `flags` attribute within
630
// the ChannelEdgePolicy determines which of the directed edges are being
631
// updated. If the flag is 1, then the first node's information is being
632
// updated, otherwise it's the second node's information. The node ordering is
633
// determined by the lexicographical ordering of the identity public keys of the
634
// nodes on either side of the channel.
635
//
636
// NOTE: part of the V1Store interface.
637
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
638
        edge *models.ChannelEdgePolicy,
639
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
640

×
641
        var (
×
642
                isUpdate1    bool
×
643
                edgeNotFound bool
×
644
                from, to     route.Vertex
×
645
        )
×
646

×
647
        r := &batch.Request[SQLQueries]{
×
648
                Opts: batch.NewSchedulerOptions(opts...),
×
649
                Reset: func() {
×
650
                        isUpdate1 = false
×
651
                        edgeNotFound = false
×
652
                },
×
653
                Do: func(tx SQLQueries) error {
×
654
                        var err error
×
655
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
656
                                ctx, tx, edge,
×
657
                        )
×
658
                        if err != nil {
×
659
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
660
                        }
×
661

662
                        // Silence ErrEdgeNotFound so that the batch can
663
                        // succeed, but propagate the error via local state.
664
                        if errors.Is(err, ErrEdgeNotFound) {
×
665
                                edgeNotFound = true
×
666
                                return nil
×
667
                        }
×
668

669
                        return err
×
670
                },
671
                OnCommit: func(err error) error {
×
672
                        switch {
×
673
                        case err != nil:
×
674
                                return err
×
675
                        case edgeNotFound:
×
676
                                return ErrEdgeNotFound
×
677
                        default:
×
678
                                s.updateEdgeCache(edge, isUpdate1)
×
679
                                return nil
×
680
                        }
681
                },
682
        }
683

684
        err := s.chanScheduler.Execute(ctx, r)
×
685

×
686
        return from, to, err
×
687
}
688

689
// updateEdgeCache updates our reject and channel caches with the new
690
// edge policy information.
691
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
692
        isUpdate1 bool) {
×
693

×
694
        // If an entry for this channel is found in reject cache, we'll modify
×
695
        // the entry with the updated timestamp for the direction that was just
×
696
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
697
        // during the next query for this edge.
×
698
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
699
                if isUpdate1 {
×
700
                        entry.upd1Time = e.LastUpdate.Unix()
×
701
                } else {
×
702
                        entry.upd2Time = e.LastUpdate.Unix()
×
703
                }
×
704
                s.rejectCache.insert(e.ChannelID, entry)
×
705
        }
706

707
        // If an entry for this channel is found in channel cache, we'll modify
708
        // the entry with the updated policy for the direction that was just
709
        // written. If the edge doesn't exist, we'll defer loading the info and
710
        // policies and lazily read from disk during the next query.
711
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
712
                if isUpdate1 {
×
713
                        channel.Policy1 = e
×
714
                } else {
×
715
                        channel.Policy2 = e
×
716
                }
×
717
                s.chanCache.insert(e.ChannelID, channel)
×
718
        }
719
}
720

721
// ForEachSourceNodeChannel iterates through all channels of the source node,
722
// executing the passed callback on each. The call-back is provided with the
723
// channel's outpoint, whether we have a policy for the channel and the channel
724
// peer's node information.
725
//
726
// NOTE: part of the V1Store interface.
727
func (s *SQLStore) ForEachSourceNodeChannel(cb func(chanPoint wire.OutPoint,
728
        havePolicy bool, otherNode *models.LightningNode) error) error {
×
729

×
730
        var ctx = context.TODO()
×
731

×
732
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
733
                nodeID, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
734
                if err != nil {
×
735
                        return fmt.Errorf("unable to fetch source node: %w",
×
736
                                err)
×
737
                }
×
738

739
                return forEachNodeChannel(
×
740
                        ctx, db, s.cfg.ChainHash, nodeID,
×
741
                        func(info *models.ChannelEdgeInfo,
×
742
                                outPolicy *models.ChannelEdgePolicy,
×
743
                                _ *models.ChannelEdgePolicy) error {
×
744

×
745
                                // Fetch the other node.
×
746
                                var (
×
747
                                        otherNodePub [33]byte
×
748
                                        node1        = info.NodeKey1Bytes
×
749
                                        node2        = info.NodeKey2Bytes
×
750
                                )
×
751
                                switch {
×
752
                                case bytes.Equal(node1[:], nodePub[:]):
×
753
                                        otherNodePub = node2
×
754
                                case bytes.Equal(node2[:], nodePub[:]):
×
755
                                        otherNodePub = node1
×
756
                                default:
×
757
                                        return fmt.Errorf("node not " +
×
758
                                                "participating in this channel")
×
759
                                }
760

761
                                _, otherNode, err := getNodeByPubKey(
×
762
                                        ctx, db, otherNodePub,
×
763
                                )
×
764
                                if err != nil {
×
765
                                        return fmt.Errorf("unable to fetch "+
×
766
                                                "other node(%x): %w",
×
767
                                                otherNodePub, err)
×
768
                                }
×
769

770
                                return cb(
×
771
                                        info.ChannelPoint, outPolicy != nil,
×
772
                                        otherNode,
×
773
                                )
×
774
                        },
775
                )
776
        }, sqldb.NoOpReset)
777
}
778

779
// ForEachNode iterates through all the stored vertices/nodes in the graph,
780
// executing the passed callback with each node encountered. If the callback
781
// returns an error, then the transaction is aborted and the iteration stops
782
// early. Any operations performed on the NodeTx passed to the call-back are
783
// executed under the same read transaction and so, methods on the NodeTx object
784
// _MUST_ only be called from within the call-back.
785
//
786
// NOTE: part of the V1Store interface.
787
func (s *SQLStore) ForEachNode(cb func(tx NodeRTx) error) error {
×
788
        var (
×
789
                ctx          = context.TODO()
×
790
                lastID int64 = 0
×
791
        )
×
792

×
793
        handleNode := func(db SQLQueries, dbNode sqlc.Node) error {
×
794
                node, err := buildNode(ctx, db, &dbNode)
×
795
                if err != nil {
×
796
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
797
                                dbNode.ID, err)
×
798
                }
×
799

800
                err = cb(
×
801
                        newSQLGraphNodeTx(db, s.cfg.ChainHash, dbNode.ID, node),
×
802
                )
×
803
                if err != nil {
×
804
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
805
                                dbNode.ID, err)
×
806
                }
×
807

808
                return nil
×
809
        }
810

811
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
812
                for {
×
813
                        nodes, err := db.ListNodesPaginated(
×
814
                                ctx, sqlc.ListNodesPaginatedParams{
×
815
                                        Version: int16(ProtocolV1),
×
816
                                        ID:      lastID,
×
817
                                        Limit:   pageSize,
×
818
                                },
×
819
                        )
×
820
                        if err != nil {
×
821
                                return fmt.Errorf("unable to fetch nodes: %w",
×
822
                                        err)
×
823
                        }
×
824

825
                        if len(nodes) == 0 {
×
826
                                break
×
827
                        }
828

829
                        for _, dbNode := range nodes {
×
830
                                err = handleNode(db, dbNode)
×
831
                                if err != nil {
×
832
                                        return err
×
833
                                }
×
834

835
                                lastID = dbNode.ID
×
836
                        }
837
                }
838

839
                return nil
×
840
        }, sqldb.NoOpReset)
841
}
842

843
// sqlGraphNodeTx is an implementation of the NodeRTx interface backed by the
844
// SQLStore and a SQL transaction.
845
type sqlGraphNodeTx struct {
846
        db    SQLQueries
847
        id    int64
848
        node  *models.LightningNode
849
        chain chainhash.Hash
850
}
851

852
// A compile-time constraint to ensure sqlGraphNodeTx implements the NodeRTx
853
// interface.
854
var _ NodeRTx = (*sqlGraphNodeTx)(nil)
855

856
func newSQLGraphNodeTx(db SQLQueries, chain chainhash.Hash,
857
        id int64, node *models.LightningNode) *sqlGraphNodeTx {
×
858

×
859
        return &sqlGraphNodeTx{
×
860
                db:    db,
×
861
                chain: chain,
×
862
                id:    id,
×
863
                node:  node,
×
864
        }
×
865
}
×
866

867
// Node returns the raw information of the node.
868
//
869
// NOTE: This is a part of the NodeRTx interface.
870
func (s *sqlGraphNodeTx) Node() *models.LightningNode {
×
871
        return s.node
×
872
}
×
873

874
// ForEachChannel can be used to iterate over the node's channels under the same
875
// transaction used to fetch the node.
876
//
877
// NOTE: This is a part of the NodeRTx interface.
878
func (s *sqlGraphNodeTx) ForEachChannel(cb func(*models.ChannelEdgeInfo,
879
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
×
880

×
881
        ctx := context.TODO()
×
882

×
883
        return forEachNodeChannel(ctx, s.db, s.chain, s.id, cb)
×
884
}
×
885

886
// FetchNode fetches the node with the given pub key under the same transaction
887
// used to fetch the current node. The returned node is also a NodeRTx and any
888
// operations on that NodeRTx will also be done under the same transaction.
889
//
890
// NOTE: This is a part of the NodeRTx interface.
891
func (s *sqlGraphNodeTx) FetchNode(nodePub route.Vertex) (NodeRTx, error) {
×
892
        ctx := context.TODO()
×
893

×
894
        id, node, err := getNodeByPubKey(ctx, s.db, nodePub)
×
895
        if err != nil {
×
896
                return nil, fmt.Errorf("unable to fetch V1 node(%x): %w",
×
897
                        nodePub, err)
×
898
        }
×
899

900
        return newSQLGraphNodeTx(s.db, s.chain, id, node), nil
×
901
}
902

903
// ForEachNodeDirectedChannel iterates through all channels of a given node,
904
// executing the passed callback on the directed edge representing the channel
905
// and its incoming policy. If the callback returns an error, then the iteration
906
// is halted with the error propagated back up to the caller.
907
//
908
// Unknown policies are passed into the callback as nil values.
909
//
910
// NOTE: this is part of the graphdb.NodeTraverser interface.
911
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
912
        cb func(channel *DirectedChannel) error) error {
×
913

×
914
        var ctx = context.TODO()
×
915

×
916
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
917
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
918
        }, sqldb.NoOpReset)
×
919
}
920

921
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
922
// graph, executing the passed callback with each node encountered. If the
923
// callback returns an error, then the transaction is aborted and the iteration
924
// stops early.
925
//
926
// NOTE: This is a part of the V1Store interface.
927
func (s *SQLStore) ForEachNodeCacheable(cb func(route.Vertex,
928
        *lnwire.FeatureVector) error) error {
×
929

×
930
        ctx := context.TODO()
×
931

×
932
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
933
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
934
                        nodePub route.Vertex) error {
×
935

×
936
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
937
                        if err != nil {
×
938
                                return fmt.Errorf("unable to fetch node "+
×
939
                                        "features: %w", err)
×
940
                        }
×
941

942
                        return cb(nodePub, features)
×
943
                })
944
        }, sqldb.NoOpReset)
945
        if err != nil {
×
946
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
947
        }
×
948

949
        return nil
×
950
}
951

952
// ForEachNodeChannel iterates through all channels of the given node,
953
// executing the passed callback with an edge info structure and the policies
954
// of each end of the channel. The first edge policy is the outgoing edge *to*
955
// the connecting node, while the second is the incoming edge *from* the
956
// connecting node. If the callback returns an error, then the iteration is
957
// halted with the error propagated back up to the caller.
958
//
959
// Unknown policies are passed into the callback as nil values.
960
//
961
// NOTE: part of the V1Store interface.
962
func (s *SQLStore) ForEachNodeChannel(nodePub route.Vertex,
963
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
964
                *models.ChannelEdgePolicy) error) error {
×
965

×
966
        var ctx = context.TODO()
×
967

×
968
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
969
                dbNode, err := db.GetNodeByPubKey(
×
970
                        ctx, sqlc.GetNodeByPubKeyParams{
×
971
                                Version: int16(ProtocolV1),
×
972
                                PubKey:  nodePub[:],
×
973
                        },
×
974
                )
×
975
                if errors.Is(err, sql.ErrNoRows) {
×
976
                        return nil
×
977
                } else if err != nil {
×
978
                        return fmt.Errorf("unable to fetch node: %w", err)
×
979
                }
×
980

981
                return forEachNodeChannel(
×
982
                        ctx, db, s.cfg.ChainHash, dbNode.ID, cb,
×
983
                )
×
984
        }, sqldb.NoOpReset)
985
}
986

987
// ChanUpdatesInHorizon returns all the known channel edges which have at least
988
// one edge that has an update timestamp within the specified horizon.
989
//
990
// NOTE: This is part of the V1Store interface.
991
func (s *SQLStore) ChanUpdatesInHorizon(startTime,
992
        endTime time.Time) ([]ChannelEdge, error) {
×
993

×
994
        s.cacheMu.Lock()
×
995
        defer s.cacheMu.Unlock()
×
996

×
997
        var (
×
998
                ctx = context.TODO()
×
999
                // To ensure we don't return duplicate ChannelEdges, we'll use
×
1000
                // an additional map to keep track of the edges already seen to
×
1001
                // prevent re-adding it.
×
1002
                edgesSeen    = make(map[uint64]struct{})
×
1003
                edgesToCache = make(map[uint64]ChannelEdge)
×
1004
                edges        []ChannelEdge
×
1005
                hits         int
×
1006
        )
×
1007
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1008
                rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
1009
                        ctx, sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
1010
                                Version:   int16(ProtocolV1),
×
1011
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
1012
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
1013
                        },
×
1014
                )
×
1015
                if err != nil {
×
1016
                        return err
×
1017
                }
×
1018

1019
                for _, row := range rows {
×
1020
                        // If we've already retrieved the info and policies for
×
1021
                        // this edge, then we can skip it as we don't need to do
×
1022
                        // so again.
×
1023
                        chanIDInt := byteOrder.Uint64(row.Channel.Scid)
×
1024
                        if _, ok := edgesSeen[chanIDInt]; ok {
×
1025
                                continue
×
1026
                        }
1027

1028
                        if channel, ok := s.chanCache.get(chanIDInt); ok {
×
1029
                                hits++
×
1030
                                edgesSeen[chanIDInt] = struct{}{}
×
1031
                                edges = append(edges, channel)
×
1032

×
1033
                                continue
×
1034
                        }
1035

1036
                        node1, node2, err := buildNodes(
×
1037
                                ctx, db, row.Node, row.Node_2,
×
1038
                        )
×
1039
                        if err != nil {
×
1040
                                return err
×
1041
                        }
×
1042

1043
                        channel, err := getAndBuildEdgeInfo(
×
1044
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
1045
                                row.Channel, node1.PubKeyBytes,
×
1046
                                node2.PubKeyBytes,
×
1047
                        )
×
1048
                        if err != nil {
×
1049
                                return fmt.Errorf("unable to build channel "+
×
1050
                                        "info: %w", err)
×
1051
                        }
×
1052

1053
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1054
                        if err != nil {
×
1055
                                return fmt.Errorf("unable to extract channel "+
×
1056
                                        "policies: %w", err)
×
1057
                        }
×
1058

1059
                        p1, p2, err := getAndBuildChanPolicies(
×
1060
                                ctx, db, dbPol1, dbPol2, channel.ChannelID,
×
1061
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
1062
                        )
×
1063
                        if err != nil {
×
1064
                                return fmt.Errorf("unable to build channel "+
×
1065
                                        "policies: %w", err)
×
1066
                        }
×
1067

1068
                        edgesSeen[chanIDInt] = struct{}{}
×
1069
                        chanEdge := ChannelEdge{
×
1070
                                Info:    channel,
×
1071
                                Policy1: p1,
×
1072
                                Policy2: p2,
×
1073
                                Node1:   node1,
×
1074
                                Node2:   node2,
×
1075
                        }
×
1076
                        edges = append(edges, chanEdge)
×
1077
                        edgesToCache[chanIDInt] = chanEdge
×
1078
                }
1079

1080
                return nil
×
1081
        }, func() {
×
1082
                edgesSeen = make(map[uint64]struct{})
×
1083
                edgesToCache = make(map[uint64]ChannelEdge)
×
1084
                edges = nil
×
1085
        })
×
1086
        if err != nil {
×
1087
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
1088
        }
×
1089

1090
        // Insert any edges loaded from disk into the cache.
1091
        for chanid, channel := range edgesToCache {
×
1092
                s.chanCache.insert(chanid, channel)
×
1093
        }
×
1094

1095
        if len(edges) > 0 {
×
1096
                log.Debugf("ChanUpdatesInHorizon hit percentage: %.2f (%d/%d)",
×
1097
                        float64(hits)*100/float64(len(edges)), hits, len(edges))
×
1098
        } else {
×
1099
                log.Debugf("ChanUpdatesInHorizon returned no edges in "+
×
1100
                        "horizon (%s, %s)", startTime, endTime)
×
1101
        }
×
1102

1103
        return edges, nil
×
1104
}
1105

1106
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1107
// data to the call-back.
1108
//
1109
// NOTE: The callback contents MUST not be modified.
1110
//
1111
// NOTE: part of the V1Store interface.
1112
func (s *SQLStore) ForEachNodeCached(cb func(node route.Vertex,
1113
        chans map[uint64]*DirectedChannel) error) error {
×
1114

×
1115
        var ctx = context.TODO()
×
1116

×
1117
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1118
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
1119
                        nodePub route.Vertex) error {
×
1120

×
1121
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
1122
                        if err != nil {
×
1123
                                return fmt.Errorf("unable to fetch "+
×
1124
                                        "node(id=%d) features: %w", nodeID, err)
×
1125
                        }
×
1126

1127
                        toNodeCallback := func() route.Vertex {
×
1128
                                return nodePub
×
1129
                        }
×
1130

1131
                        rows, err := db.ListChannelsByNodeID(
×
1132
                                ctx, sqlc.ListChannelsByNodeIDParams{
×
1133
                                        Version: int16(ProtocolV1),
×
1134
                                        NodeID1: nodeID,
×
1135
                                },
×
1136
                        )
×
1137
                        if err != nil {
×
1138
                                return fmt.Errorf("unable to fetch channels "+
×
1139
                                        "of node(id=%d): %w", nodeID, err)
×
1140
                        }
×
1141

1142
                        channels := make(map[uint64]*DirectedChannel, len(rows))
×
1143
                        for _, row := range rows {
×
1144
                                node1, node2, err := buildNodeVertices(
×
1145
                                        row.Node1Pubkey, row.Node2Pubkey,
×
1146
                                )
×
1147
                                if err != nil {
×
1148
                                        return err
×
1149
                                }
×
1150

1151
                                e, err := getAndBuildEdgeInfo(
×
1152
                                        ctx, db, s.cfg.ChainHash,
×
1153
                                        row.Channel.ID, row.Channel, node1,
×
1154
                                        node2,
×
1155
                                )
×
1156
                                if err != nil {
×
1157
                                        return fmt.Errorf("unable to build "+
×
1158
                                                "channel info: %w", err)
×
1159
                                }
×
1160

1161
                                dbPol1, dbPol2, err := extractChannelPolicies(
×
1162
                                        row,
×
1163
                                )
×
1164
                                if err != nil {
×
1165
                                        return fmt.Errorf("unable to "+
×
1166
                                                "extract channel "+
×
1167
                                                "policies: %w", err)
×
1168
                                }
×
1169

1170
                                p1, p2, err := getAndBuildChanPolicies(
×
1171
                                        ctx, db, dbPol1, dbPol2, e.ChannelID,
×
1172
                                        node1, node2,
×
1173
                                )
×
1174
                                if err != nil {
×
1175
                                        return fmt.Errorf("unable to "+
×
1176
                                                "build channel policies: %w",
×
1177
                                                err)
×
1178
                                }
×
1179

1180
                                // Determine the outgoing and incoming policy
1181
                                // for this channel and node combo.
1182
                                outPolicy, inPolicy := p1, p2
×
1183
                                if p1 != nil && p1.ToNode == nodePub {
×
1184
                                        outPolicy, inPolicy = p2, p1
×
1185
                                } else if p2 != nil && p2.ToNode != nodePub {
×
1186
                                        outPolicy, inPolicy = p2, p1
×
1187
                                }
×
1188

1189
                                var cachedInPolicy *models.CachedEdgePolicy
×
1190
                                if inPolicy != nil {
×
1191
                                        cachedInPolicy = models.NewCachedPolicy(
×
1192
                                                p2,
×
1193
                                        )
×
1194
                                        cachedInPolicy.ToNodePubKey =
×
1195
                                                toNodeCallback
×
1196
                                        cachedInPolicy.ToNodeFeatures =
×
1197
                                                features
×
1198
                                }
×
1199

1200
                                var inboundFee lnwire.Fee
×
1201
                                outPolicy.InboundFee.WhenSome(
×
1202
                                        func(fee lnwire.Fee) {
×
1203
                                                inboundFee = fee
×
1204
                                        },
×
1205
                                )
1206

1207
                                directedChannel := &DirectedChannel{
×
1208
                                        ChannelID: e.ChannelID,
×
1209
                                        IsNode1: nodePub ==
×
1210
                                                e.NodeKey1Bytes,
×
1211
                                        OtherNode:    e.NodeKey2Bytes,
×
1212
                                        Capacity:     e.Capacity,
×
1213
                                        OutPolicySet: p1 != nil,
×
1214
                                        InPolicy:     cachedInPolicy,
×
1215
                                        InboundFee:   inboundFee,
×
1216
                                }
×
1217

×
1218
                                if nodePub == e.NodeKey2Bytes {
×
1219
                                        directedChannel.OtherNode =
×
1220
                                                e.NodeKey1Bytes
×
1221
                                }
×
1222

1223
                                channels[e.ChannelID] = directedChannel
×
1224
                        }
1225

1226
                        return cb(nodePub, channels)
×
1227
                })
1228
        }, sqldb.NoOpReset)
1229
}
1230

1231
// ForEachChannelCacheable iterates through all the channel edges stored
1232
// within the graph and invokes the passed callback for each edge. The
1233
// callback takes two edges as since this is a directed graph, both the
1234
// in/out edges are visited. If the callback returns an error, then the
1235
// transaction is aborted and the iteration stops early.
1236
//
1237
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1238
// pointer for that particular channel edge routing policy will be
1239
// passed into the callback.
1240
//
1241
// NOTE: this method is like ForEachChannel but fetches only the data
1242
// required for the graph cache.
1243
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1244
        *models.CachedEdgePolicy,
1245
        *models.CachedEdgePolicy) error) error {
×
1246

×
1247
        ctx := context.TODO()
×
1248

×
1249
        handleChannel := func(db SQLQueries,
×
1250
                row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
×
1251

×
1252
                node1, node2, err := buildNodeVertices(
×
1253
                        row.Node1Pubkey, row.Node2Pubkey,
×
1254
                )
×
1255
                if err != nil {
×
1256
                        return err
×
1257
                }
×
1258

1259
                edge := buildCacheableChannelInfo(row.Channel, node1, node2)
×
1260

×
1261
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1262
                if err != nil {
×
1263
                        return err
×
1264
                }
×
1265

1266
                var pol1, pol2 *models.CachedEdgePolicy
×
1267
                if dbPol1 != nil {
×
1268
                        policy1, err := buildChanPolicy(
×
1269
                                *dbPol1, edge.ChannelID, nil, node2, true,
×
1270
                        )
×
1271
                        if err != nil {
×
1272
                                return err
×
1273
                        }
×
1274

1275
                        pol1 = models.NewCachedPolicy(policy1)
×
1276
                }
1277
                if dbPol2 != nil {
×
1278
                        policy2, err := buildChanPolicy(
×
1279
                                *dbPol2, edge.ChannelID, nil, node1, false,
×
1280
                        )
×
1281
                        if err != nil {
×
1282
                                return err
×
1283
                        }
×
1284

1285
                        pol2 = models.NewCachedPolicy(policy2)
×
1286
                }
1287

1288
                if err := cb(edge, pol1, pol2); err != nil {
×
1289
                        return err
×
1290
                }
×
1291

1292
                return nil
×
1293
        }
1294

1295
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1296
                lastID := int64(-1)
×
1297
                for {
×
1298
                        //nolint:ll
×
1299
                        rows, err := db.ListChannelsWithPoliciesPaginated(
×
1300
                                ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
1301
                                        Version: int16(ProtocolV1),
×
1302
                                        ID:      lastID,
×
1303
                                        Limit:   pageSize,
×
1304
                                },
×
1305
                        )
×
1306
                        if err != nil {
×
1307
                                return err
×
1308
                        }
×
1309

1310
                        if len(rows) == 0 {
×
1311
                                break
×
1312
                        }
1313

1314
                        for _, row := range rows {
×
1315
                                err := handleChannel(db, row)
×
1316
                                if err != nil {
×
1317
                                        return err
×
1318
                                }
×
1319

1320
                                lastID = row.Channel.ID
×
1321
                        }
1322
                }
1323

1324
                return nil
×
1325
        }, sqldb.NoOpReset)
1326
}
1327

1328
// ForEachChannel iterates through all the channel edges stored within the
1329
// graph and invokes the passed callback for each edge. The callback takes two
1330
// edges as since this is a directed graph, both the in/out edges are visited.
1331
// If the callback returns an error, then the transaction is aborted and the
1332
// iteration stops early.
1333
//
1334
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1335
// for that particular channel edge routing policy will be passed into the
1336
// callback.
1337
//
1338
// NOTE: part of the V1Store interface.
1339
func (s *SQLStore) ForEachChannel(cb func(*models.ChannelEdgeInfo,
1340
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
×
1341

×
1342
        ctx := context.TODO()
×
1343

×
1344
        handleChannel := func(db SQLQueries,
×
1345
                row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
×
1346

×
1347
                node1, node2, err := buildNodeVertices(
×
1348
                        row.Node1Pubkey, row.Node2Pubkey,
×
1349
                )
×
1350
                if err != nil {
×
1351
                        return fmt.Errorf("unable to build node vertices: %w",
×
1352
                                err)
×
1353
                }
×
1354

1355
                edge, err := getAndBuildEdgeInfo(
×
1356
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
1357
                        node1, node2,
×
1358
                )
×
1359
                if err != nil {
×
1360
                        return fmt.Errorf("unable to build channel info: %w",
×
1361
                                err)
×
1362
                }
×
1363

1364
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1365
                if err != nil {
×
1366
                        return fmt.Errorf("unable to extract channel "+
×
1367
                                "policies: %w", err)
×
1368
                }
×
1369

1370
                p1, p2, err := getAndBuildChanPolicies(
×
1371
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1372
                )
×
1373
                if err != nil {
×
1374
                        return fmt.Errorf("unable to build channel "+
×
1375
                                "policies: %w", err)
×
1376
                }
×
1377

1378
                err = cb(edge, p1, p2)
×
1379
                if err != nil {
×
1380
                        return fmt.Errorf("callback failed for channel "+
×
1381
                                "id=%d: %w", edge.ChannelID, err)
×
1382
                }
×
1383

1384
                return nil
×
1385
        }
1386

1387
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1388
                lastID := int64(-1)
×
1389
                for {
×
1390
                        //nolint:ll
×
1391
                        rows, err := db.ListChannelsWithPoliciesPaginated(
×
1392
                                ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
1393
                                        Version: int16(ProtocolV1),
×
1394
                                        ID:      lastID,
×
1395
                                        Limit:   pageSize,
×
1396
                                },
×
1397
                        )
×
1398
                        if err != nil {
×
1399
                                return err
×
1400
                        }
×
1401

1402
                        if len(rows) == 0 {
×
1403
                                break
×
1404
                        }
1405

1406
                        for _, row := range rows {
×
1407
                                err := handleChannel(db, row)
×
1408
                                if err != nil {
×
1409
                                        return err
×
1410
                                }
×
1411

1412
                                lastID = row.Channel.ID
×
1413
                        }
1414
                }
1415

1416
                return nil
×
1417
        }, sqldb.NoOpReset)
1418
}
1419

1420
// FilterChannelRange returns the channel ID's of all known channels which were
1421
// mined in a block height within the passed range. The channel IDs are grouped
1422
// by their common block height. This method can be used to quickly share with a
1423
// peer the set of channels we know of within a particular range to catch them
1424
// up after a period of time offline. If withTimestamps is true then the
1425
// timestamp info of the latest received channel update messages of the channel
1426
// will be included in the response.
1427
//
1428
// NOTE: This is part of the V1Store interface.
1429
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1430
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1431

×
1432
        var (
×
1433
                ctx       = context.TODO()
×
1434
                startSCID = &lnwire.ShortChannelID{
×
1435
                        BlockHeight: startHeight,
×
1436
                }
×
1437
                endSCID = lnwire.ShortChannelID{
×
1438
                        BlockHeight: endHeight,
×
1439
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1440
                        TxPosition:  math.MaxUint16,
×
1441
                }
×
1442
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1443
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1444
        )
×
1445

×
1446
        // 1) get all channels where channelID is between start and end chan ID.
×
1447
        // 2) skip if not public (ie, no channel_proof)
×
1448
        // 3) collect that channel.
×
1449
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1450
        //    and add those timestamps to the collected channel.
×
1451
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1452
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1453
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1454
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1455
                                StartScid: chanIDStart[:],
×
1456
                                EndScid:   chanIDEnd[:],
×
1457
                        },
×
1458
                )
×
1459
                if err != nil {
×
1460
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1461
                                err)
×
1462
                }
×
1463

1464
                for _, dbChan := range dbChans {
×
1465
                        cid := lnwire.NewShortChanIDFromInt(
×
1466
                                byteOrder.Uint64(dbChan.Scid),
×
1467
                        )
×
1468
                        chanInfo := NewChannelUpdateInfo(
×
1469
                                cid, time.Time{}, time.Time{},
×
1470
                        )
×
1471

×
1472
                        if !withTimestamps {
×
1473
                                channelsPerBlock[cid.BlockHeight] = append(
×
1474
                                        channelsPerBlock[cid.BlockHeight],
×
1475
                                        chanInfo,
×
1476
                                )
×
1477

×
1478
                                continue
×
1479
                        }
1480

1481
                        //nolint:ll
1482
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1483
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1484
                                        Version:   int16(ProtocolV1),
×
1485
                                        ChannelID: dbChan.ID,
×
1486
                                        NodeID:    dbChan.NodeID1,
×
1487
                                },
×
1488
                        )
×
1489
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1490
                                return fmt.Errorf("unable to fetch node1 "+
×
1491
                                        "policy: %w", err)
×
1492
                        } else if err == nil {
×
1493
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1494
                                        node1Policy.LastUpdate.Int64, 0,
×
1495
                                )
×
1496
                        }
×
1497

1498
                        //nolint:ll
1499
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1500
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1501
                                        Version:   int16(ProtocolV1),
×
1502
                                        ChannelID: dbChan.ID,
×
1503
                                        NodeID:    dbChan.NodeID2,
×
1504
                                },
×
1505
                        )
×
1506
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1507
                                return fmt.Errorf("unable to fetch node2 "+
×
1508
                                        "policy: %w", err)
×
1509
                        } else if err == nil {
×
1510
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1511
                                        node2Policy.LastUpdate.Int64, 0,
×
1512
                                )
×
1513
                        }
×
1514

1515
                        channelsPerBlock[cid.BlockHeight] = append(
×
1516
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1517
                        )
×
1518
                }
1519

1520
                return nil
×
1521
        }, func() {
×
1522
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1523
        })
×
1524
        if err != nil {
×
1525
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1526
        }
×
1527

1528
        if len(channelsPerBlock) == 0 {
×
1529
                return nil, nil
×
1530
        }
×
1531

1532
        // Return the channel ranges in ascending block height order.
1533
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1534
        slices.Sort(blocks)
×
1535

×
1536
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1537
                return BlockChannelRange{
×
1538
                        Height:   block,
×
1539
                        Channels: channelsPerBlock[block],
×
1540
                }
×
1541
        }), nil
×
1542
}
1543

1544
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1545
// zombie. This method is used on an ad-hoc basis, when channels need to be
1546
// marked as zombies outside the normal pruning cycle.
1547
//
1548
// NOTE: part of the V1Store interface.
1549
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1550
        pubKey1, pubKey2 [33]byte) error {
×
1551

×
1552
        ctx := context.TODO()
×
1553

×
1554
        s.cacheMu.Lock()
×
1555
        defer s.cacheMu.Unlock()
×
1556

×
1557
        chanIDB := channelIDToBytes(chanID)
×
1558

×
1559
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1560
                return db.UpsertZombieChannel(
×
1561
                        ctx, sqlc.UpsertZombieChannelParams{
×
1562
                                Version:  int16(ProtocolV1),
×
1563
                                Scid:     chanIDB[:],
×
1564
                                NodeKey1: pubKey1[:],
×
1565
                                NodeKey2: pubKey2[:],
×
1566
                        },
×
1567
                )
×
1568
        }, sqldb.NoOpReset)
×
1569
        if err != nil {
×
1570
                return fmt.Errorf("unable to upsert zombie channel "+
×
1571
                        "(channel_id=%d): %w", chanID, err)
×
1572
        }
×
1573

1574
        s.rejectCache.remove(chanID)
×
1575
        s.chanCache.remove(chanID)
×
1576

×
1577
        return nil
×
1578
}
1579

1580
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1581
//
1582
// NOTE: part of the V1Store interface.
1583
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1584
        s.cacheMu.Lock()
×
1585
        defer s.cacheMu.Unlock()
×
1586

×
1587
        var (
×
1588
                ctx     = context.TODO()
×
1589
                chanIDB = channelIDToBytes(chanID)
×
1590
        )
×
1591

×
1592
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1593
                res, err := db.DeleteZombieChannel(
×
1594
                        ctx, sqlc.DeleteZombieChannelParams{
×
1595
                                Scid:    chanIDB[:],
×
1596
                                Version: int16(ProtocolV1),
×
1597
                        },
×
1598
                )
×
1599
                if err != nil {
×
1600
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1601
                                err)
×
1602
                }
×
1603

1604
                rows, err := res.RowsAffected()
×
1605
                if err != nil {
×
1606
                        return err
×
1607
                }
×
1608

1609
                if rows == 0 {
×
1610
                        return ErrZombieEdgeNotFound
×
1611
                } else if rows > 1 {
×
1612
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1613
                                "expected 1", rows)
×
1614
                }
×
1615

1616
                return nil
×
1617
        }, sqldb.NoOpReset)
1618
        if err != nil {
×
1619
                return fmt.Errorf("unable to mark edge live "+
×
1620
                        "(channel_id=%d): %w", chanID, err)
×
1621
        }
×
1622

1623
        s.rejectCache.remove(chanID)
×
1624
        s.chanCache.remove(chanID)
×
1625

×
1626
        return err
×
1627
}
1628

1629
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1630
// zombie, then the two node public keys corresponding to this edge are also
1631
// returned.
1632
//
1633
// NOTE: part of the V1Store interface.
1634
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
1635
        error) {
×
1636

×
1637
        var (
×
1638
                ctx              = context.TODO()
×
1639
                isZombie         bool
×
1640
                pubKey1, pubKey2 route.Vertex
×
1641
                chanIDB          = channelIDToBytes(chanID)
×
1642
        )
×
1643

×
1644
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1645
                zombie, err := db.GetZombieChannel(
×
1646
                        ctx, sqlc.GetZombieChannelParams{
×
1647
                                Scid:    chanIDB[:],
×
1648
                                Version: int16(ProtocolV1),
×
1649
                        },
×
1650
                )
×
1651
                if errors.Is(err, sql.ErrNoRows) {
×
1652
                        return nil
×
1653
                }
×
1654
                if err != nil {
×
1655
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1656
                                err)
×
1657
                }
×
1658

1659
                copy(pubKey1[:], zombie.NodeKey1)
×
1660
                copy(pubKey2[:], zombie.NodeKey2)
×
1661
                isZombie = true
×
1662

×
1663
                return nil
×
1664
        }, sqldb.NoOpReset)
1665
        if err != nil {
×
1666
                return false, route.Vertex{}, route.Vertex{},
×
1667
                        fmt.Errorf("%w: %w (chanID=%d)",
×
1668
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
×
1669
        }
×
1670

1671
        return isZombie, pubKey1, pubKey2, nil
×
1672
}
1673

1674
// NumZombies returns the current number of zombie channels in the graph.
1675
//
1676
// NOTE: part of the V1Store interface.
1677
func (s *SQLStore) NumZombies() (uint64, error) {
×
1678
        var (
×
1679
                ctx        = context.TODO()
×
1680
                numZombies uint64
×
1681
        )
×
1682
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1683
                count, err := db.CountZombieChannels(ctx, int16(ProtocolV1))
×
1684
                if err != nil {
×
1685
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1686
                                err)
×
1687
                }
×
1688

1689
                numZombies = uint64(count)
×
1690

×
1691
                return nil
×
1692
        }, sqldb.NoOpReset)
1693
        if err != nil {
×
1694
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1695
        }
×
1696

1697
        return numZombies, nil
×
1698
}
1699

1700
// DeleteChannelEdges removes edges with the given channel IDs from the
1701
// database and marks them as zombies. This ensures that we're unable to re-add
1702
// it to our database once again. If an edge does not exist within the
1703
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1704
// true, then when we mark these edges as zombies, we'll set up the keys such
1705
// that we require the node that failed to send the fresh update to be the one
1706
// that resurrects the channel from its zombie state. The markZombie bool
1707
// denotes whether to mark the channel as a zombie.
1708
//
1709
// NOTE: part of the V1Store interface.
1710
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1711
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
1712

×
1713
        s.cacheMu.Lock()
×
1714
        defer s.cacheMu.Unlock()
×
1715

×
1716
        var (
×
1717
                ctx     = context.TODO()
×
1718
                deleted []*models.ChannelEdgeInfo
×
1719
        )
×
1720
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1721
                for _, chanID := range chanIDs {
×
1722
                        chanIDB := channelIDToBytes(chanID)
×
1723

×
1724
                        row, err := db.GetChannelBySCIDWithPolicies(
×
1725
                                ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1726
                                        Scid:    chanIDB[:],
×
1727
                                        Version: int16(ProtocolV1),
×
1728
                                },
×
1729
                        )
×
1730
                        if errors.Is(err, sql.ErrNoRows) {
×
1731
                                return ErrEdgeNotFound
×
1732
                        } else if err != nil {
×
1733
                                return fmt.Errorf("unable to fetch channel: %w",
×
1734
                                        err)
×
1735
                        }
×
1736

1737
                        node1, node2, err := buildNodeVertices(
×
1738
                                row.Node.PubKey, row.Node_2.PubKey,
×
1739
                        )
×
1740
                        if err != nil {
×
1741
                                return err
×
1742
                        }
×
1743

1744
                        info, err := getAndBuildEdgeInfo(
×
1745
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
1746
                                row.Channel, node1, node2,
×
1747
                        )
×
1748
                        if err != nil {
×
1749
                                return err
×
1750
                        }
×
1751

1752
                        err = db.DeleteChannel(ctx, row.Channel.ID)
×
1753
                        if err != nil {
×
1754
                                return fmt.Errorf("unable to delete "+
×
1755
                                        "channel: %w", err)
×
1756
                        }
×
1757

1758
                        deleted = append(deleted, info)
×
1759

×
1760
                        if !markZombie {
×
1761
                                continue
×
1762
                        }
1763

1764
                        nodeKey1, nodeKey2 := info.NodeKey1Bytes,
×
1765
                                info.NodeKey2Bytes
×
1766
                        if strictZombiePruning {
×
1767
                                var e1UpdateTime, e2UpdateTime *time.Time
×
1768
                                if row.Policy1LastUpdate.Valid {
×
1769
                                        e1Time := time.Unix(
×
1770
                                                row.Policy1LastUpdate.Int64, 0,
×
1771
                                        )
×
1772
                                        e1UpdateTime = &e1Time
×
1773
                                }
×
1774
                                if row.Policy2LastUpdate.Valid {
×
1775
                                        e2Time := time.Unix(
×
1776
                                                row.Policy2LastUpdate.Int64, 0,
×
1777
                                        )
×
1778
                                        e2UpdateTime = &e2Time
×
1779
                                }
×
1780

1781
                                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
1782
                                        info, e1UpdateTime, e2UpdateTime,
×
1783
                                )
×
1784
                        }
1785

1786
                        err = db.UpsertZombieChannel(
×
1787
                                ctx, sqlc.UpsertZombieChannelParams{
×
1788
                                        Version:  int16(ProtocolV1),
×
1789
                                        Scid:     chanIDB[:],
×
1790
                                        NodeKey1: nodeKey1[:],
×
1791
                                        NodeKey2: nodeKey2[:],
×
1792
                                },
×
1793
                        )
×
1794
                        if err != nil {
×
1795
                                return fmt.Errorf("unable to mark channel as "+
×
1796
                                        "zombie: %w", err)
×
1797
                        }
×
1798
                }
1799

1800
                return nil
×
1801
        }, func() {
×
1802
                deleted = nil
×
1803
        })
×
1804
        if err != nil {
×
1805
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1806
                        err)
×
1807
        }
×
1808

1809
        for _, chanID := range chanIDs {
×
1810
                s.rejectCache.remove(chanID)
×
1811
                s.chanCache.remove(chanID)
×
1812
        }
×
1813

1814
        return deleted, nil
×
1815
}
1816

1817
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1818
// channel identified by the channel ID. If the channel can't be found, then
1819
// ErrEdgeNotFound is returned. A struct which houses the general information
1820
// for the channel itself is returned as well as two structs that contain the
1821
// routing policies for the channel in either direction.
1822
//
1823
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1824
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1825
// the ChannelEdgeInfo will only include the public keys of each node.
1826
//
1827
// NOTE: part of the V1Store interface.
1828
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1829
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1830
        *models.ChannelEdgePolicy, error) {
×
1831

×
1832
        var (
×
1833
                ctx              = context.TODO()
×
1834
                edge             *models.ChannelEdgeInfo
×
1835
                policy1, policy2 *models.ChannelEdgePolicy
×
1836
        )
×
1837
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1838
                var chanIDB [8]byte
×
1839
                byteOrder.PutUint64(chanIDB[:], chanID)
×
1840

×
1841
                row, err := db.GetChannelBySCIDWithPolicies(
×
1842
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1843
                                Scid:    chanIDB[:],
×
1844
                                Version: int16(ProtocolV1),
×
1845
                        },
×
1846
                )
×
1847
                if errors.Is(err, sql.ErrNoRows) {
×
1848
                        // First check if this edge is perhaps in the zombie
×
1849
                        // index.
×
1850
                        isZombie, err := db.IsZombieChannel(
×
1851
                                ctx, sqlc.IsZombieChannelParams{
×
1852
                                        Scid:    chanIDB[:],
×
1853
                                        Version: int16(ProtocolV1),
×
1854
                                },
×
1855
                        )
×
1856
                        if err != nil {
×
1857
                                return fmt.Errorf("unable to check if "+
×
1858
                                        "channel is zombie: %w", err)
×
1859
                        } else if isZombie {
×
1860
                                return ErrZombieEdge
×
1861
                        }
×
1862

1863
                        return ErrEdgeNotFound
×
1864
                } else if err != nil {
×
1865
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1866
                }
×
1867

1868
                node1, node2, err := buildNodeVertices(
×
1869
                        row.Node.PubKey, row.Node_2.PubKey,
×
1870
                )
×
1871
                if err != nil {
×
1872
                        return err
×
1873
                }
×
1874

1875
                edge, err = getAndBuildEdgeInfo(
×
1876
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
1877
                        node1, node2,
×
1878
                )
×
1879
                if err != nil {
×
1880
                        return fmt.Errorf("unable to build channel info: %w",
×
1881
                                err)
×
1882
                }
×
1883

1884
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1885
                if err != nil {
×
1886
                        return fmt.Errorf("unable to extract channel "+
×
1887
                                "policies: %w", err)
×
1888
                }
×
1889

1890
                policy1, policy2, err = getAndBuildChanPolicies(
×
1891
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1892
                )
×
1893
                if err != nil {
×
1894
                        return fmt.Errorf("unable to build channel "+
×
1895
                                "policies: %w", err)
×
1896
                }
×
1897

1898
                return nil
×
1899
        }, sqldb.NoOpReset)
1900
        if err != nil {
×
1901
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1902
                        err)
×
1903
        }
×
1904

1905
        return edge, policy1, policy2, nil
×
1906
}
1907

1908
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
1909
// the channel identified by the funding outpoint. If the channel can't be
1910
// found, then ErrEdgeNotFound is returned. A struct which houses the general
1911
// information for the channel itself is returned as well as two structs that
1912
// contain the routing policies for the channel in either direction.
1913
//
1914
// NOTE: part of the V1Store interface.
1915
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
1916
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1917
        *models.ChannelEdgePolicy, error) {
×
1918

×
1919
        var (
×
1920
                ctx              = context.TODO()
×
1921
                edge             *models.ChannelEdgeInfo
×
1922
                policy1, policy2 *models.ChannelEdgePolicy
×
1923
        )
×
1924
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1925
                row, err := db.GetChannelByOutpointWithPolicies(
×
1926
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
1927
                                Outpoint: op.String(),
×
1928
                                Version:  int16(ProtocolV1),
×
1929
                        },
×
1930
                )
×
1931
                if errors.Is(err, sql.ErrNoRows) {
×
1932
                        return ErrEdgeNotFound
×
1933
                } else if err != nil {
×
1934
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1935
                }
×
1936

1937
                node1, node2, err := buildNodeVertices(
×
1938
                        row.Node1Pubkey, row.Node2Pubkey,
×
1939
                )
×
1940
                if err != nil {
×
1941
                        return err
×
1942
                }
×
1943

1944
                edge, err = getAndBuildEdgeInfo(
×
1945
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
1946
                        node1, node2,
×
1947
                )
×
1948
                if err != nil {
×
1949
                        return fmt.Errorf("unable to build channel info: %w",
×
1950
                                err)
×
1951
                }
×
1952

1953
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1954
                if err != nil {
×
1955
                        return fmt.Errorf("unable to extract channel "+
×
1956
                                "policies: %w", err)
×
1957
                }
×
1958

1959
                policy1, policy2, err = getAndBuildChanPolicies(
×
1960
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1961
                )
×
1962
                if err != nil {
×
1963
                        return fmt.Errorf("unable to build channel "+
×
1964
                                "policies: %w", err)
×
1965
                }
×
1966

1967
                return nil
×
1968
        }, sqldb.NoOpReset)
1969
        if err != nil {
×
1970
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1971
                        err)
×
1972
        }
×
1973

1974
        return edge, policy1, policy2, nil
×
1975
}
1976

1977
// HasChannelEdge returns true if the database knows of a channel edge with the
1978
// passed channel ID, and false otherwise. If an edge with that ID is found
1979
// within the graph, then two time stamps representing the last time the edge
1980
// was updated for both directed edges are returned along with the boolean. If
1981
// it is not found, then the zombie index is checked and its result is returned
1982
// as the second boolean.
1983
//
1984
// NOTE: part of the V1Store interface.
1985
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
1986
        bool, error) {
×
1987

×
1988
        ctx := context.TODO()
×
1989

×
1990
        var (
×
1991
                exists          bool
×
1992
                isZombie        bool
×
1993
                node1LastUpdate time.Time
×
1994
                node2LastUpdate time.Time
×
1995
        )
×
1996

×
1997
        // We'll query the cache with the shared lock held to allow multiple
×
1998
        // readers to access values in the cache concurrently if they exist.
×
1999
        s.cacheMu.RLock()
×
2000
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2001
                s.cacheMu.RUnlock()
×
2002
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2003
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2004
                exists, isZombie = entry.flags.unpack()
×
2005

×
2006
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2007
        }
×
2008
        s.cacheMu.RUnlock()
×
2009

×
2010
        s.cacheMu.Lock()
×
2011
        defer s.cacheMu.Unlock()
×
2012

×
2013
        // The item was not found with the shared lock, so we'll acquire the
×
2014
        // exclusive lock and check the cache again in case another method added
×
2015
        // the entry to the cache while no lock was held.
×
2016
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2017
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2018
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2019
                exists, isZombie = entry.flags.unpack()
×
2020

×
2021
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2022
        }
×
2023

2024
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2025
                var chanIDB [8]byte
×
2026
                byteOrder.PutUint64(chanIDB[:], chanID)
×
2027

×
2028
                channel, err := db.GetChannelBySCID(
×
2029
                        ctx, sqlc.GetChannelBySCIDParams{
×
2030
                                Scid:    chanIDB[:],
×
2031
                                Version: int16(ProtocolV1),
×
2032
                        },
×
2033
                )
×
2034
                if errors.Is(err, sql.ErrNoRows) {
×
2035
                        // Check if it is a zombie channel.
×
2036
                        isZombie, err = db.IsZombieChannel(
×
2037
                                ctx, sqlc.IsZombieChannelParams{
×
2038
                                        Scid:    chanIDB[:],
×
2039
                                        Version: int16(ProtocolV1),
×
2040
                                },
×
2041
                        )
×
2042
                        if err != nil {
×
2043
                                return fmt.Errorf("could not check if channel "+
×
2044
                                        "is zombie: %w", err)
×
2045
                        }
×
2046

2047
                        return nil
×
2048
                } else if err != nil {
×
2049
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2050
                }
×
2051

2052
                exists = true
×
2053

×
2054
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
2055
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2056
                                Version:   int16(ProtocolV1),
×
2057
                                ChannelID: channel.ID,
×
2058
                                NodeID:    channel.NodeID1,
×
2059
                        },
×
2060
                )
×
2061
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2062
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2063
                                err)
×
2064
                } else if err == nil {
×
2065
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
2066
                }
×
2067

2068
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
2069
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2070
                                Version:   int16(ProtocolV1),
×
2071
                                ChannelID: channel.ID,
×
2072
                                NodeID:    channel.NodeID2,
×
2073
                        },
×
2074
                )
×
2075
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2076
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2077
                                err)
×
2078
                } else if err == nil {
×
2079
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
2080
                }
×
2081

2082
                return nil
×
2083
        }, sqldb.NoOpReset)
2084
        if err != nil {
×
2085
                return time.Time{}, time.Time{}, false, false,
×
2086
                        fmt.Errorf("unable to fetch channel: %w", err)
×
2087
        }
×
2088

2089
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
2090
                upd1Time: node1LastUpdate.Unix(),
×
2091
                upd2Time: node2LastUpdate.Unix(),
×
2092
                flags:    packRejectFlags(exists, isZombie),
×
2093
        })
×
2094

×
2095
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2096
}
2097

2098
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2099
// passed channel point (outpoint). If the passed channel doesn't exist within
2100
// the database, then ErrEdgeNotFound is returned.
2101
//
2102
// NOTE: part of the V1Store interface.
2103
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
2104
        var (
×
2105
                ctx       = context.TODO()
×
2106
                channelID uint64
×
2107
        )
×
2108
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2109
                chanID, err := db.GetSCIDByOutpoint(
×
2110
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2111
                                Outpoint: chanPoint.String(),
×
2112
                                Version:  int16(ProtocolV1),
×
2113
                        },
×
2114
                )
×
2115
                if errors.Is(err, sql.ErrNoRows) {
×
2116
                        return ErrEdgeNotFound
×
2117
                } else if err != nil {
×
2118
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2119
                                err)
×
2120
                }
×
2121

2122
                channelID = byteOrder.Uint64(chanID)
×
2123

×
2124
                return nil
×
2125
        }, sqldb.NoOpReset)
2126
        if err != nil {
×
2127
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2128
        }
×
2129

2130
        return channelID, nil
×
2131
}
2132

2133
// IsPublicNode is a helper method that determines whether the node with the
2134
// given public key is seen as a public node in the graph from the graph's
2135
// source node's point of view.
2136
//
2137
// NOTE: part of the V1Store interface.
2138
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
2139
        ctx := context.TODO()
×
2140

×
2141
        var isPublic bool
×
2142
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2143
                var err error
×
2144
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2145

×
2146
                return err
×
2147
        }, sqldb.NoOpReset)
×
2148
        if err != nil {
×
2149
                return false, fmt.Errorf("unable to check if node is "+
×
2150
                        "public: %w", err)
×
2151
        }
×
2152

2153
        return isPublic, nil
×
2154
}
2155

2156
// FetchChanInfos returns the set of channel edges that correspond to the passed
2157
// channel ID's. If an edge is the query is unknown to the database, it will
2158
// skipped and the result will contain only those edges that exist at the time
2159
// of the query. This can be used to respond to peer queries that are seeking to
2160
// fill in gaps in their view of the channel graph.
2161
//
2162
// NOTE: part of the V1Store interface.
2163
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2164
        var (
×
2165
                ctx   = context.TODO()
×
2166
                edges []ChannelEdge
×
2167
        )
×
2168
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2169
                for _, chanID := range chanIDs {
×
2170
                        var chanIDB [8]byte
×
2171
                        byteOrder.PutUint64(chanIDB[:], chanID)
×
2172

×
2173
                        // TODO(elle): potentially optimize this by using
×
2174
                        //  sqlc.slice() once that works for both SQLite and
×
2175
                        //  Postgres.
×
2176
                        row, err := db.GetChannelBySCIDWithPolicies(
×
2177
                                ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
2178
                                        Scid:    chanIDB[:],
×
2179
                                        Version: int16(ProtocolV1),
×
2180
                                },
×
2181
                        )
×
2182
                        if errors.Is(err, sql.ErrNoRows) {
×
2183
                                continue
×
2184
                        } else if err != nil {
×
2185
                                return fmt.Errorf("unable to fetch channel: %w",
×
2186
                                        err)
×
2187
                        }
×
2188

2189
                        node1, node2, err := buildNodes(
×
2190
                                ctx, db, row.Node, row.Node_2,
×
2191
                        )
×
2192
                        if err != nil {
×
2193
                                return fmt.Errorf("unable to fetch nodes: %w",
×
2194
                                        err)
×
2195
                        }
×
2196

2197
                        edge, err := getAndBuildEdgeInfo(
×
2198
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
2199
                                row.Channel, node1.PubKeyBytes,
×
2200
                                node2.PubKeyBytes,
×
2201
                        )
×
2202
                        if err != nil {
×
2203
                                return fmt.Errorf("unable to build "+
×
2204
                                        "channel info: %w", err)
×
2205
                        }
×
2206

2207
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2208
                        if err != nil {
×
2209
                                return fmt.Errorf("unable to extract channel "+
×
2210
                                        "policies: %w", err)
×
2211
                        }
×
2212

2213
                        p1, p2, err := getAndBuildChanPolicies(
×
2214
                                ctx, db, dbPol1, dbPol2, edge.ChannelID,
×
2215
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
2216
                        )
×
2217
                        if err != nil {
×
2218
                                return fmt.Errorf("unable to build channel "+
×
2219
                                        "policies: %w", err)
×
2220
                        }
×
2221

2222
                        edges = append(edges, ChannelEdge{
×
2223
                                Info:    edge,
×
2224
                                Policy1: p1,
×
2225
                                Policy2: p2,
×
2226
                                Node1:   node1,
×
2227
                                Node2:   node2,
×
2228
                        })
×
2229
                }
2230

2231
                return nil
×
2232
        }, func() {
×
2233
                edges = nil
×
2234
        })
×
2235
        if err != nil {
×
2236
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2237
        }
×
2238

2239
        return edges, nil
×
2240
}
2241

2242
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2243
// ID's that we don't know and are not known zombies of the passed set. In other
2244
// words, we perform a set difference of our set of chan ID's and the ones
2245
// passed in. This method can be used by callers to determine the set of
2246
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2247
// known zombies is also returned.
2248
//
2249
// NOTE: part of the V1Store interface.
2250
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
2251
        []ChannelUpdateInfo, error) {
×
2252

×
2253
        var (
×
2254
                ctx          = context.TODO()
×
2255
                newChanIDs   []uint64
×
2256
                knownZombies []ChannelUpdateInfo
×
2257
        )
×
2258
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2259
                for _, chanInfo := range chansInfo {
×
2260
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
2261
                        var chanIDB [8]byte
×
2262
                        byteOrder.PutUint64(chanIDB[:], channelID)
×
2263

×
2264
                        // TODO(elle): potentially optimize this by using
×
2265
                        //  sqlc.slice() once that works for both SQLite and
×
2266
                        //  Postgres.
×
2267
                        _, err := db.GetChannelBySCID(
×
2268
                                ctx, sqlc.GetChannelBySCIDParams{
×
2269
                                        Version: int16(ProtocolV1),
×
2270
                                        Scid:    chanIDB[:],
×
2271
                                },
×
2272
                        )
×
2273
                        if err == nil {
×
2274
                                continue
×
2275
                        } else if !errors.Is(err, sql.ErrNoRows) {
×
2276
                                return fmt.Errorf("unable to fetch channel: %w",
×
2277
                                        err)
×
2278
                        }
×
2279

2280
                        isZombie, err := db.IsZombieChannel(
×
2281
                                ctx, sqlc.IsZombieChannelParams{
×
2282
                                        Scid:    chanIDB[:],
×
2283
                                        Version: int16(ProtocolV1),
×
2284
                                },
×
2285
                        )
×
2286
                        if err != nil {
×
2287
                                return fmt.Errorf("unable to fetch zombie "+
×
2288
                                        "channel: %w", err)
×
2289
                        }
×
2290

2291
                        if isZombie {
×
2292
                                knownZombies = append(knownZombies, chanInfo)
×
2293

×
2294
                                continue
×
2295
                        }
2296

2297
                        newChanIDs = append(newChanIDs, channelID)
×
2298
                }
2299

2300
                return nil
×
2301
        }, func() {
×
2302
                newChanIDs = nil
×
2303
                knownZombies = nil
×
2304
        })
×
2305
        if err != nil {
×
2306
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2307
        }
×
2308

2309
        return newChanIDs, knownZombies, nil
×
2310
}
2311

2312
// PruneGraphNodes is a garbage collection method which attempts to prune out
2313
// any nodes from the channel graph that are currently unconnected. This ensure
2314
// that we only maintain a graph of reachable nodes. In the event that a pruned
2315
// node gains more channels, it will be re-added back to the graph.
2316
//
2317
// NOTE: this prunes nodes across protocol versions. It will never prune the
2318
// source nodes.
2319
//
2320
// NOTE: part of the V1Store interface.
2321
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
2322
        var ctx = context.TODO()
×
2323

×
2324
        var prunedNodes []route.Vertex
×
2325
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2326
                var err error
×
2327
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2328

×
2329
                return err
×
2330
        }, func() {
×
2331
                prunedNodes = nil
×
2332
        })
×
2333
        if err != nil {
×
2334
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2335
        }
×
2336

2337
        return prunedNodes, nil
×
2338
}
2339

2340
// PruneGraph prunes newly closed channels from the channel graph in response
2341
// to a new block being solved on the network. Any transactions which spend the
2342
// funding output of any known channels within he graph will be deleted.
2343
// Additionally, the "prune tip", or the last block which has been used to
2344
// prune the graph is stored so callers can ensure the graph is fully in sync
2345
// with the current UTXO state. A slice of channels that have been closed by
2346
// the target block along with any pruned nodes are returned if the function
2347
// succeeds without error.
2348
//
2349
// NOTE: part of the V1Store interface.
2350
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2351
        blockHash *chainhash.Hash, blockHeight uint32) (
2352
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
2353

×
2354
        ctx := context.TODO()
×
2355

×
2356
        s.cacheMu.Lock()
×
2357
        defer s.cacheMu.Unlock()
×
2358

×
2359
        var (
×
2360
                closedChans []*models.ChannelEdgeInfo
×
2361
                prunedNodes []route.Vertex
×
2362
        )
×
2363
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2364
                for _, outpoint := range spentOutputs {
×
2365
                        // TODO(elle): potentially optimize this by using
×
2366
                        //  sqlc.slice() once that works for both SQLite and
×
2367
                        //  Postgres.
×
2368
                        //
×
2369
                        // NOTE: this fetches channels for all protocol
×
2370
                        // versions.
×
2371
                        row, err := db.GetChannelByOutpoint(
×
2372
                                ctx, outpoint.String(),
×
2373
                        )
×
2374
                        if errors.Is(err, sql.ErrNoRows) {
×
2375
                                continue
×
2376
                        } else if err != nil {
×
2377
                                return fmt.Errorf("unable to fetch channel: %w",
×
2378
                                        err)
×
2379
                        }
×
2380

2381
                        node1, node2, err := buildNodeVertices(
×
2382
                                row.Node1Pubkey, row.Node2Pubkey,
×
2383
                        )
×
2384
                        if err != nil {
×
2385
                                return err
×
2386
                        }
×
2387

2388
                        info, err := getAndBuildEdgeInfo(
×
2389
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
2390
                                row.Channel, node1, node2,
×
2391
                        )
×
2392
                        if err != nil {
×
2393
                                return err
×
2394
                        }
×
2395

2396
                        err = db.DeleteChannel(ctx, row.Channel.ID)
×
2397
                        if err != nil {
×
2398
                                return fmt.Errorf("unable to delete "+
×
2399
                                        "channel: %w", err)
×
2400
                        }
×
2401

2402
                        closedChans = append(closedChans, info)
×
2403
                }
2404

2405
                err := db.UpsertPruneLogEntry(
×
2406
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
2407
                                BlockHash:   blockHash[:],
×
2408
                                BlockHeight: int64(blockHeight),
×
2409
                        },
×
2410
                )
×
2411
                if err != nil {
×
2412
                        return fmt.Errorf("unable to insert prune log "+
×
2413
                                "entry: %w", err)
×
2414
                }
×
2415

2416
                // Now that we've pruned some channels, we'll also prune any
2417
                // nodes that no longer have any channels.
2418
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2419
                if err != nil {
×
2420
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2421
                                err)
×
2422
                }
×
2423

2424
                return nil
×
2425
        }, func() {
×
2426
                prunedNodes = nil
×
2427
                closedChans = nil
×
2428
        })
×
2429
        if err != nil {
×
2430
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2431
        }
×
2432

2433
        for _, channel := range closedChans {
×
2434
                s.rejectCache.remove(channel.ChannelID)
×
2435
                s.chanCache.remove(channel.ChannelID)
×
2436
        }
×
2437

2438
        return closedChans, prunedNodes, nil
×
2439
}
2440

2441
// ChannelView returns the verifiable edge information for each active channel
2442
// within the known channel graph. The set of UTXOs (along with their scripts)
2443
// returned are the ones that need to be watched on chain to detect channel
2444
// closes on the resident blockchain.
2445
//
2446
// NOTE: part of the V1Store interface.
2447
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
2448
        var (
×
2449
                ctx        = context.TODO()
×
2450
                edgePoints []EdgePoint
×
2451
        )
×
2452

×
2453
        handleChannel := func(db SQLQueries,
×
2454
                channel sqlc.ListChannelsPaginatedRow) error {
×
2455

×
2456
                pkScript, err := genMultiSigP2WSH(
×
2457
                        channel.BitcoinKey1, channel.BitcoinKey2,
×
2458
                )
×
2459
                if err != nil {
×
2460
                        return err
×
2461
                }
×
2462

2463
                op, err := wire.NewOutPointFromString(channel.Outpoint)
×
2464
                if err != nil {
×
2465
                        return err
×
2466
                }
×
2467

2468
                edgePoints = append(edgePoints, EdgePoint{
×
2469
                        FundingPkScript: pkScript,
×
2470
                        OutPoint:        *op,
×
2471
                })
×
2472

×
2473
                return nil
×
2474
        }
2475

2476
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2477
                lastID := int64(-1)
×
2478
                for {
×
2479
                        rows, err := db.ListChannelsPaginated(
×
2480
                                ctx, sqlc.ListChannelsPaginatedParams{
×
2481
                                        Version: int16(ProtocolV1),
×
2482
                                        ID:      lastID,
×
2483
                                        Limit:   pageSize,
×
2484
                                },
×
2485
                        )
×
2486
                        if err != nil {
×
2487
                                return err
×
2488
                        }
×
2489

2490
                        if len(rows) == 0 {
×
2491
                                break
×
2492
                        }
2493

2494
                        for _, row := range rows {
×
2495
                                err := handleChannel(db, row)
×
2496
                                if err != nil {
×
2497
                                        return err
×
2498
                                }
×
2499

2500
                                lastID = row.ID
×
2501
                        }
2502
                }
2503

2504
                return nil
×
2505
        }, func() {
×
2506
                edgePoints = nil
×
2507
        })
×
2508
        if err != nil {
×
2509
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
2510
        }
×
2511

2512
        return edgePoints, nil
×
2513
}
2514

2515
// PruneTip returns the block height and hash of the latest block that has been
2516
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2517
// to tell if the graph is currently in sync with the current best known UTXO
2518
// state.
2519
//
2520
// NOTE: part of the V1Store interface.
2521
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
2522
        var (
×
2523
                ctx       = context.TODO()
×
2524
                tipHash   chainhash.Hash
×
2525
                tipHeight uint32
×
2526
        )
×
2527
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2528
                pruneTip, err := db.GetPruneTip(ctx)
×
2529
                if errors.Is(err, sql.ErrNoRows) {
×
2530
                        return ErrGraphNeverPruned
×
2531
                } else if err != nil {
×
2532
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
2533
                }
×
2534

2535
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
2536
                tipHeight = uint32(pruneTip.BlockHeight)
×
2537

×
2538
                return nil
×
2539
        }, sqldb.NoOpReset)
2540
        if err != nil {
×
2541
                return nil, 0, err
×
2542
        }
×
2543

2544
        return &tipHash, tipHeight, nil
×
2545
}
2546

2547
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2548
//
2549
// NOTE: this prunes nodes across protocol versions. It will never prune the
2550
// source nodes.
2551
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
2552
        db SQLQueries) ([]route.Vertex, error) {
×
2553

×
NEW
2554
        nodeKeys, err := db.DeleteUnconnectedNodes(ctx)
×
2555
        if err != nil {
×
NEW
2556
                return nil, fmt.Errorf("unable to delete unconnected "+
×
NEW
2557
                        "nodes: %w", err)
×
UNCOV
2558
        }
×
2559

NEW
2560
        prunedNodes := make([]route.Vertex, len(nodeKeys))
×
NEW
2561
        for i, nodeKey := range nodeKeys {
×
NEW
2562
                pub, err := route.NewVertexFromBytes(nodeKey)
×
2563
                if err != nil {
×
2564
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
NEW
2565
                                "from bytes: %w", err)
×
2566
                }
×
2567

NEW
2568
                prunedNodes[i] = pub
×
2569
        }
2570

2571
        return prunedNodes, nil
×
2572
}
2573

2574
// DisconnectBlockAtHeight is used to indicate that the block specified
2575
// by the passed height has been disconnected from the main chain. This
2576
// will "rewind" the graph back to the height below, deleting channels
2577
// that are no longer confirmed from the graph. The prune log will be
2578
// set to the last prune height valid for the remaining chain.
2579
// Channels that were removed from the graph resulting from the
2580
// disconnected block are returned.
2581
//
2582
// NOTE: part of the V1Store interface.
2583
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
2584
        []*models.ChannelEdgeInfo, error) {
×
2585

×
2586
        ctx := context.TODO()
×
2587

×
2588
        var (
×
2589
                // Every channel having a ShortChannelID starting at 'height'
×
2590
                // will no longer be confirmed.
×
2591
                startShortChanID = lnwire.ShortChannelID{
×
2592
                        BlockHeight: height,
×
2593
                }
×
2594

×
2595
                // Delete everything after this height from the db up until the
×
2596
                // SCID alias range.
×
2597
                endShortChanID = aliasmgr.StartingAlias
×
2598

×
2599
                removedChans []*models.ChannelEdgeInfo
×
2600
        )
×
2601

×
2602
        var chanIDStart [8]byte
×
2603
        byteOrder.PutUint64(chanIDStart[:], startShortChanID.ToUint64())
×
2604
        var chanIDEnd [8]byte
×
2605
        byteOrder.PutUint64(chanIDEnd[:], endShortChanID.ToUint64())
×
2606

×
2607
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2608
                rows, err := db.GetChannelsBySCIDRange(
×
2609
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
2610
                                StartScid: chanIDStart[:],
×
2611
                                EndScid:   chanIDEnd[:],
×
2612
                        },
×
2613
                )
×
2614
                if err != nil {
×
2615
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2616
                }
×
2617

2618
                for _, row := range rows {
×
2619
                        node1, node2, err := buildNodeVertices(
×
2620
                                row.Node1PubKey, row.Node2PubKey,
×
2621
                        )
×
2622
                        if err != nil {
×
2623
                                return err
×
2624
                        }
×
2625

2626
                        channel, err := getAndBuildEdgeInfo(
×
2627
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
2628
                                row.Channel, node1, node2,
×
2629
                        )
×
2630
                        if err != nil {
×
2631
                                return err
×
2632
                        }
×
2633

2634
                        err = db.DeleteChannel(ctx, row.Channel.ID)
×
2635
                        if err != nil {
×
2636
                                return fmt.Errorf("unable to delete "+
×
2637
                                        "channel: %w", err)
×
2638
                        }
×
2639

2640
                        removedChans = append(removedChans, channel)
×
2641
                }
2642

2643
                return db.DeletePruneLogEntriesInRange(
×
2644
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2645
                                StartHeight: int64(height),
×
2646
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
2647
                        },
×
2648
                )
×
2649
        }, func() {
×
2650
                removedChans = nil
×
2651
        })
×
2652
        if err != nil {
×
2653
                return nil, fmt.Errorf("unable to disconnect block at "+
×
2654
                        "height: %w", err)
×
2655
        }
×
2656

2657
        for _, channel := range removedChans {
×
2658
                s.rejectCache.remove(channel.ChannelID)
×
2659
                s.chanCache.remove(channel.ChannelID)
×
2660
        }
×
2661

2662
        return removedChans, nil
×
2663
}
2664

2665
// AddEdgeProof sets the proof of an existing edge in the graph database.
2666
//
2667
// NOTE: part of the V1Store interface.
2668
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
2669
        proof *models.ChannelAuthProof) error {
×
2670

×
2671
        var (
×
2672
                ctx       = context.TODO()
×
2673
                scidBytes = channelIDToBytes(scid.ToUint64())
×
2674
        )
×
2675

×
2676
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2677
                res, err := db.AddV1ChannelProof(
×
2678
                        ctx, sqlc.AddV1ChannelProofParams{
×
2679
                                Scid:              scidBytes[:],
×
2680
                                Node1Signature:    proof.NodeSig1Bytes,
×
2681
                                Node2Signature:    proof.NodeSig2Bytes,
×
2682
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
2683
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
2684
                        },
×
2685
                )
×
2686
                if err != nil {
×
2687
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
2688
                }
×
2689

2690
                n, err := res.RowsAffected()
×
2691
                if err != nil {
×
2692
                        return err
×
2693
                }
×
2694

2695
                if n == 0 {
×
2696
                        return fmt.Errorf("no rows affected when adding edge "+
×
2697
                                "proof for SCID %v", scid)
×
2698
                } else if n > 1 {
×
2699
                        return fmt.Errorf("multiple rows affected when adding "+
×
2700
                                "edge proof for SCID %v: %d rows affected",
×
2701
                                scid, n)
×
2702
                }
×
2703

2704
                return nil
×
2705
        }, sqldb.NoOpReset)
2706
        if err != nil {
×
2707
                return fmt.Errorf("unable to add edge proof: %w", err)
×
2708
        }
×
2709

2710
        return nil
×
2711
}
2712

2713
// PutClosedScid stores a SCID for a closed channel in the database. This is so
2714
// that we can ignore channel announcements that we know to be closed without
2715
// having to validate them and fetch a block.
2716
//
2717
// NOTE: part of the V1Store interface.
2718
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
2719
        var (
×
2720
                ctx     = context.TODO()
×
2721
                chanIDB = channelIDToBytes(scid.ToUint64())
×
2722
        )
×
2723

×
2724
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2725
                return db.InsertClosedChannel(ctx, chanIDB[:])
×
2726
        }, sqldb.NoOpReset)
×
2727
}
2728

2729
// IsClosedScid checks whether a channel identified by the passed in scid is
2730
// closed. This helps avoid having to perform expensive validation checks.
2731
//
2732
// NOTE: part of the V1Store interface.
2733
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
2734
        var (
×
2735
                ctx      = context.TODO()
×
2736
                isClosed bool
×
2737
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
2738
        )
×
2739
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2740
                var err error
×
2741
                isClosed, err = db.IsClosedChannel(ctx, chanIDB[:])
×
2742
                if err != nil {
×
2743
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
2744
                                err)
×
2745
                }
×
2746

2747
                return nil
×
2748
        }, sqldb.NoOpReset)
2749
        if err != nil {
×
2750
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
2751
                        err)
×
2752
        }
×
2753

2754
        return isClosed, nil
×
2755
}
2756

2757
// GraphSession will provide the call-back with access to a NodeTraverser
2758
// instance which can be used to perform queries against the channel graph.
2759
//
2760
// NOTE: part of the V1Store interface.
2761
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error) error {
×
2762
        var ctx = context.TODO()
×
2763

×
2764
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2765
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
2766
        }, sqldb.NoOpReset)
×
2767
}
2768

2769
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
2770
// read only transaction for a consistent view of the graph.
2771
type sqlNodeTraverser struct {
2772
        db    SQLQueries
2773
        chain chainhash.Hash
2774
}
2775

2776
// A compile-time assertion to ensure that sqlNodeTraverser implements the
2777
// NodeTraverser interface.
2778
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
2779

2780
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
2781
func newSQLNodeTraverser(db SQLQueries,
2782
        chain chainhash.Hash) *sqlNodeTraverser {
×
2783

×
2784
        return &sqlNodeTraverser{
×
2785
                db:    db,
×
2786
                chain: chain,
×
2787
        }
×
2788
}
×
2789

2790
// ForEachNodeDirectedChannel calls the callback for every channel of the given
2791
// node.
2792
//
2793
// NOTE: Part of the NodeTraverser interface.
2794
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
2795
        cb func(channel *DirectedChannel) error) error {
×
2796

×
2797
        ctx := context.TODO()
×
2798

×
2799
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
2800
}
×
2801

2802
// FetchNodeFeatures returns the features of the given node. If the node is
2803
// unknown, assume no additional features are supported.
2804
//
2805
// NOTE: Part of the NodeTraverser interface.
2806
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
2807
        *lnwire.FeatureVector, error) {
×
2808

×
2809
        ctx := context.TODO()
×
2810

×
2811
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
2812
}
×
2813

2814
// forEachNodeDirectedChannel iterates through all channels of a given
2815
// node, executing the passed callback on the directed edge representing the
2816
// channel and its incoming policy. If the node is not found, no error is
2817
// returned.
2818
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
2819
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
2820

×
2821
        toNodeCallback := func() route.Vertex {
×
2822
                return nodePub
×
2823
        }
×
2824

2825
        dbID, err := db.GetNodeIDByPubKey(
×
2826
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
2827
                        Version: int16(ProtocolV1),
×
2828
                        PubKey:  nodePub[:],
×
2829
                },
×
2830
        )
×
2831
        if errors.Is(err, sql.ErrNoRows) {
×
2832
                return nil
×
2833
        } else if err != nil {
×
2834
                return fmt.Errorf("unable to fetch node: %w", err)
×
2835
        }
×
2836

2837
        rows, err := db.ListChannelsByNodeID(
×
2838
                ctx, sqlc.ListChannelsByNodeIDParams{
×
2839
                        Version: int16(ProtocolV1),
×
2840
                        NodeID1: dbID,
×
2841
                },
×
2842
        )
×
2843
        if err != nil {
×
2844
                return fmt.Errorf("unable to fetch channels: %w", err)
×
2845
        }
×
2846

2847
        // Exit early if there are no channels for this node so we don't
2848
        // do the unnecessary feature fetching.
2849
        if len(rows) == 0 {
×
2850
                return nil
×
2851
        }
×
2852

2853
        features, err := getNodeFeatures(ctx, db, dbID)
×
2854
        if err != nil {
×
2855
                return fmt.Errorf("unable to fetch node features: %w", err)
×
2856
        }
×
2857

2858
        for _, row := range rows {
×
2859
                node1, node2, err := buildNodeVertices(
×
2860
                        row.Node1Pubkey, row.Node2Pubkey,
×
2861
                )
×
2862
                if err != nil {
×
2863
                        return fmt.Errorf("unable to build node vertices: %w",
×
2864
                                err)
×
2865
                }
×
2866

2867
                edge := buildCacheableChannelInfo(row.Channel, node1, node2)
×
2868

×
2869
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2870
                if err != nil {
×
2871
                        return err
×
2872
                }
×
2873

2874
                var p1, p2 *models.CachedEdgePolicy
×
2875
                if dbPol1 != nil {
×
2876
                        policy1, err := buildChanPolicy(
×
2877
                                *dbPol1, edge.ChannelID, nil, node2, true,
×
2878
                        )
×
2879
                        if err != nil {
×
2880
                                return err
×
2881
                        }
×
2882

2883
                        p1 = models.NewCachedPolicy(policy1)
×
2884
                }
2885
                if dbPol2 != nil {
×
2886
                        policy2, err := buildChanPolicy(
×
2887
                                *dbPol2, edge.ChannelID, nil, node1, false,
×
2888
                        )
×
2889
                        if err != nil {
×
2890
                                return err
×
2891
                        }
×
2892

2893
                        p2 = models.NewCachedPolicy(policy2)
×
2894
                }
2895

2896
                // Determine the outgoing and incoming policy for this
2897
                // channel and node combo.
2898
                outPolicy, inPolicy := p1, p2
×
2899
                if p1 != nil && node2 == nodePub {
×
2900
                        outPolicy, inPolicy = p2, p1
×
2901
                } else if p2 != nil && node1 != nodePub {
×
2902
                        outPolicy, inPolicy = p2, p1
×
2903
                }
×
2904

2905
                var cachedInPolicy *models.CachedEdgePolicy
×
2906
                if inPolicy != nil {
×
2907
                        cachedInPolicy = inPolicy
×
2908
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
2909
                        cachedInPolicy.ToNodeFeatures = features
×
2910
                }
×
2911

2912
                directedChannel := &DirectedChannel{
×
2913
                        ChannelID:    edge.ChannelID,
×
2914
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
2915
                        OtherNode:    edge.NodeKey2Bytes,
×
2916
                        Capacity:     edge.Capacity,
×
2917
                        OutPolicySet: outPolicy != nil,
×
2918
                        InPolicy:     cachedInPolicy,
×
2919
                }
×
2920
                if outPolicy != nil {
×
2921
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
2922
                                directedChannel.InboundFee = fee
×
2923
                        })
×
2924
                }
2925

2926
                if nodePub == edge.NodeKey2Bytes {
×
2927
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
2928
                }
×
2929

2930
                if err := cb(directedChannel); err != nil {
×
2931
                        return err
×
2932
                }
×
2933
        }
2934

2935
        return nil
×
2936
}
2937

2938
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
2939
// and executes the provided callback for each node.
2940
func forEachNodeCacheable(ctx context.Context, db SQLQueries,
2941
        cb func(nodeID int64, nodePub route.Vertex) error) error {
×
2942

×
2943
        lastID := int64(-1)
×
2944

×
2945
        for {
×
2946
                nodes, err := db.ListNodeIDsAndPubKeys(
×
2947
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
2948
                                Version: int16(ProtocolV1),
×
2949
                                ID:      lastID,
×
2950
                                Limit:   pageSize,
×
2951
                        },
×
2952
                )
×
2953
                if err != nil {
×
2954
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
2955
                }
×
2956

2957
                if len(nodes) == 0 {
×
2958
                        break
×
2959
                }
2960

2961
                for _, node := range nodes {
×
2962
                        var pub route.Vertex
×
2963
                        copy(pub[:], node.PubKey)
×
2964

×
2965
                        if err := cb(node.ID, pub); err != nil {
×
2966
                                return fmt.Errorf("forEachNodeCacheable "+
×
2967
                                        "callback failed for node(id=%d): %w",
×
2968
                                        node.ID, err)
×
2969
                        }
×
2970

2971
                        lastID = node.ID
×
2972
                }
2973
        }
2974

2975
        return nil
×
2976
}
2977

2978
// forEachNodeChannel iterates through all channels of a node, executing
2979
// the passed callback on each. The call-back is provided with the channel's
2980
// edge information, the outgoing policy and the incoming policy for the
2981
// channel and node combo.
2982
func forEachNodeChannel(ctx context.Context, db SQLQueries,
2983
        chain chainhash.Hash, id int64, cb func(*models.ChannelEdgeInfo,
2984
                *models.ChannelEdgePolicy,
2985
                *models.ChannelEdgePolicy) error) error {
×
2986

×
2987
        // Get all the V1 channels for this node.Add commentMore actions
×
2988
        rows, err := db.ListChannelsByNodeID(
×
2989
                ctx, sqlc.ListChannelsByNodeIDParams{
×
2990
                        Version: int16(ProtocolV1),
×
2991
                        NodeID1: id,
×
2992
                },
×
2993
        )
×
2994
        if err != nil {
×
2995
                return fmt.Errorf("unable to fetch channels: %w", err)
×
2996
        }
×
2997

2998
        // Call the call-back for each channel and its known policies.
2999
        for _, row := range rows {
×
3000
                node1, node2, err := buildNodeVertices(
×
3001
                        row.Node1Pubkey, row.Node2Pubkey,
×
3002
                )
×
3003
                if err != nil {
×
3004
                        return fmt.Errorf("unable to build node vertices: %w",
×
3005
                                err)
×
3006
                }
×
3007

3008
                edge, err := getAndBuildEdgeInfo(
×
3009
                        ctx, db, chain, row.Channel.ID, row.Channel, node1,
×
3010
                        node2,
×
3011
                )
×
3012
                if err != nil {
×
3013
                        return fmt.Errorf("unable to build channel info: %w",
×
3014
                                err)
×
3015
                }
×
3016

3017
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3018
                if err != nil {
×
3019
                        return fmt.Errorf("unable to extract channel "+
×
3020
                                "policies: %w", err)
×
3021
                }
×
3022

3023
                p1, p2, err := getAndBuildChanPolicies(
×
3024
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
3025
                )
×
3026
                if err != nil {
×
3027
                        return fmt.Errorf("unable to build channel "+
×
3028
                                "policies: %w", err)
×
3029
                }
×
3030

3031
                // Determine the outgoing and incoming policy for this
3032
                // channel and node combo.
3033
                p1ToNode := row.Channel.NodeID2
×
3034
                p2ToNode := row.Channel.NodeID1
×
3035
                outPolicy, inPolicy := p1, p2
×
3036
                if (p1 != nil && p1ToNode == id) ||
×
3037
                        (p2 != nil && p2ToNode != id) {
×
3038

×
3039
                        outPolicy, inPolicy = p2, p1
×
3040
                }
×
3041

3042
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3043
                        return err
×
3044
                }
×
3045
        }
3046

3047
        return nil
×
3048
}
3049

3050
// updateChanEdgePolicy upserts the channel policy info we have stored for
3051
// a channel we already know of.
3052
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3053
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
3054
        error) {
×
3055

×
3056
        var (
×
3057
                node1Pub, node2Pub route.Vertex
×
3058
                isNode1            bool
×
3059
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3060
        )
×
3061

×
3062
        // Check that this edge policy refers to a channel that we already
×
3063
        // know of. We do this explicitly so that we can return the appropriate
×
3064
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
3065
        // abort the transaction which would abort the entire batch.
×
3066
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3067
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3068
                        Scid:    chanIDB[:],
×
3069
                        Version: int16(ProtocolV1),
×
3070
                },
×
3071
        )
×
3072
        if errors.Is(err, sql.ErrNoRows) {
×
3073
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3074
        } else if err != nil {
×
3075
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3076
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3077
        }
×
3078

3079
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3080
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3081

×
3082
        // Figure out which node this edge is from.
×
3083
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3084
        nodeID := dbChan.NodeID1
×
3085
        if !isNode1 {
×
3086
                nodeID = dbChan.NodeID2
×
3087
        }
×
3088

3089
        var (
×
3090
                inboundBase sql.NullInt64
×
3091
                inboundRate sql.NullInt64
×
3092
        )
×
3093
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3094
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3095
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3096
        })
×
3097

3098
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
3099
                Version:     int16(ProtocolV1),
×
3100
                ChannelID:   dbChan.ID,
×
3101
                NodeID:      nodeID,
×
3102
                Timelock:    int32(edge.TimeLockDelta),
×
3103
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
3104
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
3105
                MinHtlcMsat: int64(edge.MinHTLC),
×
3106
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3107
                Disabled: sql.NullBool{
×
3108
                        Valid: true,
×
3109
                        Bool:  edge.IsDisabled(),
×
3110
                },
×
3111
                MaxHtlcMsat: sql.NullInt64{
×
3112
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3113
                        Int64: int64(edge.MaxHTLC),
×
3114
                },
×
3115
                InboundBaseFeeMsat:      inboundBase,
×
3116
                InboundFeeRateMilliMsat: inboundRate,
×
3117
                Signature:               edge.SigBytes,
×
3118
        })
×
3119
        if err != nil {
×
3120
                return node1Pub, node2Pub, isNode1,
×
3121
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3122
        }
×
3123

3124
        // Convert the flat extra opaque data into a map of TLV types to
3125
        // values.
3126
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3127
        if err != nil {
×
3128
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3129
                        "marshal extra opaque data: %w", err)
×
3130
        }
×
3131

3132
        // Update the channel policy's extra signed fields.
3133
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3134
        if err != nil {
×
3135
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3136
                        "policy extra TLVs: %w", err)
×
3137
        }
×
3138

3139
        return node1Pub, node2Pub, isNode1, nil
×
3140
}
3141

3142
// getNodeByPubKey attempts to look up a target node by its public key.
3143
func getNodeByPubKey(ctx context.Context, db SQLQueries,
3144
        pubKey route.Vertex) (int64, *models.LightningNode, error) {
×
3145

×
3146
        dbNode, err := db.GetNodeByPubKey(
×
3147
                ctx, sqlc.GetNodeByPubKeyParams{
×
3148
                        Version: int16(ProtocolV1),
×
3149
                        PubKey:  pubKey[:],
×
3150
                },
×
3151
        )
×
3152
        if errors.Is(err, sql.ErrNoRows) {
×
3153
                return 0, nil, ErrGraphNodeNotFound
×
3154
        } else if err != nil {
×
3155
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3156
        }
×
3157

3158
        node, err := buildNode(ctx, db, &dbNode)
×
3159
        if err != nil {
×
3160
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3161
        }
×
3162

3163
        return dbNode.ID, node, nil
×
3164
}
3165

3166
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3167
// provided database channel row and the public keys of the two nodes
3168
// involved in the channel.
3169
func buildCacheableChannelInfo(dbChan sqlc.Channel, node1Pub,
3170
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
3171

×
3172
        return &models.CachedEdgeInfo{
×
3173
                ChannelID:     byteOrder.Uint64(dbChan.Scid),
×
3174
                NodeKey1Bytes: node1Pub,
×
3175
                NodeKey2Bytes: node2Pub,
×
3176
                Capacity:      btcutil.Amount(dbChan.Capacity.Int64),
×
3177
        }
×
3178
}
×
3179

3180
// buildNode constructs a LightningNode instance from the given database node
3181
// record. The node's features, addresses and extra signed fields are also
3182
// fetched from the database and set on the node.
3183
func buildNode(ctx context.Context, db SQLQueries, dbNode *sqlc.Node) (
3184
        *models.LightningNode, error) {
×
3185

×
3186
        if dbNode.Version != int16(ProtocolV1) {
×
3187
                return nil, fmt.Errorf("unsupported node version: %d",
×
3188
                        dbNode.Version)
×
3189
        }
×
3190

3191
        var pub [33]byte
×
3192
        copy(pub[:], dbNode.PubKey)
×
3193

×
3194
        node := &models.LightningNode{
×
3195
                PubKeyBytes: pub,
×
3196
                Features:    lnwire.EmptyFeatureVector(),
×
3197
                LastUpdate:  time.Unix(0, 0),
×
3198
        }
×
3199

×
3200
        if len(dbNode.Signature) == 0 {
×
3201
                return node, nil
×
3202
        }
×
3203

3204
        node.HaveNodeAnnouncement = true
×
3205
        node.AuthSigBytes = dbNode.Signature
×
3206
        node.Alias = dbNode.Alias.String
×
3207
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
3208

×
3209
        var err error
×
3210
        if dbNode.Color.Valid {
×
3211
                node.Color, err = DecodeHexColor(dbNode.Color.String)
×
3212
                if err != nil {
×
3213
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3214
                                err)
×
3215
                }
×
3216
        }
3217

3218
        // Fetch the node's features.
3219
        node.Features, err = getNodeFeatures(ctx, db, dbNode.ID)
×
3220
        if err != nil {
×
3221
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3222
                        "features: %w", dbNode.ID, err)
×
3223
        }
×
3224

3225
        // Fetch the node's addresses.
3226
        _, node.Addresses, err = getNodeAddresses(ctx, db, pub[:])
×
3227
        if err != nil {
×
3228
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3229
                        "addresses: %w", dbNode.ID, err)
×
3230
        }
×
3231

3232
        // Fetch the node's extra signed fields.
3233
        extraTLVMap, err := getNodeExtraSignedFields(ctx, db, dbNode.ID)
×
3234
        if err != nil {
×
3235
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3236
                        "extra signed fields: %w", dbNode.ID, err)
×
3237
        }
×
3238

3239
        recs, err := lnwire.CustomRecords(extraTLVMap).Serialize()
×
3240
        if err != nil {
×
3241
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
3242
                        "fields: %w", err)
×
3243
        }
×
3244

3245
        if len(recs) != 0 {
×
3246
                node.ExtraOpaqueData = recs
×
3247
        }
×
3248

3249
        return node, nil
×
3250
}
3251

3252
// getNodeFeatures fetches the feature bits and constructs the feature vector
3253
// for a node with the given DB ID.
3254
func getNodeFeatures(ctx context.Context, db SQLQueries,
3255
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3256

×
3257
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3258
        if err != nil {
×
3259
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3260
                        nodeID, err)
×
3261
        }
×
3262

3263
        features := lnwire.EmptyFeatureVector()
×
3264
        for _, feature := range rows {
×
3265
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3266
        }
×
3267

3268
        return features, nil
×
3269
}
3270

3271
// getNodeExtraSignedFields fetches the extra signed fields for a node with the
3272
// given DB ID.
3273
func getNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3274
        nodeID int64) (map[uint64][]byte, error) {
×
3275

×
3276
        fields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3277
        if err != nil {
×
3278
                return nil, fmt.Errorf("unable to get node(%d) extra "+
×
3279
                        "signed fields: %w", nodeID, err)
×
3280
        }
×
3281

3282
        extraFields := make(map[uint64][]byte)
×
3283
        for _, field := range fields {
×
3284
                extraFields[uint64(field.Type)] = field.Value
×
3285
        }
×
3286

3287
        return extraFields, nil
×
3288
}
3289

3290
// upsertNode upserts the node record into the database. If the node already
3291
// exists, then the node's information is updated. If the node doesn't exist,
3292
// then a new node is created. The node's features, addresses and extra TLV
3293
// types are also updated. The node's DB ID is returned.
3294
func upsertNode(ctx context.Context, db SQLQueries,
3295
        node *models.LightningNode) (int64, error) {
×
3296

×
3297
        params := sqlc.UpsertNodeParams{
×
3298
                Version: int16(ProtocolV1),
×
3299
                PubKey:  node.PubKeyBytes[:],
×
3300
        }
×
3301

×
3302
        if node.HaveNodeAnnouncement {
×
3303
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
3304
                params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
×
3305
                params.Alias = sqldb.SQLStr(node.Alias)
×
3306
                params.Signature = node.AuthSigBytes
×
3307
        }
×
3308

3309
        nodeID, err := db.UpsertNode(ctx, params)
×
3310
        if err != nil {
×
3311
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3312
                        err)
×
3313
        }
×
3314

3315
        // We can exit here if we don't have the announcement yet.
3316
        if !node.HaveNodeAnnouncement {
×
3317
                return nodeID, nil
×
3318
        }
×
3319

3320
        // Update the node's features.
3321
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3322
        if err != nil {
×
3323
                return 0, fmt.Errorf("inserting node features: %w", err)
×
3324
        }
×
3325

3326
        // Update the node's addresses.
3327
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3328
        if err != nil {
×
3329
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
3330
        }
×
3331

3332
        // Convert the flat extra opaque data into a map of TLV types to
3333
        // values.
3334
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
3335
        if err != nil {
×
3336
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
3337
                        err)
×
3338
        }
×
3339

3340
        // Update the node's extra signed fields.
3341
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3342
        if err != nil {
×
3343
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
3344
        }
×
3345

3346
        return nodeID, nil
×
3347
}
3348

3349
// upsertNodeFeatures updates the node's features node_features table. This
3350
// includes deleting any feature bits no longer present and inserting any new
3351
// feature bits. If the feature bit does not yet exist in the features table,
3352
// then an entry is created in that table first.
3353
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3354
        features *lnwire.FeatureVector) error {
×
3355

×
3356
        // Get any existing features for the node.
×
3357
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3358
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3359
                return err
×
3360
        }
×
3361

3362
        // Copy the nodes latest set of feature bits.
3363
        newFeatures := make(map[int32]struct{})
×
3364
        if features != nil {
×
3365
                for feature := range features.Features() {
×
3366
                        newFeatures[int32(feature)] = struct{}{}
×
3367
                }
×
3368
        }
3369

3370
        // For any current feature that already exists in the DB, remove it from
3371
        // the in-memory map. For any existing feature that does not exist in
3372
        // the in-memory map, delete it from the database.
3373
        for _, feature := range existingFeatures {
×
3374
                // The feature is still present, so there are no updates to be
×
3375
                // made.
×
3376
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3377
                        delete(newFeatures, feature.FeatureBit)
×
3378
                        continue
×
3379
                }
3380

3381
                // The feature is no longer present, so we remove it from the
3382
                // database.
3383
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
3384
                        NodeID:     nodeID,
×
3385
                        FeatureBit: feature.FeatureBit,
×
3386
                })
×
3387
                if err != nil {
×
3388
                        return fmt.Errorf("unable to delete node(%d) "+
×
3389
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3390
                                err)
×
3391
                }
×
3392
        }
3393

3394
        // Any remaining entries in newFeatures are new features that need to be
3395
        // added to the database for the first time.
3396
        for feature := range newFeatures {
×
3397
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3398
                        NodeID:     nodeID,
×
3399
                        FeatureBit: feature,
×
3400
                })
×
3401
                if err != nil {
×
3402
                        return fmt.Errorf("unable to insert node(%d) "+
×
3403
                                "feature(%v): %w", nodeID, feature, err)
×
3404
                }
×
3405
        }
3406

3407
        return nil
×
3408
}
3409

3410
// fetchNodeFeatures fetches the features for a node with the given public key.
3411
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3412
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3413

×
3414
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3415
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3416
                        PubKey:  nodePub[:],
×
3417
                        Version: int16(ProtocolV1),
×
3418
                },
×
3419
        )
×
3420
        if err != nil {
×
3421
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
3422
                        nodePub, err)
×
3423
        }
×
3424

3425
        features := lnwire.EmptyFeatureVector()
×
3426
        for _, bit := range rows {
×
3427
                features.Set(lnwire.FeatureBit(bit))
×
3428
        }
×
3429

3430
        return features, nil
×
3431
}
3432

3433
// dbAddressType is an enum type that represents the different address types
3434
// that we store in the node_addresses table. The address type determines how
3435
// the address is to be serialised/deserialize.
3436
type dbAddressType uint8
3437

3438
const (
3439
        addressTypeIPv4   dbAddressType = 1
3440
        addressTypeIPv6   dbAddressType = 2
3441
        addressTypeTorV2  dbAddressType = 3
3442
        addressTypeTorV3  dbAddressType = 4
3443
        addressTypeOpaque dbAddressType = math.MaxInt8
3444
)
3445

3446
// upsertNodeAddresses updates the node's addresses in the database. This
3447
// includes deleting any existing addresses and inserting the new set of
3448
// addresses. The deletion is necessary since the ordering of the addresses may
3449
// change, and we need to ensure that the database reflects the latest set of
3450
// addresses so that at the time of reconstructing the node announcement, the
3451
// order is preserved and the signature over the message remains valid.
3452
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
3453
        addresses []net.Addr) error {
×
3454

×
3455
        // Delete any existing addresses for the node. This is required since
×
3456
        // even if the new set of addresses is the same, the ordering may have
×
3457
        // changed for a given address type.
×
3458
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
3459
        if err != nil {
×
3460
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
3461
                        nodeID, err)
×
3462
        }
×
3463

3464
        // Copy the nodes latest set of addresses.
3465
        newAddresses := map[dbAddressType][]string{
×
3466
                addressTypeIPv4:   {},
×
3467
                addressTypeIPv6:   {},
×
3468
                addressTypeTorV2:  {},
×
3469
                addressTypeTorV3:  {},
×
3470
                addressTypeOpaque: {},
×
3471
        }
×
3472
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3473
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3474
        }
×
3475

3476
        for _, address := range addresses {
×
3477
                switch addr := address.(type) {
×
3478
                case *net.TCPAddr:
×
3479
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3480
                                addAddr(addressTypeIPv4, addr)
×
3481
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3482
                                addAddr(addressTypeIPv6, addr)
×
3483
                        } else {
×
3484
                                return fmt.Errorf("unhandled IP address: %v",
×
3485
                                        addr)
×
3486
                        }
×
3487

3488
                case *tor.OnionAddr:
×
3489
                        switch len(addr.OnionService) {
×
3490
                        case tor.V2Len:
×
3491
                                addAddr(addressTypeTorV2, addr)
×
3492
                        case tor.V3Len:
×
3493
                                addAddr(addressTypeTorV3, addr)
×
3494
                        default:
×
3495
                                return fmt.Errorf("invalid length for a tor " +
×
3496
                                        "address")
×
3497
                        }
3498

3499
                case *lnwire.OpaqueAddrs:
×
3500
                        addAddr(addressTypeOpaque, addr)
×
3501

3502
                default:
×
3503
                        return fmt.Errorf("unhandled address type: %T", addr)
×
3504
                }
3505
        }
3506

3507
        // Any remaining entries in newAddresses are new addresses that need to
3508
        // be added to the database for the first time.
3509
        for addrType, addrList := range newAddresses {
×
3510
                for position, addr := range addrList {
×
3511
                        err := db.InsertNodeAddress(
×
3512
                                ctx, sqlc.InsertNodeAddressParams{
×
3513
                                        NodeID:   nodeID,
×
3514
                                        Type:     int16(addrType),
×
3515
                                        Address:  addr,
×
3516
                                        Position: int32(position),
×
3517
                                },
×
3518
                        )
×
3519
                        if err != nil {
×
3520
                                return fmt.Errorf("unable to insert "+
×
3521
                                        "node(%d) address(%v): %w", nodeID,
×
3522
                                        addr, err)
×
3523
                        }
×
3524
                }
3525
        }
3526

3527
        return nil
×
3528
}
3529

3530
// getNodeAddresses fetches the addresses for a node with the given public key.
3531
func getNodeAddresses(ctx context.Context, db SQLQueries, nodePub []byte) (bool,
3532
        []net.Addr, error) {
×
3533

×
3534
        // GetNodeAddressesByPubKey ensures that the addresses for a given type
×
3535
        // are returned in the same order as they were inserted.
×
3536
        rows, err := db.GetNodeAddressesByPubKey(
×
3537
                ctx, sqlc.GetNodeAddressesByPubKeyParams{
×
3538
                        Version: int16(ProtocolV1),
×
3539
                        PubKey:  nodePub,
×
3540
                },
×
3541
        )
×
3542
        if err != nil {
×
3543
                return false, nil, err
×
3544
        }
×
3545

3546
        // GetNodeAddressesByPubKey uses a left join so there should always be
3547
        // at least one row returned if the node exists even if it has no
3548
        // addresses.
3549
        if len(rows) == 0 {
×
3550
                return false, nil, nil
×
3551
        }
×
3552

3553
        addresses := make([]net.Addr, 0, len(rows))
×
3554
        for _, addr := range rows {
×
3555
                if !(addr.Type.Valid && addr.Address.Valid) {
×
3556
                        continue
×
3557
                }
3558

3559
                address := addr.Address.String
×
3560

×
3561
                switch dbAddressType(addr.Type.Int16) {
×
3562
                case addressTypeIPv4:
×
3563
                        tcp, err := net.ResolveTCPAddr("tcp4", address)
×
3564
                        if err != nil {
×
3565
                                return false, nil, nil
×
3566
                        }
×
3567
                        tcp.IP = tcp.IP.To4()
×
3568

×
3569
                        addresses = append(addresses, tcp)
×
3570

3571
                case addressTypeIPv6:
×
3572
                        tcp, err := net.ResolveTCPAddr("tcp6", address)
×
3573
                        if err != nil {
×
3574
                                return false, nil, nil
×
3575
                        }
×
3576
                        addresses = append(addresses, tcp)
×
3577

3578
                case addressTypeTorV3, addressTypeTorV2:
×
3579
                        service, portStr, err := net.SplitHostPort(address)
×
3580
                        if err != nil {
×
3581
                                return false, nil, fmt.Errorf("unable to "+
×
3582
                                        "split tor v3 address: %v",
×
3583
                                        addr.Address)
×
3584
                        }
×
3585

3586
                        port, err := strconv.Atoi(portStr)
×
3587
                        if err != nil {
×
3588
                                return false, nil, err
×
3589
                        }
×
3590

3591
                        addresses = append(addresses, &tor.OnionAddr{
×
3592
                                OnionService: service,
×
3593
                                Port:         port,
×
3594
                        })
×
3595

3596
                case addressTypeOpaque:
×
3597
                        opaque, err := hex.DecodeString(address)
×
3598
                        if err != nil {
×
3599
                                return false, nil, fmt.Errorf("unable to "+
×
3600
                                        "decode opaque address: %v", addr)
×
3601
                        }
×
3602

3603
                        addresses = append(addresses, &lnwire.OpaqueAddrs{
×
3604
                                Payload: opaque,
×
3605
                        })
×
3606

3607
                default:
×
3608
                        return false, nil, fmt.Errorf("unknown address "+
×
3609
                                "type: %v", addr.Type)
×
3610
                }
3611
        }
3612

3613
        // If we have no addresses, then we'll return nil instead of an
3614
        // empty slice.
NEW
3615
        if len(addresses) == 0 {
×
NEW
3616
                addresses = nil
×
NEW
3617
        }
×
3618

UNCOV
3619
        return true, addresses, nil
×
3620
}
3621

3622
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
3623
// database. This includes updating any existing types, inserting any new types,
3624
// and deleting any types that are no longer present.
3625
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3626
        nodeID int64, extraFields map[uint64][]byte) error {
×
3627

×
3628
        // Get any existing extra signed fields for the node.
×
3629
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3630
        if err != nil {
×
3631
                return err
×
3632
        }
×
3633

3634
        // Make a lookup map of the existing field types so that we can use it
3635
        // to keep track of any fields we should delete.
3636
        m := make(map[uint64]bool)
×
3637
        for _, field := range existingFields {
×
3638
                m[uint64(field.Type)] = true
×
3639
        }
×
3640

3641
        // For all the new fields, we'll upsert them and remove them from the
3642
        // map of existing fields.
3643
        for tlvType, value := range extraFields {
×
3644
                err = db.UpsertNodeExtraType(
×
3645
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
3646
                                NodeID: nodeID,
×
3647
                                Type:   int64(tlvType),
×
3648
                                Value:  value,
×
3649
                        },
×
3650
                )
×
3651
                if err != nil {
×
3652
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
3653
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3654
                }
×
3655

3656
                // Remove the field from the map of existing fields if it was
3657
                // present.
3658
                delete(m, tlvType)
×
3659
        }
3660

3661
        // For all the fields that are left in the map of existing fields, we'll
3662
        // delete them as they are no longer present in the new set of fields.
3663
        for tlvType := range m {
×
3664
                err = db.DeleteExtraNodeType(
×
3665
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
3666
                                NodeID: nodeID,
×
3667
                                Type:   int64(tlvType),
×
3668
                        },
×
3669
                )
×
3670
                if err != nil {
×
3671
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
3672
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3673
                }
×
3674
        }
3675

3676
        return nil
×
3677
}
3678

3679
// srcNodeInfo holds the information about the source node of the graph.
3680
type srcNodeInfo struct {
3681
        // id is the DB level ID of the source node entry in the "nodes" table.
3682
        id int64
3683

3684
        // pub is the public key of the source node.
3685
        pub route.Vertex
3686
}
3687

3688
// getSourceNode returns the DB node ID and pub key of the source node for the
3689
// specified protocol version.
3690
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
3691
        version ProtocolVersion) (int64, route.Vertex, error) {
×
3692

×
3693
        s.srcNodeMu.Lock()
×
3694
        defer s.srcNodeMu.Unlock()
×
3695

×
3696
        // If we already have the source node ID and pub key cached, then
×
3697
        // return them.
×
3698
        if info, ok := s.srcNodes[version]; ok {
×
3699
                return info.id, info.pub, nil
×
3700
        }
×
3701

3702
        var pubKey route.Vertex
×
3703

×
3704
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
3705
        if err != nil {
×
3706
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
3707
                        err)
×
3708
        }
×
3709

3710
        if len(nodes) == 0 {
×
3711
                return 0, pubKey, ErrSourceNodeNotSet
×
3712
        } else if len(nodes) > 1 {
×
3713
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
3714
                        "protocol %s found", version)
×
3715
        }
×
3716

3717
        copy(pubKey[:], nodes[0].PubKey)
×
3718

×
3719
        s.srcNodes[version] = &srcNodeInfo{
×
3720
                id:  nodes[0].NodeID,
×
3721
                pub: pubKey,
×
3722
        }
×
3723

×
3724
        return nodes[0].NodeID, pubKey, nil
×
3725
}
3726

3727
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
3728
// This then produces a map from TLV type to value. If the input is not a
3729
// valid TLV stream, then an error is returned.
3730
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
3731
        r := bytes.NewReader(data)
×
3732

×
3733
        tlvStream, err := tlv.NewStream()
×
3734
        if err != nil {
×
3735
                return nil, err
×
3736
        }
×
3737

3738
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
3739
        // pass it into the P2P decoding variant.
3740
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
3741
        if err != nil {
×
3742
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
3743
        }
×
3744
        if len(parsedTypes) == 0 {
×
3745
                return nil, nil
×
3746
        }
×
3747

3748
        records := make(map[uint64][]byte)
×
3749
        for k, v := range parsedTypes {
×
3750
                records[uint64(k)] = v
×
3751
        }
×
3752

3753
        return records, nil
×
3754
}
3755

3756
// insertChannel inserts a new channel record into the database.
3757
func insertChannel(ctx context.Context, db SQLQueries,
3758
        edge *models.ChannelEdgeInfo) error {
×
3759

×
3760
        chanIDB := channelIDToBytes(edge.ChannelID)
×
3761

×
3762
        // Make sure that the channel doesn't already exist. We do this
×
3763
        // explicitly instead of relying on catching a unique constraint error
×
3764
        // because relying on SQL to throw that error would abort the entire
×
3765
        // batch of transactions.
×
3766
        _, err := db.GetChannelBySCID(
×
3767
                ctx, sqlc.GetChannelBySCIDParams{
×
3768
                        Scid:    chanIDB[:],
×
3769
                        Version: int16(ProtocolV1),
×
3770
                },
×
3771
        )
×
3772
        if err == nil {
×
3773
                return ErrEdgeAlreadyExist
×
3774
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3775
                return fmt.Errorf("unable to fetch channel: %w", err)
×
3776
        }
×
3777

3778
        // Make sure that at least a "shell" entry for each node is present in
3779
        // the nodes table.
3780
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
3781
        if err != nil {
×
3782
                return fmt.Errorf("unable to create shell node: %w", err)
×
3783
        }
×
3784

3785
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
3786
        if err != nil {
×
3787
                return fmt.Errorf("unable to create shell node: %w", err)
×
3788
        }
×
3789

3790
        var capacity sql.NullInt64
×
3791
        if edge.Capacity != 0 {
×
3792
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
3793
        }
×
3794

3795
        createParams := sqlc.CreateChannelParams{
×
3796
                Version:     int16(ProtocolV1),
×
3797
                Scid:        chanIDB[:],
×
3798
                NodeID1:     node1DBID,
×
3799
                NodeID2:     node2DBID,
×
3800
                Outpoint:    edge.ChannelPoint.String(),
×
3801
                Capacity:    capacity,
×
3802
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
3803
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
3804
        }
×
3805

×
3806
        if edge.AuthProof != nil {
×
3807
                proof := edge.AuthProof
×
3808

×
3809
                createParams.Node1Signature = proof.NodeSig1Bytes
×
3810
                createParams.Node2Signature = proof.NodeSig2Bytes
×
3811
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
3812
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
3813
        }
×
3814

3815
        // Insert the new channel record.
3816
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
3817
        if err != nil {
×
3818
                return err
×
3819
        }
×
3820

3821
        // Insert any channel features.
3822
        if len(edge.Features) != 0 {
×
3823
                chanFeatures := lnwire.NewRawFeatureVector()
×
3824
                err := chanFeatures.Decode(bytes.NewReader(edge.Features))
×
3825
                if err != nil {
×
3826
                        return err
×
3827
                }
×
3828

3829
                fv := lnwire.NewFeatureVector(chanFeatures, lnwire.Features)
×
3830
                for feature := range fv.Features() {
×
3831
                        err = db.InsertChannelFeature(
×
3832
                                ctx, sqlc.InsertChannelFeatureParams{
×
3833
                                        ChannelID:  dbChanID,
×
3834
                                        FeatureBit: int32(feature),
×
3835
                                },
×
3836
                        )
×
3837
                        if err != nil {
×
3838
                                return fmt.Errorf("unable to insert "+
×
3839
                                        "channel(%d) feature(%v): %w", dbChanID,
×
3840
                                        feature, err)
×
3841
                        }
×
3842
                }
3843
        }
3844

3845
        // Finally, insert any extra TLV fields in the channel announcement.
3846
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3847
        if err != nil {
×
3848
                return fmt.Errorf("unable to marshal extra opaque data: %w",
×
3849
                        err)
×
3850
        }
×
3851

3852
        for tlvType, value := range extra {
×
3853
                err := db.CreateChannelExtraType(
×
3854
                        ctx, sqlc.CreateChannelExtraTypeParams{
×
3855
                                ChannelID: dbChanID,
×
3856
                                Type:      int64(tlvType),
×
3857
                                Value:     value,
×
3858
                        },
×
3859
                )
×
3860
                if err != nil {
×
3861
                        return fmt.Errorf("unable to upsert channel(%d) extra "+
×
3862
                                "signed field(%v): %w", edge.ChannelID,
×
3863
                                tlvType, err)
×
3864
                }
×
3865
        }
3866

3867
        return nil
×
3868
}
3869

3870
// maybeCreateShellNode checks if a shell node entry exists for the
3871
// given public key. If it does not exist, then a new shell node entry is
3872
// created. The ID of the node is returned. A shell node only has a protocol
3873
// version and public key persisted.
3874
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
3875
        pubKey route.Vertex) (int64, error) {
×
3876

×
3877
        dbNode, err := db.GetNodeByPubKey(
×
3878
                ctx, sqlc.GetNodeByPubKeyParams{
×
3879
                        PubKey:  pubKey[:],
×
3880
                        Version: int16(ProtocolV1),
×
3881
                },
×
3882
        )
×
3883
        // The node exists. Return the ID.
×
3884
        if err == nil {
×
3885
                return dbNode.ID, nil
×
3886
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3887
                return 0, err
×
3888
        }
×
3889

3890
        // Otherwise, the node does not exist, so we create a shell entry for
3891
        // it.
3892
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
3893
                Version: int16(ProtocolV1),
×
3894
                PubKey:  pubKey[:],
×
3895
        })
×
3896
        if err != nil {
×
3897
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
3898
        }
×
3899

3900
        return id, nil
×
3901
}
3902

3903
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
3904
// the database. This includes deleting any existing types and then inserting
3905
// the new types.
3906
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
3907
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
3908

×
3909
        // Delete all existing extra signed fields for the channel policy.
×
3910
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
3911
        if err != nil {
×
3912
                return fmt.Errorf("unable to delete "+
×
3913
                        "existing policy extra signed fields for policy %d: %w",
×
3914
                        chanPolicyID, err)
×
3915
        }
×
3916

3917
        // Insert all new extra signed fields for the channel policy.
3918
        for tlvType, value := range extraFields {
×
3919
                err = db.InsertChanPolicyExtraType(
×
3920
                        ctx, sqlc.InsertChanPolicyExtraTypeParams{
×
3921
                                ChannelPolicyID: chanPolicyID,
×
3922
                                Type:            int64(tlvType),
×
3923
                                Value:           value,
×
3924
                        },
×
3925
                )
×
3926
                if err != nil {
×
3927
                        return fmt.Errorf("unable to insert "+
×
3928
                                "channel_policy(%d) extra signed field(%v): %w",
×
3929
                                chanPolicyID, tlvType, err)
×
3930
                }
×
3931
        }
3932

3933
        return nil
×
3934
}
3935

3936
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
3937
// provided dbChanRow and also fetches any other required information
3938
// to construct the edge info.
3939
func getAndBuildEdgeInfo(ctx context.Context, db SQLQueries,
3940
        chain chainhash.Hash, dbChanID int64, dbChan sqlc.Channel, node1,
3941
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
3942

×
3943
        if dbChan.Version != int16(ProtocolV1) {
×
3944
                return nil, fmt.Errorf("unsupported channel version: %d",
×
3945
                        dbChan.Version)
×
3946
        }
×
3947

3948
        fv, extras, err := getChanFeaturesAndExtras(
×
3949
                ctx, db, dbChanID,
×
3950
        )
×
3951
        if err != nil {
×
3952
                return nil, err
×
3953
        }
×
3954

3955
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
3956
        if err != nil {
×
3957
                return nil, err
×
3958
        }
×
3959

3960
        var featureBuf bytes.Buffer
×
3961
        if err := fv.Encode(&featureBuf); err != nil {
×
3962
                return nil, fmt.Errorf("unable to encode features: %w", err)
×
3963
        }
×
3964

3965
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
3966
        if err != nil {
×
3967
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
3968
                        "fields: %w", err)
×
3969
        }
×
3970
        if recs == nil {
×
3971
                recs = make([]byte, 0)
×
3972
        }
×
3973

3974
        var btcKey1, btcKey2 route.Vertex
×
3975
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
3976
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
3977

×
3978
        channel := &models.ChannelEdgeInfo{
×
3979
                ChainHash:        chain,
×
3980
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
3981
                NodeKey1Bytes:    node1,
×
3982
                NodeKey2Bytes:    node2,
×
3983
                BitcoinKey1Bytes: btcKey1,
×
3984
                BitcoinKey2Bytes: btcKey2,
×
3985
                ChannelPoint:     *op,
×
3986
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
3987
                Features:         featureBuf.Bytes(),
×
3988
                ExtraOpaqueData:  recs,
×
3989
        }
×
3990

×
3991
        // We always set all the signatures at the same time, so we can
×
3992
        // safely check if one signature is present to determine if we have the
×
3993
        // rest of the signatures for the auth proof.
×
3994
        if len(dbChan.Bitcoin1Signature) > 0 {
×
3995
                channel.AuthProof = &models.ChannelAuthProof{
×
3996
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
3997
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
3998
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
3999
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
4000
                }
×
4001
        }
×
4002

4003
        return channel, nil
×
4004
}
4005

4006
// buildNodeVertices is a helper that converts raw node public keys
4007
// into route.Vertex instances.
4008
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4009
        route.Vertex, error) {
×
4010

×
4011
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4012
        if err != nil {
×
4013
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4014
                        "create vertex from node1 pubkey: %w", err)
×
4015
        }
×
4016

4017
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
4018
        if err != nil {
×
4019
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4020
                        "create vertex from node2 pubkey: %w", err)
×
4021
        }
×
4022

4023
        return node1Vertex, node2Vertex, nil
×
4024
}
4025

4026
// getChanFeaturesAndExtras fetches the channel features and extra TLV types
4027
// for a channel with the given ID.
4028
func getChanFeaturesAndExtras(ctx context.Context, db SQLQueries,
4029
        id int64) (*lnwire.FeatureVector, map[uint64][]byte, error) {
×
4030

×
4031
        rows, err := db.GetChannelFeaturesAndExtras(ctx, id)
×
4032
        if err != nil {
×
4033
                return nil, nil, fmt.Errorf("unable to fetch channel "+
×
4034
                        "features and extras: %w", err)
×
4035
        }
×
4036

4037
        var (
×
4038
                fv     = lnwire.EmptyFeatureVector()
×
4039
                extras = make(map[uint64][]byte)
×
4040
        )
×
4041
        for _, row := range rows {
×
4042
                if row.IsFeature {
×
4043
                        fv.Set(lnwire.FeatureBit(row.FeatureBit))
×
4044

×
4045
                        continue
×
4046
                }
4047

4048
                tlvType, ok := row.ExtraKey.(int64)
×
4049
                if !ok {
×
4050
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
4051
                                "TLV type: %T", row.ExtraKey)
×
4052
                }
×
4053

4054
                valueBytes, ok := row.Value.([]byte)
×
4055
                if !ok {
×
4056
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
4057
                                "Value: %T", row.Value)
×
4058
                }
×
4059

4060
                extras[uint64(tlvType)] = valueBytes
×
4061
        }
4062

4063
        return fv, extras, nil
×
4064
}
4065

4066
// getAndBuildChanPolicies uses the given sqlc.ChannelPolicy and also retrieves
4067
// all the extra info required to build the complete models.ChannelEdgePolicy
4068
// types. It returns two policies, which may be nil if the provided
4069
// sqlc.ChannelPolicy records are nil.
4070
func getAndBuildChanPolicies(ctx context.Context, db SQLQueries,
4071
        dbPol1, dbPol2 *sqlc.ChannelPolicy, channelID uint64, node1,
4072
        node2 route.Vertex) (*models.ChannelEdgePolicy,
4073
        *models.ChannelEdgePolicy, error) {
×
4074

×
4075
        if dbPol1 == nil && dbPol2 == nil {
×
4076
                return nil, nil, nil
×
4077
        }
×
4078

4079
        var (
×
4080
                policy1ID int64
×
4081
                policy2ID int64
×
4082
        )
×
4083
        if dbPol1 != nil {
×
4084
                policy1ID = dbPol1.ID
×
4085
        }
×
4086
        if dbPol2 != nil {
×
4087
                policy2ID = dbPol2.ID
×
4088
        }
×
4089
        rows, err := db.GetChannelPolicyExtraTypes(
×
4090
                ctx, sqlc.GetChannelPolicyExtraTypesParams{
×
4091
                        ID:   policy1ID,
×
4092
                        ID_2: policy2ID,
×
4093
                },
×
4094
        )
×
4095
        if err != nil {
×
4096
                return nil, nil, err
×
4097
        }
×
4098

4099
        var (
×
4100
                dbPol1Extras = make(map[uint64][]byte)
×
4101
                dbPol2Extras = make(map[uint64][]byte)
×
4102
        )
×
4103
        for _, row := range rows {
×
4104
                switch row.PolicyID {
×
4105
                case policy1ID:
×
4106
                        dbPol1Extras[uint64(row.Type)] = row.Value
×
4107
                case policy2ID:
×
4108
                        dbPol2Extras[uint64(row.Type)] = row.Value
×
4109
                default:
×
4110
                        return nil, nil, fmt.Errorf("unexpected policy ID %d "+
×
4111
                                "in row: %v", row.PolicyID, row)
×
4112
                }
4113
        }
4114

4115
        var pol1, pol2 *models.ChannelEdgePolicy
×
4116
        if dbPol1 != nil {
×
4117
                pol1, err = buildChanPolicy(
×
4118
                        *dbPol1, channelID, dbPol1Extras, node2, true,
×
4119
                )
×
4120
                if err != nil {
×
4121
                        return nil, nil, err
×
4122
                }
×
4123
        }
4124
        if dbPol2 != nil {
×
4125
                pol2, err = buildChanPolicy(
×
4126
                        *dbPol2, channelID, dbPol2Extras, node1, false,
×
4127
                )
×
4128
                if err != nil {
×
4129
                        return nil, nil, err
×
4130
                }
×
4131
        }
4132

4133
        return pol1, pol2, nil
×
4134
}
4135

4136
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4137
// provided sqlc.ChannelPolicy and other required information.
4138
func buildChanPolicy(dbPolicy sqlc.ChannelPolicy, channelID uint64,
4139
        extras map[uint64][]byte, toNode route.Vertex,
4140
        isNode1 bool) (*models.ChannelEdgePolicy, error) {
×
4141

×
4142
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4143
        if err != nil {
×
4144
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4145
                        "fields: %w", err)
×
4146
        }
×
4147

4148
        var msgFlags lnwire.ChanUpdateMsgFlags
×
4149
        if dbPolicy.MaxHtlcMsat.Valid {
×
4150
                msgFlags |= lnwire.ChanUpdateRequiredMaxHtlc
×
4151
        }
×
4152

4153
        var chanFlags lnwire.ChanUpdateChanFlags
×
4154
        if !isNode1 {
×
4155
                chanFlags |= lnwire.ChanUpdateDirection
×
4156
        }
×
4157
        if dbPolicy.Disabled.Bool {
×
4158
                chanFlags |= lnwire.ChanUpdateDisabled
×
4159
        }
×
4160

4161
        var inboundFee fn.Option[lnwire.Fee]
×
4162
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
4163
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4164

×
4165
                inboundFee = fn.Some(lnwire.Fee{
×
4166
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4167
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4168
                })
×
4169
        }
×
4170

4171
        return &models.ChannelEdgePolicy{
×
4172
                SigBytes:  dbPolicy.Signature,
×
4173
                ChannelID: channelID,
×
4174
                LastUpdate: time.Unix(
×
4175
                        dbPolicy.LastUpdate.Int64, 0,
×
4176
                ),
×
4177
                MessageFlags:  msgFlags,
×
4178
                ChannelFlags:  chanFlags,
×
4179
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
4180
                MinHTLC: lnwire.MilliSatoshi(
×
4181
                        dbPolicy.MinHtlcMsat,
×
4182
                ),
×
4183
                MaxHTLC: lnwire.MilliSatoshi(
×
4184
                        dbPolicy.MaxHtlcMsat.Int64,
×
4185
                ),
×
4186
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4187
                        dbPolicy.BaseFeeMsat,
×
4188
                ),
×
4189
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4190
                ToNode:                    toNode,
×
4191
                InboundFee:                inboundFee,
×
4192
                ExtraOpaqueData:           recs,
×
4193
        }, nil
×
4194
}
4195

4196
// buildNodes builds the models.LightningNode instances for the
4197
// given row which is expected to be a sqlc type that contains node information.
4198
func buildNodes(ctx context.Context, db SQLQueries, dbNode1,
4199
        dbNode2 sqlc.Node) (*models.LightningNode, *models.LightningNode,
4200
        error) {
×
4201

×
4202
        node1, err := buildNode(ctx, db, &dbNode1)
×
4203
        if err != nil {
×
4204
                return nil, nil, err
×
4205
        }
×
4206

4207
        node2, err := buildNode(ctx, db, &dbNode2)
×
4208
        if err != nil {
×
4209
                return nil, nil, err
×
4210
        }
×
4211

4212
        return node1, node2, nil
×
4213
}
4214

4215
// extractChannelPolicies extracts the sqlc.ChannelPolicy records from the give
4216
// row which is expected to be a sqlc type that contains channel policy
4217
// information. It returns two policies, which may be nil if the policy
4218
// information is not present in the row.
4219
//
4220
//nolint:ll,dupl,funlen
4221
func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
4222
        error) {
×
4223

×
4224
        var policy1, policy2 *sqlc.ChannelPolicy
×
4225
        switch r := row.(type) {
×
4226
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
4227
                if r.Policy1ID.Valid {
×
4228
                        policy1 = &sqlc.ChannelPolicy{
×
4229
                                ID:                      r.Policy1ID.Int64,
×
4230
                                Version:                 r.Policy1Version.Int16,
×
4231
                                ChannelID:               r.Channel.ID,
×
4232
                                NodeID:                  r.Policy1NodeID.Int64,
×
4233
                                Timelock:                r.Policy1Timelock.Int32,
×
4234
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4235
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4236
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4237
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4238
                                LastUpdate:              r.Policy1LastUpdate,
×
4239
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4240
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4241
                                Disabled:                r.Policy1Disabled,
×
4242
                                Signature:               r.Policy1Signature,
×
4243
                        }
×
4244
                }
×
4245
                if r.Policy2ID.Valid {
×
4246
                        policy2 = &sqlc.ChannelPolicy{
×
4247
                                ID:                      r.Policy2ID.Int64,
×
4248
                                Version:                 r.Policy2Version.Int16,
×
4249
                                ChannelID:               r.Channel.ID,
×
4250
                                NodeID:                  r.Policy2NodeID.Int64,
×
4251
                                Timelock:                r.Policy2Timelock.Int32,
×
4252
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4253
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4254
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4255
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4256
                                LastUpdate:              r.Policy2LastUpdate,
×
4257
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4258
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4259
                                Disabled:                r.Policy2Disabled,
×
4260
                                Signature:               r.Policy2Signature,
×
4261
                        }
×
4262
                }
×
4263

4264
                return policy1, policy2, nil
×
4265

4266
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4267
                if r.Policy1ID.Valid {
×
4268
                        policy1 = &sqlc.ChannelPolicy{
×
4269
                                ID:                      r.Policy1ID.Int64,
×
4270
                                Version:                 r.Policy1Version.Int16,
×
4271
                                ChannelID:               r.Channel.ID,
×
4272
                                NodeID:                  r.Policy1NodeID.Int64,
×
4273
                                Timelock:                r.Policy1Timelock.Int32,
×
4274
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4275
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4276
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4277
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4278
                                LastUpdate:              r.Policy1LastUpdate,
×
4279
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4280
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4281
                                Disabled:                r.Policy1Disabled,
×
4282
                                Signature:               r.Policy1Signature,
×
4283
                        }
×
4284
                }
×
4285
                if r.Policy2ID.Valid {
×
4286
                        policy2 = &sqlc.ChannelPolicy{
×
4287
                                ID:                      r.Policy2ID.Int64,
×
4288
                                Version:                 r.Policy2Version.Int16,
×
4289
                                ChannelID:               r.Channel.ID,
×
4290
                                NodeID:                  r.Policy2NodeID.Int64,
×
4291
                                Timelock:                r.Policy2Timelock.Int32,
×
4292
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4293
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4294
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4295
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4296
                                LastUpdate:              r.Policy2LastUpdate,
×
4297
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4298
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4299
                                Disabled:                r.Policy2Disabled,
×
4300
                                Signature:               r.Policy2Signature,
×
4301
                        }
×
4302
                }
×
4303

4304
                return policy1, policy2, nil
×
4305

4306
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4307
                if r.Policy1ID.Valid {
×
4308
                        policy1 = &sqlc.ChannelPolicy{
×
4309
                                ID:                      r.Policy1ID.Int64,
×
4310
                                Version:                 r.Policy1Version.Int16,
×
4311
                                ChannelID:               r.Channel.ID,
×
4312
                                NodeID:                  r.Policy1NodeID.Int64,
×
4313
                                Timelock:                r.Policy1Timelock.Int32,
×
4314
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4315
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4316
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4317
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4318
                                LastUpdate:              r.Policy1LastUpdate,
×
4319
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4320
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4321
                                Disabled:                r.Policy1Disabled,
×
4322
                                Signature:               r.Policy1Signature,
×
4323
                        }
×
4324
                }
×
4325
                if r.Policy2ID.Valid {
×
4326
                        policy2 = &sqlc.ChannelPolicy{
×
4327
                                ID:                      r.Policy2ID.Int64,
×
4328
                                Version:                 r.Policy2Version.Int16,
×
4329
                                ChannelID:               r.Channel.ID,
×
4330
                                NodeID:                  r.Policy2NodeID.Int64,
×
4331
                                Timelock:                r.Policy2Timelock.Int32,
×
4332
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4333
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4334
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4335
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4336
                                LastUpdate:              r.Policy2LastUpdate,
×
4337
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4338
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4339
                                Disabled:                r.Policy2Disabled,
×
4340
                                Signature:               r.Policy2Signature,
×
4341
                        }
×
4342
                }
×
4343

4344
                return policy1, policy2, nil
×
4345

4346
        case sqlc.ListChannelsByNodeIDRow:
×
4347
                if r.Policy1ID.Valid {
×
4348
                        policy1 = &sqlc.ChannelPolicy{
×
4349
                                ID:                      r.Policy1ID.Int64,
×
4350
                                Version:                 r.Policy1Version.Int16,
×
4351
                                ChannelID:               r.Channel.ID,
×
4352
                                NodeID:                  r.Policy1NodeID.Int64,
×
4353
                                Timelock:                r.Policy1Timelock.Int32,
×
4354
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4355
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4356
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4357
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4358
                                LastUpdate:              r.Policy1LastUpdate,
×
4359
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4360
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4361
                                Disabled:                r.Policy1Disabled,
×
4362
                                Signature:               r.Policy1Signature,
×
4363
                        }
×
4364
                }
×
4365
                if r.Policy2ID.Valid {
×
4366
                        policy2 = &sqlc.ChannelPolicy{
×
4367
                                ID:                      r.Policy2ID.Int64,
×
4368
                                Version:                 r.Policy2Version.Int16,
×
4369
                                ChannelID:               r.Channel.ID,
×
4370
                                NodeID:                  r.Policy2NodeID.Int64,
×
4371
                                Timelock:                r.Policy2Timelock.Int32,
×
4372
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4373
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4374
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4375
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4376
                                LastUpdate:              r.Policy2LastUpdate,
×
4377
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4378
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4379
                                Disabled:                r.Policy2Disabled,
×
4380
                                Signature:               r.Policy2Signature,
×
4381
                        }
×
4382
                }
×
4383

4384
                return policy1, policy2, nil
×
4385

4386
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4387
                if r.Policy1ID.Valid {
×
4388
                        policy1 = &sqlc.ChannelPolicy{
×
4389
                                ID:                      r.Policy1ID.Int64,
×
4390
                                Version:                 r.Policy1Version.Int16,
×
4391
                                ChannelID:               r.Channel.ID,
×
4392
                                NodeID:                  r.Policy1NodeID.Int64,
×
4393
                                Timelock:                r.Policy1Timelock.Int32,
×
4394
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4395
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4396
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4397
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4398
                                LastUpdate:              r.Policy1LastUpdate,
×
4399
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4400
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4401
                                Disabled:                r.Policy1Disabled,
×
4402
                                Signature:               r.Policy1Signature,
×
4403
                        }
×
4404
                }
×
4405
                if r.Policy2ID.Valid {
×
4406
                        policy2 = &sqlc.ChannelPolicy{
×
4407
                                ID:                      r.Policy2ID.Int64,
×
4408
                                Version:                 r.Policy2Version.Int16,
×
4409
                                ChannelID:               r.Channel.ID,
×
4410
                                NodeID:                  r.Policy2NodeID.Int64,
×
4411
                                Timelock:                r.Policy2Timelock.Int32,
×
4412
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4413
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4414
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4415
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4416
                                LastUpdate:              r.Policy2LastUpdate,
×
4417
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4418
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4419
                                Disabled:                r.Policy2Disabled,
×
4420
                                Signature:               r.Policy2Signature,
×
4421
                        }
×
4422
                }
×
4423

4424
                return policy1, policy2, nil
×
4425
        default:
×
4426
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
4427
                        "extractChannelPolicies: %T", r)
×
4428
        }
4429
}
4430

4431
// channelIDToBytes converts a channel ID (SCID) to a byte array
4432
// representation.
4433
func channelIDToBytes(channelID uint64) [8]byte {
×
4434
        var chanIDB [8]byte
×
4435
        byteOrder.PutUint64(chanIDB[:], channelID)
×
4436

×
4437
        return chanIDB
×
4438
}
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc