• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 16312565331

16 Jul 2025 06:54AM UTC coverage: 67.221% (-0.1%) from 67.321%
16312565331

Pull #10081

github

web-flow
Merge 5d9bf9bf6 into 9059a4e7b
Pull Request #10081: graph/db: use `/*SLICE:<field_name>*/` to optimise various graph queries

0 of 378 new or added lines in 4 files covered. (0.0%)

307 existing lines in 26 files now uncovered.

135405 of 201432 relevant lines covered (67.22%)

21751.77 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "maps"
11
        "math"
12
        "net"
13
        "slices"
14
        "strconv"
15
        "sync"
16
        "time"
17

18
        "github.com/btcsuite/btcd/btcec/v2"
19
        "github.com/btcsuite/btcd/btcutil"
20
        "github.com/btcsuite/btcd/chaincfg/chainhash"
21
        "github.com/btcsuite/btcd/wire"
22
        "github.com/lightningnetwork/lnd/aliasmgr"
23
        "github.com/lightningnetwork/lnd/batch"
24
        "github.com/lightningnetwork/lnd/fn/v2"
25
        "github.com/lightningnetwork/lnd/graph/db/models"
26
        "github.com/lightningnetwork/lnd/lnwire"
27
        "github.com/lightningnetwork/lnd/routing/route"
28
        "github.com/lightningnetwork/lnd/sqldb"
29
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
30
        "github.com/lightningnetwork/lnd/tlv"
31
        "github.com/lightningnetwork/lnd/tor"
32
)
33

34
// pageSize is the limit for the number of records that can be returned
35
// in a paginated query. This can be tuned after some benchmarks.
36
const pageSize = 2000
37

38
// ProtocolVersion is an enum that defines the gossip protocol version of a
39
// message.
40
type ProtocolVersion uint8
41

42
const (
43
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
44
        ProtocolV1 ProtocolVersion = 1
45
)
46

47
// String returns a string representation of the protocol version.
48
func (v ProtocolVersion) String() string {
×
49
        return fmt.Sprintf("V%d", v)
×
50
}
×
51

52
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
53
// execute queries against the SQL graph tables.
54
//
55
//nolint:ll,interfacebloat
56
type SQLQueries interface {
57
        /*
58
                Node queries.
59
        */
60
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
61
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.GraphNode, error)
62
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
63
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.GraphNode, error)
64
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.GraphNode, error)
65
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
66
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
67
        DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error)
68
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
69
        DeleteNode(ctx context.Context, id int64) error
70

71
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeExtraType, error)
72
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
73
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
74

75
        InsertNodeAddress(ctx context.Context, arg sqlc.InsertNodeAddressParams) error
76
        GetNodeAddressesByPubKey(ctx context.Context, arg sqlc.GetNodeAddressesByPubKeyParams) ([]sqlc.GetNodeAddressesByPubKeyRow, error)
77
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
78

79
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
80
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeFeature, error)
81
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
82
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
83

84
        /*
85
                Source node queries.
86
        */
87
        AddSourceNode(ctx context.Context, nodeID int64) error
88
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
89

90
        /*
91
                Channel queries.
92
        */
93
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
94
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
95
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.GraphChannel, error)
96
        GetChannelsBySCIDs(ctx context.Context, arg sqlc.GetChannelsBySCIDsParams) ([]sqlc.GraphChannel, error)
97
        GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]sqlc.GetChannelsByOutpointsRow, error)
98
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
99
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
100
        GetChannelsBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelsBySCIDWithPoliciesParams) ([]sqlc.GetChannelsBySCIDWithPoliciesRow, error)
101
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
102
        GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]sqlc.GetChannelFeaturesAndExtrasRow, error)
103
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
104
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
105
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
106
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
107
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
108
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
109
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.GraphChannel, error)
110
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
111
        DeleteChannel(ctx context.Context, id int64) error
112

113
        CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
114
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
115

116
        /*
117
                Channel Policy table queries.
118
        */
119
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
120
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.GraphChannelPolicy, error)
121
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
122

123
        InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error
124
        GetChannelPolicyExtraTypes(ctx context.Context, arg sqlc.GetChannelPolicyExtraTypesParams) ([]sqlc.GetChannelPolicyExtraTypesRow, error)
125
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
126

127
        /*
128
                Zombie index queries.
129
        */
130
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
131
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.GraphZombieChannel, error)
132
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
133
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
134
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
135

136
        /*
137
                Prune log table queries.
138
        */
139
        GetPruneTip(ctx context.Context) (sqlc.GraphPruneLog, error)
140
        GetPruneHashByHeight(ctx context.Context, blockHeight int64) ([]byte, error)
141
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
142
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
143

144
        /*
145
                Closed SCID table queries.
146
        */
147
        InsertClosedChannel(ctx context.Context, scid []byte) error
148
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
149
}
150

151
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
152
// database operations.
153
type BatchedSQLQueries interface {
154
        SQLQueries
155
        sqldb.BatchedTx[SQLQueries]
156
}
157

158
// SQLStore is an implementation of the V1Store interface that uses a SQL
159
// database as the backend.
160
type SQLStore struct {
161
        cfg *SQLStoreConfig
162
        db  BatchedSQLQueries
163

164
        // cacheMu guards all caches (rejectCache and chanCache). If
165
        // this mutex will be acquired at the same time as the DB mutex then
166
        // the cacheMu MUST be acquired first to prevent deadlock.
167
        cacheMu     sync.RWMutex
168
        rejectCache *rejectCache
169
        chanCache   *channelCache
170

171
        chanScheduler batch.Scheduler[SQLQueries]
172
        nodeScheduler batch.Scheduler[SQLQueries]
173

174
        srcNodes  map[ProtocolVersion]*srcNodeInfo
175
        srcNodeMu sync.Mutex
176
}
177

178
// A compile-time assertion to ensure that SQLStore implements the V1Store
179
// interface.
180
var _ V1Store = (*SQLStore)(nil)
181

182
// SQLStoreConfig holds the configuration for the SQLStore.
183
type SQLStoreConfig struct {
184
        // ChainHash is the genesis hash for the chain that all the gossip
185
        // messages in this store are aimed at.
186
        ChainHash chainhash.Hash
187

188
        // PaginationCfg is the configuration for paginated queries.
189
        PaginationCfg *sqldb.PagedQueryConfig
190
}
191

192
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
193
// storage backend.
194
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
195
        options ...StoreOptionModifier) (*SQLStore, error) {
×
196

×
197
        opts := DefaultOptions()
×
198
        for _, o := range options {
×
199
                o(opts)
×
200
        }
×
201

202
        if opts.NoMigration {
×
203
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
204
                        "supported for SQL stores")
×
205
        }
×
206

207
        s := &SQLStore{
×
208
                cfg:         cfg,
×
209
                db:          db,
×
210
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
211
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
212
                srcNodes:    make(map[ProtocolVersion]*srcNodeInfo),
×
213
        }
×
214

×
215
        s.chanScheduler = batch.NewTimeScheduler(
×
216
                db, &s.cacheMu, opts.BatchCommitInterval,
×
217
        )
×
218
        s.nodeScheduler = batch.NewTimeScheduler(
×
219
                db, nil, opts.BatchCommitInterval,
×
220
        )
×
221

×
222
        return s, nil
×
223
}
224

225
// AddLightningNode adds a vertex/node to the graph database. If the node is not
226
// in the database from before, this will add a new, unconnected one to the
227
// graph. If it is present from before, this will update that node's
228
// information.
229
//
230
// NOTE: part of the V1Store interface.
231
func (s *SQLStore) AddLightningNode(ctx context.Context,
232
        node *models.LightningNode, opts ...batch.SchedulerOption) error {
×
233

×
234
        r := &batch.Request[SQLQueries]{
×
235
                Opts: batch.NewSchedulerOptions(opts...),
×
236
                Do: func(queries SQLQueries) error {
×
237
                        _, err := upsertNode(ctx, queries, node)
×
238
                        return err
×
239
                },
×
240
        }
241

242
        return s.nodeScheduler.Execute(ctx, r)
×
243
}
244

245
// FetchLightningNode attempts to look up a target node by its identity public
246
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
247
// returned.
248
//
249
// NOTE: part of the V1Store interface.
250
func (s *SQLStore) FetchLightningNode(ctx context.Context,
251
        pubKey route.Vertex) (*models.LightningNode, error) {
×
252

×
253
        var node *models.LightningNode
×
254
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
255
                var err error
×
256
                _, node, err = getNodeByPubKey(ctx, db, pubKey)
×
257

×
258
                return err
×
259
        }, sqldb.NoOpReset)
×
260
        if err != nil {
×
261
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
262
        }
×
263

264
        return node, nil
×
265
}
266

267
// HasLightningNode determines if the graph has a vertex identified by the
268
// target node identity public key. If the node exists in the database, a
269
// timestamp of when the data for the node was lasted updated is returned along
270
// with a true boolean. Otherwise, an empty time.Time is returned with a false
271
// boolean.
272
//
273
// NOTE: part of the V1Store interface.
274
func (s *SQLStore) HasLightningNode(ctx context.Context,
275
        pubKey [33]byte) (time.Time, bool, error) {
×
276

×
277
        var (
×
278
                exists     bool
×
279
                lastUpdate time.Time
×
280
        )
×
281
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
282
                dbNode, err := db.GetNodeByPubKey(
×
283
                        ctx, sqlc.GetNodeByPubKeyParams{
×
284
                                Version: int16(ProtocolV1),
×
285
                                PubKey:  pubKey[:],
×
286
                        },
×
287
                )
×
288
                if errors.Is(err, sql.ErrNoRows) {
×
289
                        return nil
×
290
                } else if err != nil {
×
291
                        return fmt.Errorf("unable to fetch node: %w", err)
×
292
                }
×
293

294
                exists = true
×
295

×
296
                if dbNode.LastUpdate.Valid {
×
297
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
298
                }
×
299

300
                return nil
×
301
        }, sqldb.NoOpReset)
302
        if err != nil {
×
303
                return time.Time{}, false,
×
304
                        fmt.Errorf("unable to fetch node: %w", err)
×
305
        }
×
306

307
        return lastUpdate, exists, nil
×
308
}
309

310
// AddrsForNode returns all known addresses for the target node public key
311
// that the graph DB is aware of. The returned boolean indicates if the
312
// given node is unknown to the graph DB or not.
313
//
314
// NOTE: part of the V1Store interface.
315
func (s *SQLStore) AddrsForNode(ctx context.Context,
316
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
317

×
318
        var (
×
319
                addresses []net.Addr
×
320
                known     bool
×
321
        )
×
322
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
323
                var err error
×
324
                known, addresses, err = getNodeAddresses(
×
325
                        ctx, db, nodePub.SerializeCompressed(),
×
326
                )
×
327
                if err != nil {
×
328
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
329
                                err)
×
330
                }
×
331

332
                return nil
×
333
        }, sqldb.NoOpReset)
334
        if err != nil {
×
335
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
336
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
337
        }
×
338

339
        return known, addresses, nil
×
340
}
341

342
// DeleteLightningNode starts a new database transaction to remove a vertex/node
343
// from the database according to the node's public key.
344
//
345
// NOTE: part of the V1Store interface.
346
func (s *SQLStore) DeleteLightningNode(ctx context.Context,
347
        pubKey route.Vertex) error {
×
348

×
349
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
350
                res, err := db.DeleteNodeByPubKey(
×
351
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
352
                                Version: int16(ProtocolV1),
×
353
                                PubKey:  pubKey[:],
×
354
                        },
×
355
                )
×
356
                if err != nil {
×
357
                        return err
×
358
                }
×
359

360
                rows, err := res.RowsAffected()
×
361
                if err != nil {
×
362
                        return err
×
363
                }
×
364

365
                if rows == 0 {
×
366
                        return ErrGraphNodeNotFound
×
367
                } else if rows > 1 {
×
368
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
369
                }
×
370

371
                return err
×
372
        }, sqldb.NoOpReset)
373
        if err != nil {
×
374
                return fmt.Errorf("unable to delete node: %w", err)
×
375
        }
×
376

377
        return nil
×
378
}
379

380
// FetchNodeFeatures returns the features of the given node. If no features are
381
// known for the node, an empty feature vector is returned.
382
//
383
// NOTE: this is part of the graphdb.NodeTraverser interface.
384
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
385
        *lnwire.FeatureVector, error) {
×
386

×
387
        ctx := context.TODO()
×
388

×
389
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
390
}
×
391

392
// DisabledChannelIDs returns the channel ids of disabled channels.
393
// A channel is disabled when two of the associated ChanelEdgePolicies
394
// have their disabled bit on.
395
//
396
// NOTE: part of the V1Store interface.
397
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
398
        var (
×
399
                ctx     = context.TODO()
×
400
                chanIDs []uint64
×
401
        )
×
402
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
403
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
404
                if err != nil {
×
405
                        return fmt.Errorf("unable to fetch disabled "+
×
406
                                "channels: %w", err)
×
407
                }
×
408

409
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
410

×
411
                return nil
×
412
        }, sqldb.NoOpReset)
413
        if err != nil {
×
414
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
415
                        err)
×
416
        }
×
417

418
        return chanIDs, nil
×
419
}
420

421
// LookupAlias attempts to return the alias as advertised by the target node.
422
//
423
// NOTE: part of the V1Store interface.
424
func (s *SQLStore) LookupAlias(ctx context.Context,
425
        pub *btcec.PublicKey) (string, error) {
×
426

×
427
        var alias string
×
428
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
429
                dbNode, err := db.GetNodeByPubKey(
×
430
                        ctx, sqlc.GetNodeByPubKeyParams{
×
431
                                Version: int16(ProtocolV1),
×
432
                                PubKey:  pub.SerializeCompressed(),
×
433
                        },
×
434
                )
×
435
                if errors.Is(err, sql.ErrNoRows) {
×
436
                        return ErrNodeAliasNotFound
×
437
                } else if err != nil {
×
438
                        return fmt.Errorf("unable to fetch node: %w", err)
×
439
                }
×
440

441
                if !dbNode.Alias.Valid {
×
442
                        return ErrNodeAliasNotFound
×
443
                }
×
444

445
                alias = dbNode.Alias.String
×
446

×
447
                return nil
×
448
        }, sqldb.NoOpReset)
449
        if err != nil {
×
450
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
451
        }
×
452

453
        return alias, nil
×
454
}
455

456
// SourceNode returns the source node of the graph. The source node is treated
457
// as the center node within a star-graph. This method may be used to kick off
458
// a path finding algorithm in order to explore the reachability of another
459
// node based off the source node.
460
//
461
// NOTE: part of the V1Store interface.
462
func (s *SQLStore) SourceNode(ctx context.Context) (*models.LightningNode,
463
        error) {
×
464

×
465
        var node *models.LightningNode
×
466
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
467
                _, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
468
                if err != nil {
×
469
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
470
                                err)
×
471
                }
×
472

473
                _, node, err = getNodeByPubKey(ctx, db, nodePub)
×
474

×
475
                return err
×
476
        }, sqldb.NoOpReset)
477
        if err != nil {
×
478
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
479
        }
×
480

481
        return node, nil
×
482
}
483

484
// SetSourceNode sets the source node within the graph database. The source
485
// node is to be used as the center of a star-graph within path finding
486
// algorithms.
487
//
488
// NOTE: part of the V1Store interface.
489
func (s *SQLStore) SetSourceNode(ctx context.Context,
490
        node *models.LightningNode) error {
×
491

×
492
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
493
                id, err := upsertNode(ctx, db, node)
×
494
                if err != nil {
×
495
                        return fmt.Errorf("unable to upsert source node: %w",
×
496
                                err)
×
497
                }
×
498

499
                // Make sure that if a source node for this version is already
500
                // set, then the ID is the same as the one we are about to set.
501
                dbSourceNodeID, _, err := s.getSourceNode(ctx, db, ProtocolV1)
×
502
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
503
                        return fmt.Errorf("unable to fetch source node: %w",
×
504
                                err)
×
505
                } else if err == nil {
×
506
                        if dbSourceNodeID != id {
×
507
                                return fmt.Errorf("v1 source node already "+
×
508
                                        "set to a different node: %d vs %d",
×
509
                                        dbSourceNodeID, id)
×
510
                        }
×
511

512
                        return nil
×
513
                }
514

515
                return db.AddSourceNode(ctx, id)
×
516
        }, sqldb.NoOpReset)
517
}
518

519
// NodeUpdatesInHorizon returns all the known lightning node which have an
520
// update timestamp within the passed range. This method can be used by two
521
// nodes to quickly determine if they have the same set of up to date node
522
// announcements.
523
//
524
// NOTE: This is part of the V1Store interface.
525
func (s *SQLStore) NodeUpdatesInHorizon(startTime,
526
        endTime time.Time) ([]models.LightningNode, error) {
×
527

×
528
        ctx := context.TODO()
×
529

×
530
        var nodes []models.LightningNode
×
531
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
532
                dbNodes, err := db.GetNodesByLastUpdateRange(
×
533
                        ctx, sqlc.GetNodesByLastUpdateRangeParams{
×
534
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
535
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
536
                        },
×
537
                )
×
538
                if err != nil {
×
539
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
540
                }
×
541

542
                for _, dbNode := range dbNodes {
×
543
                        node, err := buildNode(ctx, db, &dbNode)
×
544
                        if err != nil {
×
545
                                return fmt.Errorf("unable to build node: %w",
×
546
                                        err)
×
547
                        }
×
548

549
                        nodes = append(nodes, *node)
×
550
                }
551

552
                return nil
×
553
        }, sqldb.NoOpReset)
554
        if err != nil {
×
555
                return nil, fmt.Errorf("unable to fetch nodes: %w", err)
×
556
        }
×
557

558
        return nodes, nil
×
559
}
560

561
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
562
// undirected edge from the two target nodes are created. The information stored
563
// denotes the static attributes of the channel, such as the channelID, the keys
564
// involved in creation of the channel, and the set of features that the channel
565
// supports. The chanPoint and chanID are used to uniquely identify the edge
566
// globally within the database.
567
//
568
// NOTE: part of the V1Store interface.
569
func (s *SQLStore) AddChannelEdge(ctx context.Context,
570
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
571

×
572
        var alreadyExists bool
×
573
        r := &batch.Request[SQLQueries]{
×
574
                Opts: batch.NewSchedulerOptions(opts...),
×
575
                Reset: func() {
×
576
                        alreadyExists = false
×
577
                },
×
578
                Do: func(tx SQLQueries) error {
×
579
                        _, err := insertChannel(ctx, tx, edge)
×
580

×
581
                        // Silence ErrEdgeAlreadyExist so that the batch can
×
582
                        // succeed, but propagate the error via local state.
×
583
                        if errors.Is(err, ErrEdgeAlreadyExist) {
×
584
                                alreadyExists = true
×
585
                                return nil
×
586
                        }
×
587

588
                        return err
×
589
                },
590
                OnCommit: func(err error) error {
×
591
                        switch {
×
592
                        case err != nil:
×
593
                                return err
×
594
                        case alreadyExists:
×
595
                                return ErrEdgeAlreadyExist
×
596
                        default:
×
597
                                s.rejectCache.remove(edge.ChannelID)
×
598
                                s.chanCache.remove(edge.ChannelID)
×
599
                                return nil
×
600
                        }
601
                },
602
        }
603

604
        return s.chanScheduler.Execute(ctx, r)
×
605
}
606

607
// HighestChanID returns the "highest" known channel ID in the channel graph.
608
// This represents the "newest" channel from the PoV of the chain. This method
609
// can be used by peers to quickly determine if their graphs are in sync.
610
//
611
// NOTE: This is part of the V1Store interface.
612
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
613
        var highestChanID uint64
×
614
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
615
                chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
×
616
                if errors.Is(err, sql.ErrNoRows) {
×
617
                        return nil
×
618
                } else if err != nil {
×
619
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
620
                                err)
×
621
                }
×
622

623
                highestChanID = byteOrder.Uint64(chanID)
×
624

×
625
                return nil
×
626
        }, sqldb.NoOpReset)
627
        if err != nil {
×
628
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
629
        }
×
630

631
        return highestChanID, nil
×
632
}
633

634
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
635
// within the database for the referenced channel. The `flags` attribute within
636
// the ChannelEdgePolicy determines which of the directed edges are being
637
// updated. If the flag is 1, then the first node's information is being
638
// updated, otherwise it's the second node's information. The node ordering is
639
// determined by the lexicographical ordering of the identity public keys of the
640
// nodes on either side of the channel.
641
//
642
// NOTE: part of the V1Store interface.
643
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
644
        edge *models.ChannelEdgePolicy,
645
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
646

×
647
        var (
×
648
                isUpdate1    bool
×
649
                edgeNotFound bool
×
650
                from, to     route.Vertex
×
651
        )
×
652

×
653
        r := &batch.Request[SQLQueries]{
×
654
                Opts: batch.NewSchedulerOptions(opts...),
×
655
                Reset: func() {
×
656
                        isUpdate1 = false
×
657
                        edgeNotFound = false
×
658
                },
×
659
                Do: func(tx SQLQueries) error {
×
660
                        var err error
×
661
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
662
                                ctx, tx, edge,
×
663
                        )
×
664
                        if err != nil {
×
665
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
666
                        }
×
667

668
                        // Silence ErrEdgeNotFound so that the batch can
669
                        // succeed, but propagate the error via local state.
670
                        if errors.Is(err, ErrEdgeNotFound) {
×
671
                                edgeNotFound = true
×
672
                                return nil
×
673
                        }
×
674

675
                        return err
×
676
                },
677
                OnCommit: func(err error) error {
×
678
                        switch {
×
679
                        case err != nil:
×
680
                                return err
×
681
                        case edgeNotFound:
×
682
                                return ErrEdgeNotFound
×
683
                        default:
×
684
                                s.updateEdgeCache(edge, isUpdate1)
×
685
                                return nil
×
686
                        }
687
                },
688
        }
689

690
        err := s.chanScheduler.Execute(ctx, r)
×
691

×
692
        return from, to, err
×
693
}
694

695
// updateEdgeCache updates our reject and channel caches with the new
696
// edge policy information.
697
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
698
        isUpdate1 bool) {
×
699

×
700
        // If an entry for this channel is found in reject cache, we'll modify
×
701
        // the entry with the updated timestamp for the direction that was just
×
702
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
703
        // during the next query for this edge.
×
704
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
705
                if isUpdate1 {
×
706
                        entry.upd1Time = e.LastUpdate.Unix()
×
707
                } else {
×
708
                        entry.upd2Time = e.LastUpdate.Unix()
×
709
                }
×
710
                s.rejectCache.insert(e.ChannelID, entry)
×
711
        }
712

713
        // If an entry for this channel is found in channel cache, we'll modify
714
        // the entry with the updated policy for the direction that was just
715
        // written. If the edge doesn't exist, we'll defer loading the info and
716
        // policies and lazily read from disk during the next query.
717
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
718
                if isUpdate1 {
×
719
                        channel.Policy1 = e
×
720
                } else {
×
721
                        channel.Policy2 = e
×
722
                }
×
723
                s.chanCache.insert(e.ChannelID, channel)
×
724
        }
725
}
726

727
// ForEachSourceNodeChannel iterates through all channels of the source node,
728
// executing the passed callback on each. The call-back is provided with the
729
// channel's outpoint, whether we have a policy for the channel and the channel
730
// peer's node information.
731
//
732
// NOTE: part of the V1Store interface.
733
func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context,
734
        cb func(chanPoint wire.OutPoint, havePolicy bool,
735
                otherNode *models.LightningNode) error, reset func()) error {
×
736

×
737
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
738
                nodeID, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
739
                if err != nil {
×
740
                        return fmt.Errorf("unable to fetch source node: %w",
×
741
                                err)
×
742
                }
×
743

744
                return forEachNodeChannel(
×
745
                        ctx, db, s.cfg.ChainHash, nodeID,
×
746
                        func(info *models.ChannelEdgeInfo,
×
747
                                outPolicy *models.ChannelEdgePolicy,
×
748
                                _ *models.ChannelEdgePolicy) error {
×
749

×
750
                                // Fetch the other node.
×
751
                                var (
×
752
                                        otherNodePub [33]byte
×
753
                                        node1        = info.NodeKey1Bytes
×
754
                                        node2        = info.NodeKey2Bytes
×
755
                                )
×
756
                                switch {
×
757
                                case bytes.Equal(node1[:], nodePub[:]):
×
758
                                        otherNodePub = node2
×
759
                                case bytes.Equal(node2[:], nodePub[:]):
×
760
                                        otherNodePub = node1
×
761
                                default:
×
762
                                        return fmt.Errorf("node not " +
×
763
                                                "participating in this channel")
×
764
                                }
765

766
                                _, otherNode, err := getNodeByPubKey(
×
767
                                        ctx, db, otherNodePub,
×
768
                                )
×
769
                                if err != nil {
×
770
                                        return fmt.Errorf("unable to fetch "+
×
771
                                                "other node(%x): %w",
×
772
                                                otherNodePub, err)
×
773
                                }
×
774

775
                                return cb(
×
776
                                        info.ChannelPoint, outPolicy != nil,
×
777
                                        otherNode,
×
778
                                )
×
779
                        },
780
                )
781
        }, reset)
782
}
783

784
// ForEachNode iterates through all the stored vertices/nodes in the graph,
785
// executing the passed callback with each node encountered. If the callback
786
// returns an error, then the transaction is aborted and the iteration stops
787
// early. Any operations performed on the NodeTx passed to the call-back are
788
// executed under the same read transaction and so, methods on the NodeTx object
789
// _MUST_ only be called from within the call-back.
790
//
791
// NOTE: part of the V1Store interface.
792
func (s *SQLStore) ForEachNode(ctx context.Context,
793
        cb func(tx NodeRTx) error, reset func()) error {
×
794

×
795
        var lastID int64 = 0
×
796
        handleNode := func(db SQLQueries, dbNode sqlc.GraphNode) error {
×
797
                node, err := buildNode(ctx, db, &dbNode)
×
798
                if err != nil {
×
799
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
800
                                dbNode.ID, err)
×
801
                }
×
802

803
                err = cb(
×
804
                        newSQLGraphNodeTx(db, s.cfg.ChainHash, dbNode.ID, node),
×
805
                )
×
806
                if err != nil {
×
807
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
808
                                dbNode.ID, err)
×
809
                }
×
810

811
                return nil
×
812
        }
813

814
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
815
                for {
×
816
                        nodes, err := db.ListNodesPaginated(
×
817
                                ctx, sqlc.ListNodesPaginatedParams{
×
818
                                        Version: int16(ProtocolV1),
×
819
                                        ID:      lastID,
×
820
                                        Limit:   pageSize,
×
821
                                },
×
822
                        )
×
823
                        if err != nil {
×
824
                                return fmt.Errorf("unable to fetch nodes: %w",
×
825
                                        err)
×
826
                        }
×
827

828
                        if len(nodes) == 0 {
×
829
                                break
×
830
                        }
831

832
                        for _, dbNode := range nodes {
×
833
                                err = handleNode(db, dbNode)
×
834
                                if err != nil {
×
835
                                        return err
×
836
                                }
×
837

838
                                lastID = dbNode.ID
×
839
                        }
840
                }
841

842
                return nil
×
843
        }, reset)
844
}
845

846
// sqlGraphNodeTx is an implementation of the NodeRTx interface backed by the
847
// SQLStore and a SQL transaction.
848
type sqlGraphNodeTx struct {
849
        db    SQLQueries
850
        id    int64
851
        node  *models.LightningNode
852
        chain chainhash.Hash
853
}
854

855
// A compile-time constraint to ensure sqlGraphNodeTx implements the NodeRTx
856
// interface.
857
var _ NodeRTx = (*sqlGraphNodeTx)(nil)
858

859
func newSQLGraphNodeTx(db SQLQueries, chain chainhash.Hash,
860
        id int64, node *models.LightningNode) *sqlGraphNodeTx {
×
861

×
862
        return &sqlGraphNodeTx{
×
863
                db:    db,
×
864
                chain: chain,
×
865
                id:    id,
×
866
                node:  node,
×
867
        }
×
868
}
×
869

870
// Node returns the raw information of the node.
871
//
872
// NOTE: This is a part of the NodeRTx interface.
873
func (s *sqlGraphNodeTx) Node() *models.LightningNode {
×
874
        return s.node
×
875
}
×
876

877
// ForEachChannel can be used to iterate over the node's channels under the same
878
// transaction used to fetch the node.
879
//
880
// NOTE: This is a part of the NodeRTx interface.
881
func (s *sqlGraphNodeTx) ForEachChannel(cb func(*models.ChannelEdgeInfo,
882
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
×
883

×
884
        ctx := context.TODO()
×
885

×
886
        return forEachNodeChannel(ctx, s.db, s.chain, s.id, cb)
×
887
}
×
888

889
// FetchNode fetches the node with the given pub key under the same transaction
890
// used to fetch the current node. The returned node is also a NodeRTx and any
891
// operations on that NodeRTx will also be done under the same transaction.
892
//
893
// NOTE: This is a part of the NodeRTx interface.
894
func (s *sqlGraphNodeTx) FetchNode(nodePub route.Vertex) (NodeRTx, error) {
×
895
        ctx := context.TODO()
×
896

×
897
        id, node, err := getNodeByPubKey(ctx, s.db, nodePub)
×
898
        if err != nil {
×
899
                return nil, fmt.Errorf("unable to fetch V1 node(%x): %w",
×
900
                        nodePub, err)
×
901
        }
×
902

903
        return newSQLGraphNodeTx(s.db, s.chain, id, node), nil
×
904
}
905

906
// ForEachNodeDirectedChannel iterates through all channels of a given node,
907
// executing the passed callback on the directed edge representing the channel
908
// and its incoming policy. If the callback returns an error, then the iteration
909
// is halted with the error propagated back up to the caller.
910
//
911
// Unknown policies are passed into the callback as nil values.
912
//
913
// NOTE: this is part of the graphdb.NodeTraverser interface.
914
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
915
        cb func(channel *DirectedChannel) error, reset func()) error {
×
916

×
917
        var ctx = context.TODO()
×
918

×
919
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
920
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
921
        }, reset)
×
922
}
923

924
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
925
// graph, executing the passed callback with each node encountered. If the
926
// callback returns an error, then the transaction is aborted and the iteration
927
// stops early.
928
//
929
// NOTE: This is a part of the V1Store interface.
930
func (s *SQLStore) ForEachNodeCacheable(ctx context.Context,
931
        cb func(route.Vertex, *lnwire.FeatureVector) error,
932
        reset func()) error {
×
933

×
934
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
935
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
936
                        nodePub route.Vertex) error {
×
937

×
938
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
939
                        if err != nil {
×
940
                                return fmt.Errorf("unable to fetch node "+
×
941
                                        "features: %w", err)
×
942
                        }
×
943

944
                        return cb(nodePub, features)
×
945
                })
946
        }, reset)
947
        if err != nil {
×
948
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
949
        }
×
950

951
        return nil
×
952
}
953

954
// ForEachNodeChannel iterates through all channels of the given node,
955
// executing the passed callback with an edge info structure and the policies
956
// of each end of the channel. The first edge policy is the outgoing edge *to*
957
// the connecting node, while the second is the incoming edge *from* the
958
// connecting node. If the callback returns an error, then the iteration is
959
// halted with the error propagated back up to the caller.
960
//
961
// Unknown policies are passed into the callback as nil values.
962
//
963
// NOTE: part of the V1Store interface.
964
func (s *SQLStore) ForEachNodeChannel(ctx context.Context, nodePub route.Vertex,
965
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
966
                *models.ChannelEdgePolicy) error, reset func()) error {
×
967

×
968
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
969
                dbNode, err := db.GetNodeByPubKey(
×
970
                        ctx, sqlc.GetNodeByPubKeyParams{
×
971
                                Version: int16(ProtocolV1),
×
972
                                PubKey:  nodePub[:],
×
973
                        },
×
974
                )
×
975
                if errors.Is(err, sql.ErrNoRows) {
×
976
                        return nil
×
977
                } else if err != nil {
×
978
                        return fmt.Errorf("unable to fetch node: %w", err)
×
979
                }
×
980

981
                return forEachNodeChannel(
×
982
                        ctx, db, s.cfg.ChainHash, dbNode.ID, cb,
×
983
                )
×
984
        }, reset)
985
}
986

987
// ChanUpdatesInHorizon returns all the known channel edges which have at least
988
// one edge that has an update timestamp within the specified horizon.
989
//
990
// NOTE: This is part of the V1Store interface.
991
func (s *SQLStore) ChanUpdatesInHorizon(startTime,
992
        endTime time.Time) ([]ChannelEdge, error) {
×
993

×
994
        s.cacheMu.Lock()
×
995
        defer s.cacheMu.Unlock()
×
996

×
997
        var (
×
998
                ctx = context.TODO()
×
999
                // To ensure we don't return duplicate ChannelEdges, we'll use
×
1000
                // an additional map to keep track of the edges already seen to
×
1001
                // prevent re-adding it.
×
1002
                edgesSeen    = make(map[uint64]struct{})
×
1003
                edgesToCache = make(map[uint64]ChannelEdge)
×
1004
                edges        []ChannelEdge
×
1005
                hits         int
×
1006
        )
×
1007
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1008
                rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
1009
                        ctx, sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
1010
                                Version:   int16(ProtocolV1),
×
1011
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
1012
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
1013
                        },
×
1014
                )
×
1015
                if err != nil {
×
1016
                        return err
×
1017
                }
×
1018

1019
                for _, row := range rows {
×
1020
                        // If we've already retrieved the info and policies for
×
1021
                        // this edge, then we can skip it as we don't need to do
×
1022
                        // so again.
×
1023
                        chanIDInt := byteOrder.Uint64(row.GraphChannel.Scid)
×
1024
                        if _, ok := edgesSeen[chanIDInt]; ok {
×
1025
                                continue
×
1026
                        }
1027

1028
                        if channel, ok := s.chanCache.get(chanIDInt); ok {
×
1029
                                hits++
×
1030
                                edgesSeen[chanIDInt] = struct{}{}
×
1031
                                edges = append(edges, channel)
×
1032

×
1033
                                continue
×
1034
                        }
1035

1036
                        node1, node2, err := buildNodes(
×
1037
                                ctx, db, row.GraphNode, row.GraphNode_2,
×
1038
                        )
×
1039
                        if err != nil {
×
1040
                                return err
×
1041
                        }
×
1042

1043
                        channel, err := getAndBuildEdgeInfo(
×
1044
                                ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
1045
                                row.GraphChannel, node1.PubKeyBytes,
×
1046
                                node2.PubKeyBytes,
×
1047
                        )
×
1048
                        if err != nil {
×
1049
                                return fmt.Errorf("unable to build channel "+
×
1050
                                        "info: %w", err)
×
1051
                        }
×
1052

1053
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1054
                        if err != nil {
×
1055
                                return fmt.Errorf("unable to extract channel "+
×
1056
                                        "policies: %w", err)
×
1057
                        }
×
1058

1059
                        p1, p2, err := getAndBuildChanPolicies(
×
1060
                                ctx, db, dbPol1, dbPol2, channel.ChannelID,
×
1061
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
1062
                        )
×
1063
                        if err != nil {
×
1064
                                return fmt.Errorf("unable to build channel "+
×
1065
                                        "policies: %w", err)
×
1066
                        }
×
1067

1068
                        edgesSeen[chanIDInt] = struct{}{}
×
1069
                        chanEdge := ChannelEdge{
×
1070
                                Info:    channel,
×
1071
                                Policy1: p1,
×
1072
                                Policy2: p2,
×
1073
                                Node1:   node1,
×
1074
                                Node2:   node2,
×
1075
                        }
×
1076
                        edges = append(edges, chanEdge)
×
1077
                        edgesToCache[chanIDInt] = chanEdge
×
1078
                }
1079

1080
                return nil
×
1081
        }, func() {
×
1082
                edgesSeen = make(map[uint64]struct{})
×
1083
                edgesToCache = make(map[uint64]ChannelEdge)
×
1084
                edges = nil
×
1085
        })
×
1086
        if err != nil {
×
1087
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
1088
        }
×
1089

1090
        // Insert any edges loaded from disk into the cache.
1091
        for chanid, channel := range edgesToCache {
×
1092
                s.chanCache.insert(chanid, channel)
×
1093
        }
×
1094

1095
        if len(edges) > 0 {
×
1096
                log.Debugf("ChanUpdatesInHorizon hit percentage: %.2f (%d/%d)",
×
1097
                        float64(hits)*100/float64(len(edges)), hits, len(edges))
×
1098
        } else {
×
1099
                log.Debugf("ChanUpdatesInHorizon returned no edges in "+
×
1100
                        "horizon (%s, %s)", startTime, endTime)
×
1101
        }
×
1102

1103
        return edges, nil
×
1104
}
1105

1106
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1107
// data to the call-back.
1108
//
1109
// NOTE: The callback contents MUST not be modified.
1110
//
1111
// NOTE: part of the V1Store interface.
1112
func (s *SQLStore) ForEachNodeCached(ctx context.Context,
1113
        cb func(node route.Vertex, chans map[uint64]*DirectedChannel) error,
1114
        reset func()) error {
×
1115

×
1116
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1117
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
1118
                        nodePub route.Vertex) error {
×
1119

×
1120
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
1121
                        if err != nil {
×
1122
                                return fmt.Errorf("unable to fetch "+
×
1123
                                        "node(id=%d) features: %w", nodeID, err)
×
1124
                        }
×
1125

1126
                        toNodeCallback := func() route.Vertex {
×
1127
                                return nodePub
×
1128
                        }
×
1129

1130
                        rows, err := db.ListChannelsByNodeID(
×
1131
                                ctx, sqlc.ListChannelsByNodeIDParams{
×
1132
                                        Version: int16(ProtocolV1),
×
1133
                                        NodeID1: nodeID,
×
1134
                                },
×
1135
                        )
×
1136
                        if err != nil {
×
1137
                                return fmt.Errorf("unable to fetch channels "+
×
1138
                                        "of node(id=%d): %w", nodeID, err)
×
1139
                        }
×
1140

1141
                        channels := make(map[uint64]*DirectedChannel, len(rows))
×
1142
                        for _, row := range rows {
×
1143
                                node1, node2, err := buildNodeVertices(
×
1144
                                        row.Node1Pubkey, row.Node2Pubkey,
×
1145
                                )
×
1146
                                if err != nil {
×
1147
                                        return err
×
1148
                                }
×
1149

1150
                                e, err := getAndBuildEdgeInfo(
×
1151
                                        ctx, db, s.cfg.ChainHash,
×
1152
                                        row.GraphChannel.ID, row.GraphChannel,
×
1153
                                        node1, node2,
×
1154
                                )
×
1155
                                if err != nil {
×
1156
                                        return fmt.Errorf("unable to build "+
×
1157
                                                "channel info: %w", err)
×
1158
                                }
×
1159

1160
                                dbPol1, dbPol2, err := extractChannelPolicies(
×
1161
                                        row,
×
1162
                                )
×
1163
                                if err != nil {
×
1164
                                        return fmt.Errorf("unable to "+
×
1165
                                                "extract channel "+
×
1166
                                                "policies: %w", err)
×
1167
                                }
×
1168

1169
                                p1, p2, err := getAndBuildChanPolicies(
×
1170
                                        ctx, db, dbPol1, dbPol2, e.ChannelID,
×
1171
                                        node1, node2,
×
1172
                                )
×
1173
                                if err != nil {
×
1174
                                        return fmt.Errorf("unable to "+
×
1175
                                                "build channel policies: %w",
×
1176
                                                err)
×
1177
                                }
×
1178

1179
                                // Determine the outgoing and incoming policy
1180
                                // for this channel and node combo.
1181
                                outPolicy, inPolicy := p1, p2
×
1182
                                if p1 != nil && p1.ToNode == nodePub {
×
1183
                                        outPolicy, inPolicy = p2, p1
×
1184
                                } else if p2 != nil && p2.ToNode != nodePub {
×
1185
                                        outPolicy, inPolicy = p2, p1
×
1186
                                }
×
1187

1188
                                var cachedInPolicy *models.CachedEdgePolicy
×
1189
                                if inPolicy != nil {
×
1190
                                        cachedInPolicy = models.NewCachedPolicy(
×
1191
                                                p2,
×
1192
                                        )
×
1193
                                        cachedInPolicy.ToNodePubKey =
×
1194
                                                toNodeCallback
×
1195
                                        cachedInPolicy.ToNodeFeatures =
×
1196
                                                features
×
1197
                                }
×
1198

1199
                                var inboundFee lnwire.Fee
×
1200
                                outPolicy.InboundFee.WhenSome(
×
1201
                                        func(fee lnwire.Fee) {
×
1202
                                                inboundFee = fee
×
1203
                                        },
×
1204
                                )
1205

1206
                                directedChannel := &DirectedChannel{
×
1207
                                        ChannelID: e.ChannelID,
×
1208
                                        IsNode1: nodePub ==
×
1209
                                                e.NodeKey1Bytes,
×
1210
                                        OtherNode:    e.NodeKey2Bytes,
×
1211
                                        Capacity:     e.Capacity,
×
1212
                                        OutPolicySet: p1 != nil,
×
1213
                                        InPolicy:     cachedInPolicy,
×
1214
                                        InboundFee:   inboundFee,
×
1215
                                }
×
1216

×
1217
                                if nodePub == e.NodeKey2Bytes {
×
1218
                                        directedChannel.OtherNode =
×
1219
                                                e.NodeKey1Bytes
×
1220
                                }
×
1221

1222
                                channels[e.ChannelID] = directedChannel
×
1223
                        }
1224

1225
                        return cb(nodePub, channels)
×
1226
                })
1227
        }, reset)
1228
}
1229

1230
// ForEachChannelCacheable iterates through all the channel edges stored
1231
// within the graph and invokes the passed callback for each edge. The
1232
// callback takes two edges as since this is a directed graph, both the
1233
// in/out edges are visited. If the callback returns an error, then the
1234
// transaction is aborted and the iteration stops early.
1235
//
1236
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1237
// pointer for that particular channel edge routing policy will be
1238
// passed into the callback.
1239
//
1240
// NOTE: this method is like ForEachChannel but fetches only the data
1241
// required for the graph cache.
1242
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1243
        *models.CachedEdgePolicy, *models.CachedEdgePolicy) error,
1244
        reset func()) error {
×
1245

×
1246
        ctx := context.TODO()
×
1247

×
1248
        handleChannel := func(db SQLQueries,
×
1249
                row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
×
1250

×
1251
                node1, node2, err := buildNodeVertices(
×
1252
                        row.Node1Pubkey, row.Node2Pubkey,
×
1253
                )
×
1254
                if err != nil {
×
1255
                        return err
×
1256
                }
×
1257

1258
                edge := buildCacheableChannelInfo(
×
1259
                        row.GraphChannel, node1, node2,
×
1260
                )
×
1261

×
1262
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1263
                if err != nil {
×
1264
                        return err
×
1265
                }
×
1266

1267
                var pol1, pol2 *models.CachedEdgePolicy
×
1268
                if dbPol1 != nil {
×
1269
                        policy1, err := buildChanPolicy(
×
1270
                                *dbPol1, edge.ChannelID, nil, node2,
×
1271
                        )
×
1272
                        if err != nil {
×
1273
                                return err
×
1274
                        }
×
1275

1276
                        pol1 = models.NewCachedPolicy(policy1)
×
1277
                }
1278
                if dbPol2 != nil {
×
1279
                        policy2, err := buildChanPolicy(
×
1280
                                *dbPol2, edge.ChannelID, nil, node1,
×
1281
                        )
×
1282
                        if err != nil {
×
1283
                                return err
×
1284
                        }
×
1285

1286
                        pol2 = models.NewCachedPolicy(policy2)
×
1287
                }
1288

1289
                if err := cb(edge, pol1, pol2); err != nil {
×
1290
                        return err
×
1291
                }
×
1292

1293
                return nil
×
1294
        }
1295

1296
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1297
                lastID := int64(-1)
×
1298
                for {
×
1299
                        //nolint:ll
×
1300
                        rows, err := db.ListChannelsWithPoliciesPaginated(
×
1301
                                ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
1302
                                        Version: int16(ProtocolV1),
×
1303
                                        ID:      lastID,
×
1304
                                        Limit:   pageSize,
×
1305
                                },
×
1306
                        )
×
1307
                        if err != nil {
×
1308
                                return err
×
1309
                        }
×
1310

1311
                        if len(rows) == 0 {
×
1312
                                break
×
1313
                        }
1314

1315
                        for _, row := range rows {
×
1316
                                err := handleChannel(db, row)
×
1317
                                if err != nil {
×
1318
                                        return err
×
1319
                                }
×
1320

1321
                                lastID = row.GraphChannel.ID
×
1322
                        }
1323
                }
1324

1325
                return nil
×
1326
        }, reset)
1327
}
1328

1329
// ForEachChannel iterates through all the channel edges stored within the
1330
// graph and invokes the passed callback for each edge. The callback takes two
1331
// edges as since this is a directed graph, both the in/out edges are visited.
1332
// If the callback returns an error, then the transaction is aborted and the
1333
// iteration stops early.
1334
//
1335
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1336
// for that particular channel edge routing policy will be passed into the
1337
// callback.
1338
//
1339
// NOTE: part of the V1Store interface.
1340
func (s *SQLStore) ForEachChannel(ctx context.Context,
1341
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1342
                *models.ChannelEdgePolicy) error, reset func()) error {
×
1343

×
1344
        handleChannel := func(db SQLQueries,
×
1345
                row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
×
1346

×
1347
                node1, node2, err := buildNodeVertices(
×
1348
                        row.Node1Pubkey, row.Node2Pubkey,
×
1349
                )
×
1350
                if err != nil {
×
1351
                        return fmt.Errorf("unable to build node vertices: %w",
×
1352
                                err)
×
1353
                }
×
1354

1355
                edge, err := getAndBuildEdgeInfo(
×
1356
                        ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
1357
                        row.GraphChannel, node1, node2,
×
1358
                )
×
1359
                if err != nil {
×
1360
                        return fmt.Errorf("unable to build channel info: %w",
×
1361
                                err)
×
1362
                }
×
1363

1364
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1365
                if err != nil {
×
1366
                        return fmt.Errorf("unable to extract channel "+
×
1367
                                "policies: %w", err)
×
1368
                }
×
1369

1370
                p1, p2, err := getAndBuildChanPolicies(
×
1371
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1372
                )
×
1373
                if err != nil {
×
1374
                        return fmt.Errorf("unable to build channel "+
×
1375
                                "policies: %w", err)
×
1376
                }
×
1377

1378
                err = cb(edge, p1, p2)
×
1379
                if err != nil {
×
1380
                        return fmt.Errorf("callback failed for channel "+
×
1381
                                "id=%d: %w", edge.ChannelID, err)
×
1382
                }
×
1383

1384
                return nil
×
1385
        }
1386

1387
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1388
                lastID := int64(-1)
×
1389
                for {
×
1390
                        //nolint:ll
×
1391
                        rows, err := db.ListChannelsWithPoliciesPaginated(
×
1392
                                ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
1393
                                        Version: int16(ProtocolV1),
×
1394
                                        ID:      lastID,
×
1395
                                        Limit:   pageSize,
×
1396
                                },
×
1397
                        )
×
1398
                        if err != nil {
×
1399
                                return err
×
1400
                        }
×
1401

1402
                        if len(rows) == 0 {
×
1403
                                break
×
1404
                        }
1405

1406
                        for _, row := range rows {
×
1407
                                err := handleChannel(db, row)
×
1408
                                if err != nil {
×
1409
                                        return err
×
1410
                                }
×
1411

1412
                                lastID = row.GraphChannel.ID
×
1413
                        }
1414
                }
1415

1416
                return nil
×
1417
        }, reset)
1418
}
1419

1420
// FilterChannelRange returns the channel ID's of all known channels which were
1421
// mined in a block height within the passed range. The channel IDs are grouped
1422
// by their common block height. This method can be used to quickly share with a
1423
// peer the set of channels we know of within a particular range to catch them
1424
// up after a period of time offline. If withTimestamps is true then the
1425
// timestamp info of the latest received channel update messages of the channel
1426
// will be included in the response.
1427
//
1428
// NOTE: This is part of the V1Store interface.
1429
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1430
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1431

×
1432
        var (
×
1433
                ctx       = context.TODO()
×
1434
                startSCID = &lnwire.ShortChannelID{
×
1435
                        BlockHeight: startHeight,
×
1436
                }
×
1437
                endSCID = lnwire.ShortChannelID{
×
1438
                        BlockHeight: endHeight,
×
1439
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1440
                        TxPosition:  math.MaxUint16,
×
1441
                }
×
1442
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1443
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1444
        )
×
1445

×
1446
        // 1) get all channels where channelID is between start and end chan ID.
×
1447
        // 2) skip if not public (ie, no channel_proof)
×
1448
        // 3) collect that channel.
×
1449
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1450
        //    and add those timestamps to the collected channel.
×
1451
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1452
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1453
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1454
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1455
                                StartScid: chanIDStart,
×
1456
                                EndScid:   chanIDEnd,
×
1457
                        },
×
1458
                )
×
1459
                if err != nil {
×
1460
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1461
                                err)
×
1462
                }
×
1463

1464
                for _, dbChan := range dbChans {
×
1465
                        cid := lnwire.NewShortChanIDFromInt(
×
1466
                                byteOrder.Uint64(dbChan.Scid),
×
1467
                        )
×
1468
                        chanInfo := NewChannelUpdateInfo(
×
1469
                                cid, time.Time{}, time.Time{},
×
1470
                        )
×
1471

×
1472
                        if !withTimestamps {
×
1473
                                channelsPerBlock[cid.BlockHeight] = append(
×
1474
                                        channelsPerBlock[cid.BlockHeight],
×
1475
                                        chanInfo,
×
1476
                                )
×
1477

×
1478
                                continue
×
1479
                        }
1480

1481
                        //nolint:ll
1482
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1483
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1484
                                        Version:   int16(ProtocolV1),
×
1485
                                        ChannelID: dbChan.ID,
×
1486
                                        NodeID:    dbChan.NodeID1,
×
1487
                                },
×
1488
                        )
×
1489
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1490
                                return fmt.Errorf("unable to fetch node1 "+
×
1491
                                        "policy: %w", err)
×
1492
                        } else if err == nil {
×
1493
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1494
                                        node1Policy.LastUpdate.Int64, 0,
×
1495
                                )
×
1496
                        }
×
1497

1498
                        //nolint:ll
1499
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1500
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1501
                                        Version:   int16(ProtocolV1),
×
1502
                                        ChannelID: dbChan.ID,
×
1503
                                        NodeID:    dbChan.NodeID2,
×
1504
                                },
×
1505
                        )
×
1506
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1507
                                return fmt.Errorf("unable to fetch node2 "+
×
1508
                                        "policy: %w", err)
×
1509
                        } else if err == nil {
×
1510
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1511
                                        node2Policy.LastUpdate.Int64, 0,
×
1512
                                )
×
1513
                        }
×
1514

1515
                        channelsPerBlock[cid.BlockHeight] = append(
×
1516
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1517
                        )
×
1518
                }
1519

1520
                return nil
×
1521
        }, func() {
×
1522
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1523
        })
×
1524
        if err != nil {
×
1525
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1526
        }
×
1527

1528
        if len(channelsPerBlock) == 0 {
×
1529
                return nil, nil
×
1530
        }
×
1531

1532
        // Return the channel ranges in ascending block height order.
1533
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1534
        slices.Sort(blocks)
×
1535

×
1536
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1537
                return BlockChannelRange{
×
1538
                        Height:   block,
×
1539
                        Channels: channelsPerBlock[block],
×
1540
                }
×
1541
        }), nil
×
1542
}
1543

1544
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1545
// zombie. This method is used on an ad-hoc basis, when channels need to be
1546
// marked as zombies outside the normal pruning cycle.
1547
//
1548
// NOTE: part of the V1Store interface.
1549
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1550
        pubKey1, pubKey2 [33]byte) error {
×
1551

×
1552
        ctx := context.TODO()
×
1553

×
1554
        s.cacheMu.Lock()
×
1555
        defer s.cacheMu.Unlock()
×
1556

×
1557
        chanIDB := channelIDToBytes(chanID)
×
1558

×
1559
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1560
                return db.UpsertZombieChannel(
×
1561
                        ctx, sqlc.UpsertZombieChannelParams{
×
1562
                                Version:  int16(ProtocolV1),
×
1563
                                Scid:     chanIDB,
×
1564
                                NodeKey1: pubKey1[:],
×
1565
                                NodeKey2: pubKey2[:],
×
1566
                        },
×
1567
                )
×
1568
        }, sqldb.NoOpReset)
×
1569
        if err != nil {
×
1570
                return fmt.Errorf("unable to upsert zombie channel "+
×
1571
                        "(channel_id=%d): %w", chanID, err)
×
1572
        }
×
1573

1574
        s.rejectCache.remove(chanID)
×
1575
        s.chanCache.remove(chanID)
×
1576

×
1577
        return nil
×
1578
}
1579

1580
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1581
//
1582
// NOTE: part of the V1Store interface.
1583
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1584
        s.cacheMu.Lock()
×
1585
        defer s.cacheMu.Unlock()
×
1586

×
1587
        var (
×
1588
                ctx     = context.TODO()
×
1589
                chanIDB = channelIDToBytes(chanID)
×
1590
        )
×
1591

×
1592
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1593
                res, err := db.DeleteZombieChannel(
×
1594
                        ctx, sqlc.DeleteZombieChannelParams{
×
1595
                                Scid:    chanIDB,
×
1596
                                Version: int16(ProtocolV1),
×
1597
                        },
×
1598
                )
×
1599
                if err != nil {
×
1600
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1601
                                err)
×
1602
                }
×
1603

1604
                rows, err := res.RowsAffected()
×
1605
                if err != nil {
×
1606
                        return err
×
1607
                }
×
1608

1609
                if rows == 0 {
×
1610
                        return ErrZombieEdgeNotFound
×
1611
                } else if rows > 1 {
×
1612
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1613
                                "expected 1", rows)
×
1614
                }
×
1615

1616
                return nil
×
1617
        }, sqldb.NoOpReset)
1618
        if err != nil {
×
1619
                return fmt.Errorf("unable to mark edge live "+
×
1620
                        "(channel_id=%d): %w", chanID, err)
×
1621
        }
×
1622

1623
        s.rejectCache.remove(chanID)
×
1624
        s.chanCache.remove(chanID)
×
1625

×
1626
        return err
×
1627
}
1628

1629
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1630
// zombie, then the two node public keys corresponding to this edge are also
1631
// returned.
1632
//
1633
// NOTE: part of the V1Store interface.
1634
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
1635
        error) {
×
1636

×
1637
        var (
×
1638
                ctx              = context.TODO()
×
1639
                isZombie         bool
×
1640
                pubKey1, pubKey2 route.Vertex
×
1641
                chanIDB          = channelIDToBytes(chanID)
×
1642
        )
×
1643

×
1644
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1645
                zombie, err := db.GetZombieChannel(
×
1646
                        ctx, sqlc.GetZombieChannelParams{
×
1647
                                Scid:    chanIDB,
×
1648
                                Version: int16(ProtocolV1),
×
1649
                        },
×
1650
                )
×
1651
                if errors.Is(err, sql.ErrNoRows) {
×
1652
                        return nil
×
1653
                }
×
1654
                if err != nil {
×
1655
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1656
                                err)
×
1657
                }
×
1658

1659
                copy(pubKey1[:], zombie.NodeKey1)
×
1660
                copy(pubKey2[:], zombie.NodeKey2)
×
1661
                isZombie = true
×
1662

×
1663
                return nil
×
1664
        }, sqldb.NoOpReset)
1665
        if err != nil {
×
1666
                return false, route.Vertex{}, route.Vertex{},
×
1667
                        fmt.Errorf("%w: %w (chanID=%d)",
×
1668
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
×
1669
        }
×
1670

1671
        return isZombie, pubKey1, pubKey2, nil
×
1672
}
1673

1674
// NumZombies returns the current number of zombie channels in the graph.
1675
//
1676
// NOTE: part of the V1Store interface.
1677
func (s *SQLStore) NumZombies() (uint64, error) {
×
1678
        var (
×
1679
                ctx        = context.TODO()
×
1680
                numZombies uint64
×
1681
        )
×
1682
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1683
                count, err := db.CountZombieChannels(ctx, int16(ProtocolV1))
×
1684
                if err != nil {
×
1685
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1686
                                err)
×
1687
                }
×
1688

1689
                numZombies = uint64(count)
×
1690

×
1691
                return nil
×
1692
        }, sqldb.NoOpReset)
1693
        if err != nil {
×
1694
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1695
        }
×
1696

1697
        return numZombies, nil
×
1698
}
1699

1700
// DeleteChannelEdges removes edges with the given channel IDs from the
1701
// database and marks them as zombies. This ensures that we're unable to re-add
1702
// it to our database once again. If an edge does not exist within the
1703
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1704
// true, then when we mark these edges as zombies, we'll set up the keys such
1705
// that we require the node that failed to send the fresh update to be the one
1706
// that resurrects the channel from its zombie state. The markZombie bool
1707
// denotes whether to mark the channel as a zombie.
1708
//
1709
// NOTE: part of the V1Store interface.
1710
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1711
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
1712

×
1713
        s.cacheMu.Lock()
×
1714
        defer s.cacheMu.Unlock()
×
1715

×
1716
        var (
×
1717
                ctx     = context.TODO()
×
1718
                deleted []*models.ChannelEdgeInfo
×
1719
        )
×
1720
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1721
                for _, chanID := range chanIDs {
×
1722
                        chanIDB := channelIDToBytes(chanID)
×
1723

×
1724
                        row, err := db.GetChannelBySCIDWithPolicies(
×
1725
                                ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1726
                                        Scid:    chanIDB,
×
1727
                                        Version: int16(ProtocolV1),
×
1728
                                },
×
1729
                        )
×
1730
                        if errors.Is(err, sql.ErrNoRows) {
×
1731
                                return ErrEdgeNotFound
×
1732
                        } else if err != nil {
×
1733
                                return fmt.Errorf("unable to fetch channel: %w",
×
1734
                                        err)
×
1735
                        }
×
1736

1737
                        node1, node2, err := buildNodeVertices(
×
1738
                                row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
1739
                        )
×
1740
                        if err != nil {
×
1741
                                return err
×
1742
                        }
×
1743

1744
                        info, err := getAndBuildEdgeInfo(
×
1745
                                ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
1746
                                row.GraphChannel, node1, node2,
×
1747
                        )
×
1748
                        if err != nil {
×
1749
                                return err
×
1750
                        }
×
1751

1752
                        err = db.DeleteChannel(ctx, row.GraphChannel.ID)
×
1753
                        if err != nil {
×
1754
                                return fmt.Errorf("unable to delete "+
×
1755
                                        "channel: %w", err)
×
1756
                        }
×
1757

1758
                        deleted = append(deleted, info)
×
1759

×
1760
                        if !markZombie {
×
1761
                                continue
×
1762
                        }
1763

1764
                        nodeKey1, nodeKey2 := info.NodeKey1Bytes,
×
1765
                                info.NodeKey2Bytes
×
1766
                        if strictZombiePruning {
×
1767
                                var e1UpdateTime, e2UpdateTime *time.Time
×
1768
                                if row.Policy1LastUpdate.Valid {
×
1769
                                        e1Time := time.Unix(
×
1770
                                                row.Policy1LastUpdate.Int64, 0,
×
1771
                                        )
×
1772
                                        e1UpdateTime = &e1Time
×
1773
                                }
×
1774
                                if row.Policy2LastUpdate.Valid {
×
1775
                                        e2Time := time.Unix(
×
1776
                                                row.Policy2LastUpdate.Int64, 0,
×
1777
                                        )
×
1778
                                        e2UpdateTime = &e2Time
×
1779
                                }
×
1780

1781
                                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
1782
                                        info, e1UpdateTime, e2UpdateTime,
×
1783
                                )
×
1784
                        }
1785

1786
                        err = db.UpsertZombieChannel(
×
1787
                                ctx, sqlc.UpsertZombieChannelParams{
×
1788
                                        Version:  int16(ProtocolV1),
×
1789
                                        Scid:     chanIDB,
×
1790
                                        NodeKey1: nodeKey1[:],
×
1791
                                        NodeKey2: nodeKey2[:],
×
1792
                                },
×
1793
                        )
×
1794
                        if err != nil {
×
1795
                                return fmt.Errorf("unable to mark channel as "+
×
1796
                                        "zombie: %w", err)
×
1797
                        }
×
1798
                }
1799

1800
                return nil
×
1801
        }, func() {
×
1802
                deleted = nil
×
1803
        })
×
1804
        if err != nil {
×
1805
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1806
                        err)
×
1807
        }
×
1808

1809
        for _, chanID := range chanIDs {
×
1810
                s.rejectCache.remove(chanID)
×
1811
                s.chanCache.remove(chanID)
×
1812
        }
×
1813

1814
        return deleted, nil
×
1815
}
1816

1817
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1818
// channel identified by the channel ID. If the channel can't be found, then
1819
// ErrEdgeNotFound is returned. A struct which houses the general information
1820
// for the channel itself is returned as well as two structs that contain the
1821
// routing policies for the channel in either direction.
1822
//
1823
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1824
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1825
// the ChannelEdgeInfo will only include the public keys of each node.
1826
//
1827
// NOTE: part of the V1Store interface.
1828
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1829
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1830
        *models.ChannelEdgePolicy, error) {
×
1831

×
1832
        var (
×
1833
                ctx              = context.TODO()
×
1834
                edge             *models.ChannelEdgeInfo
×
1835
                policy1, policy2 *models.ChannelEdgePolicy
×
1836
                chanIDB          = channelIDToBytes(chanID)
×
1837
        )
×
1838
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1839
                row, err := db.GetChannelBySCIDWithPolicies(
×
1840
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1841
                                Scid:    chanIDB,
×
1842
                                Version: int16(ProtocolV1),
×
1843
                        },
×
1844
                )
×
1845
                if errors.Is(err, sql.ErrNoRows) {
×
1846
                        // First check if this edge is perhaps in the zombie
×
1847
                        // index.
×
1848
                        zombie, err := db.GetZombieChannel(
×
1849
                                ctx, sqlc.GetZombieChannelParams{
×
1850
                                        Scid:    chanIDB,
×
1851
                                        Version: int16(ProtocolV1),
×
1852
                                },
×
1853
                        )
×
1854
                        if errors.Is(err, sql.ErrNoRows) {
×
1855
                                return ErrEdgeNotFound
×
1856
                        } else if err != nil {
×
1857
                                return fmt.Errorf("unable to check if "+
×
1858
                                        "channel is zombie: %w", err)
×
1859
                        }
×
1860

1861
                        // At this point, we know the channel is a zombie, so
1862
                        // we'll return an error indicating this, and we will
1863
                        // populate the edge info with the public keys of each
1864
                        // party as this is the only information we have about
1865
                        // it.
1866
                        edge = &models.ChannelEdgeInfo{}
×
1867
                        copy(edge.NodeKey1Bytes[:], zombie.NodeKey1)
×
1868
                        copy(edge.NodeKey2Bytes[:], zombie.NodeKey2)
×
1869

×
1870
                        return ErrZombieEdge
×
1871
                } else if err != nil {
×
1872
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1873
                }
×
1874

1875
                node1, node2, err := buildNodeVertices(
×
1876
                        row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
1877
                )
×
1878
                if err != nil {
×
1879
                        return err
×
1880
                }
×
1881

1882
                edge, err = getAndBuildEdgeInfo(
×
1883
                        ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
1884
                        row.GraphChannel, node1, node2,
×
1885
                )
×
1886
                if err != nil {
×
1887
                        return fmt.Errorf("unable to build channel info: %w",
×
1888
                                err)
×
1889
                }
×
1890

1891
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1892
                if err != nil {
×
1893
                        return fmt.Errorf("unable to extract channel "+
×
1894
                                "policies: %w", err)
×
1895
                }
×
1896

1897
                policy1, policy2, err = getAndBuildChanPolicies(
×
1898
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1899
                )
×
1900
                if err != nil {
×
1901
                        return fmt.Errorf("unable to build channel "+
×
1902
                                "policies: %w", err)
×
1903
                }
×
1904

1905
                return nil
×
1906
        }, sqldb.NoOpReset)
1907
        if err != nil {
×
1908
                // If we are returning the ErrZombieEdge, then we also need to
×
1909
                // return the edge info as the method comment indicates that
×
1910
                // this will be populated when the edge is a zombie.
×
1911
                return edge, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1912
                        err)
×
1913
        }
×
1914

1915
        return edge, policy1, policy2, nil
×
1916
}
1917

1918
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
1919
// the channel identified by the funding outpoint. If the channel can't be
1920
// found, then ErrEdgeNotFound is returned. A struct which houses the general
1921
// information for the channel itself is returned as well as two structs that
1922
// contain the routing policies for the channel in either direction.
1923
//
1924
// NOTE: part of the V1Store interface.
1925
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
1926
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1927
        *models.ChannelEdgePolicy, error) {
×
1928

×
1929
        var (
×
1930
                ctx              = context.TODO()
×
1931
                edge             *models.ChannelEdgeInfo
×
1932
                policy1, policy2 *models.ChannelEdgePolicy
×
1933
        )
×
1934
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1935
                row, err := db.GetChannelByOutpointWithPolicies(
×
1936
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
1937
                                Outpoint: op.String(),
×
1938
                                Version:  int16(ProtocolV1),
×
1939
                        },
×
1940
                )
×
1941
                if errors.Is(err, sql.ErrNoRows) {
×
1942
                        return ErrEdgeNotFound
×
1943
                } else if err != nil {
×
1944
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1945
                }
×
1946

1947
                node1, node2, err := buildNodeVertices(
×
1948
                        row.Node1Pubkey, row.Node2Pubkey,
×
1949
                )
×
1950
                if err != nil {
×
1951
                        return err
×
1952
                }
×
1953

1954
                edge, err = getAndBuildEdgeInfo(
×
1955
                        ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
1956
                        row.GraphChannel, node1, node2,
×
1957
                )
×
1958
                if err != nil {
×
1959
                        return fmt.Errorf("unable to build channel info: %w",
×
1960
                                err)
×
1961
                }
×
1962

1963
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1964
                if err != nil {
×
1965
                        return fmt.Errorf("unable to extract channel "+
×
1966
                                "policies: %w", err)
×
1967
                }
×
1968

1969
                policy1, policy2, err = getAndBuildChanPolicies(
×
1970
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1971
                )
×
1972
                if err != nil {
×
1973
                        return fmt.Errorf("unable to build channel "+
×
1974
                                "policies: %w", err)
×
1975
                }
×
1976

1977
                return nil
×
1978
        }, sqldb.NoOpReset)
1979
        if err != nil {
×
1980
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1981
                        err)
×
1982
        }
×
1983

1984
        return edge, policy1, policy2, nil
×
1985
}
1986

1987
// HasChannelEdge returns true if the database knows of a channel edge with the
1988
// passed channel ID, and false otherwise. If an edge with that ID is found
1989
// within the graph, then two time stamps representing the last time the edge
1990
// was updated for both directed edges are returned along with the boolean. If
1991
// it is not found, then the zombie index is checked and its result is returned
1992
// as the second boolean.
1993
//
1994
// NOTE: part of the V1Store interface.
1995
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
1996
        bool, error) {
×
1997

×
1998
        ctx := context.TODO()
×
1999

×
2000
        var (
×
2001
                exists          bool
×
2002
                isZombie        bool
×
2003
                node1LastUpdate time.Time
×
2004
                node2LastUpdate time.Time
×
2005
        )
×
2006

×
2007
        // We'll query the cache with the shared lock held to allow multiple
×
2008
        // readers to access values in the cache concurrently if they exist.
×
2009
        s.cacheMu.RLock()
×
2010
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2011
                s.cacheMu.RUnlock()
×
2012
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2013
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2014
                exists, isZombie = entry.flags.unpack()
×
2015

×
2016
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2017
        }
×
2018
        s.cacheMu.RUnlock()
×
2019

×
2020
        s.cacheMu.Lock()
×
2021
        defer s.cacheMu.Unlock()
×
2022

×
2023
        // The item was not found with the shared lock, so we'll acquire the
×
2024
        // exclusive lock and check the cache again in case another method added
×
2025
        // the entry to the cache while no lock was held.
×
2026
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2027
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2028
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2029
                exists, isZombie = entry.flags.unpack()
×
2030

×
2031
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2032
        }
×
2033

2034
        chanIDB := channelIDToBytes(chanID)
×
2035
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2036
                channel, err := db.GetChannelBySCID(
×
2037
                        ctx, sqlc.GetChannelBySCIDParams{
×
2038
                                Scid:    chanIDB,
×
2039
                                Version: int16(ProtocolV1),
×
2040
                        },
×
2041
                )
×
2042
                if errors.Is(err, sql.ErrNoRows) {
×
2043
                        // Check if it is a zombie channel.
×
2044
                        isZombie, err = db.IsZombieChannel(
×
2045
                                ctx, sqlc.IsZombieChannelParams{
×
2046
                                        Scid:    chanIDB,
×
2047
                                        Version: int16(ProtocolV1),
×
2048
                                },
×
2049
                        )
×
2050
                        if err != nil {
×
2051
                                return fmt.Errorf("could not check if channel "+
×
2052
                                        "is zombie: %w", err)
×
2053
                        }
×
2054

2055
                        return nil
×
2056
                } else if err != nil {
×
2057
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2058
                }
×
2059

2060
                exists = true
×
2061

×
2062
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
2063
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2064
                                Version:   int16(ProtocolV1),
×
2065
                                ChannelID: channel.ID,
×
2066
                                NodeID:    channel.NodeID1,
×
2067
                        },
×
2068
                )
×
2069
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2070
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2071
                                err)
×
2072
                } else if err == nil {
×
2073
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
2074
                }
×
2075

2076
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
2077
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2078
                                Version:   int16(ProtocolV1),
×
2079
                                ChannelID: channel.ID,
×
2080
                                NodeID:    channel.NodeID2,
×
2081
                        },
×
2082
                )
×
2083
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2084
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2085
                                err)
×
2086
                } else if err == nil {
×
2087
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
2088
                }
×
2089

2090
                return nil
×
2091
        }, sqldb.NoOpReset)
2092
        if err != nil {
×
2093
                return time.Time{}, time.Time{}, false, false,
×
2094
                        fmt.Errorf("unable to fetch channel: %w", err)
×
2095
        }
×
2096

2097
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
2098
                upd1Time: node1LastUpdate.Unix(),
×
2099
                upd2Time: node2LastUpdate.Unix(),
×
2100
                flags:    packRejectFlags(exists, isZombie),
×
2101
        })
×
2102

×
2103
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2104
}
2105

2106
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2107
// passed channel point (outpoint). If the passed channel doesn't exist within
2108
// the database, then ErrEdgeNotFound is returned.
2109
//
2110
// NOTE: part of the V1Store interface.
2111
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
2112
        var (
×
2113
                ctx       = context.TODO()
×
2114
                channelID uint64
×
2115
        )
×
2116
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2117
                chanID, err := db.GetSCIDByOutpoint(
×
2118
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2119
                                Outpoint: chanPoint.String(),
×
2120
                                Version:  int16(ProtocolV1),
×
2121
                        },
×
2122
                )
×
2123
                if errors.Is(err, sql.ErrNoRows) {
×
2124
                        return ErrEdgeNotFound
×
2125
                } else if err != nil {
×
2126
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2127
                                err)
×
2128
                }
×
2129

2130
                channelID = byteOrder.Uint64(chanID)
×
2131

×
2132
                return nil
×
2133
        }, sqldb.NoOpReset)
2134
        if err != nil {
×
2135
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2136
        }
×
2137

2138
        return channelID, nil
×
2139
}
2140

2141
// IsPublicNode is a helper method that determines whether the node with the
2142
// given public key is seen as a public node in the graph from the graph's
2143
// source node's point of view.
2144
//
2145
// NOTE: part of the V1Store interface.
2146
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
2147
        ctx := context.TODO()
×
2148

×
2149
        var isPublic bool
×
2150
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2151
                var err error
×
2152
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2153

×
2154
                return err
×
2155
        }, sqldb.NoOpReset)
×
2156
        if err != nil {
×
2157
                return false, fmt.Errorf("unable to check if node is "+
×
2158
                        "public: %w", err)
×
2159
        }
×
2160

2161
        return isPublic, nil
×
2162
}
2163

2164
// FetchChanInfos returns the set of channel edges that correspond to the passed
2165
// channel ID's. If an edge is the query is unknown to the database, it will
2166
// skipped and the result will contain only those edges that exist at the time
2167
// of the query. This can be used to respond to peer queries that are seeking to
2168
// fill in gaps in their view of the channel graph.
2169
//
2170
// NOTE: part of the V1Store interface.
2171
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2172
        var (
×
2173
                ctx   = context.TODO()
×
NEW
2174
                edges = make(map[uint64]ChannelEdge)
×
2175
        )
×
2176
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
2177
                chanCallBack := func(ctx context.Context,
×
NEW
2178
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
2179

×
2180
                        node1, node2, err := buildNodes(
×
2181
                                ctx, db, row.GraphNode, row.GraphNode_2,
×
2182
                        )
×
2183
                        if err != nil {
×
2184
                                return fmt.Errorf("unable to fetch nodes: %w",
×
2185
                                        err)
×
2186
                        }
×
2187

2188
                        edge, err := getAndBuildEdgeInfo(
×
2189
                                ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
2190
                                row.GraphChannel, node1.PubKeyBytes,
×
2191
                                node2.PubKeyBytes,
×
2192
                        )
×
2193
                        if err != nil {
×
2194
                                return fmt.Errorf("unable to build "+
×
2195
                                        "channel info: %w", err)
×
2196
                        }
×
2197

2198
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2199
                        if err != nil {
×
2200
                                return fmt.Errorf("unable to extract channel "+
×
2201
                                        "policies: %w", err)
×
2202
                        }
×
2203

2204
                        p1, p2, err := getAndBuildChanPolicies(
×
2205
                                ctx, db, dbPol1, dbPol2, edge.ChannelID,
×
2206
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
2207
                        )
×
2208
                        if err != nil {
×
2209
                                return fmt.Errorf("unable to build channel "+
×
2210
                                        "policies: %w", err)
×
2211
                        }
×
2212

NEW
2213
                        edges[edge.ChannelID] = ChannelEdge{
×
2214
                                Info:    edge,
×
2215
                                Policy1: p1,
×
2216
                                Policy2: p2,
×
2217
                                Node1:   node1,
×
2218
                                Node2:   node2,
×
NEW
2219
                        }
×
NEW
2220

×
NEW
2221
                        return nil
×
2222
                }
2223

NEW
2224
                return s.forEachChanWithPoliciesInSCIDList(
×
NEW
2225
                        ctx, db, chanCallBack, chanIDs,
×
NEW
2226
                )
×
2227
        }, func() {
×
NEW
2228
                clear(edges)
×
2229
        })
×
2230
        if err != nil {
×
2231
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2232
        }
×
2233

NEW
2234
        res := make([]ChannelEdge, 0, len(edges))
×
NEW
2235
        for _, chanID := range chanIDs {
×
NEW
2236
                edge, ok := edges[chanID]
×
NEW
2237
                if !ok {
×
NEW
2238
                        continue
×
2239
                }
2240

NEW
2241
                res = append(res, edge)
×
2242
        }
2243

NEW
2244
        return res, nil
×
2245
}
2246

2247
// forEachChanWithPoliciesInSCIDList is a wrapper around the
2248
// GetChannelsBySCIDWithPolicies query that allows us to iterate through
2249
// channels in a paginated manner.
2250
func (s *SQLStore) forEachChanWithPoliciesInSCIDList(ctx context.Context,
2251
        db SQLQueries, cb func(ctx context.Context,
2252
                row sqlc.GetChannelsBySCIDWithPoliciesRow) error,
NEW
2253
        chanIDs []uint64) error {
×
NEW
2254

×
NEW
2255
        queryWrapper := func(ctx context.Context,
×
NEW
2256
                scids [][]byte) ([]sqlc.GetChannelsBySCIDWithPoliciesRow,
×
NEW
2257
                error) {
×
NEW
2258

×
NEW
2259
                return db.GetChannelsBySCIDWithPolicies(
×
NEW
2260
                        ctx, sqlc.GetChannelsBySCIDWithPoliciesParams{
×
NEW
2261
                                Version: int16(ProtocolV1),
×
NEW
2262
                                Scids:   scids,
×
NEW
2263
                        },
×
NEW
2264
                )
×
NEW
2265
        }
×
2266

NEW
2267
        return sqldb.ExecutePagedQuery(
×
NEW
2268
                ctx, s.cfg.PaginationCfg, chanIDs, channelIDToBytes,
×
NEW
2269
                queryWrapper, cb,
×
NEW
2270
        )
×
2271
}
2272

2273
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2274
// ID's that we don't know and are not known zombies of the passed set. In other
2275
// words, we perform a set difference of our set of chan ID's and the ones
2276
// passed in. This method can be used by callers to determine the set of
2277
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2278
// known zombies is also returned.
2279
//
2280
// NOTE: part of the V1Store interface.
2281
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
2282
        []ChannelUpdateInfo, error) {
×
2283

×
2284
        var (
×
2285
                ctx          = context.TODO()
×
2286
                newChanIDs   []uint64
×
2287
                knownZombies []ChannelUpdateInfo
×
NEW
2288
                infoLookup   = make(
×
NEW
2289
                        map[uint64]ChannelUpdateInfo, len(chansInfo),
×
NEW
2290
                )
×
2291
        )
×
NEW
2292

×
NEW
2293
        // We first build a lookup map of the channel ID's to the
×
NEW
2294
        // ChannelUpdateInfo. This allows us to quickly delete channels that we
×
NEW
2295
        // already know about.
×
NEW
2296
        for _, chanInfo := range chansInfo {
×
NEW
2297
                infoLookup[chanInfo.ShortChannelID.ToUint64()] = chanInfo
×
NEW
2298
        }
×
2299

2300
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
2301
                // The call-back function deletes known channels from
×
NEW
2302
                // infoLookup, so that we can later check which channels are
×
NEW
2303
                // zombies by only looking at the remaining channels in the set.
×
NEW
2304
                cb := func(ctx context.Context,
×
NEW
2305
                        channel sqlc.GraphChannel) error {
×
NEW
2306

×
NEW
2307
                        delete(infoLookup, byteOrder.Uint64(channel.Scid))
×
NEW
2308

×
NEW
2309
                        return nil
×
NEW
2310
                }
×
2311

NEW
2312
                err := s.forEachChanInSCIDList(ctx, db, cb, chansInfo)
×
NEW
2313
                if err != nil {
×
NEW
2314
                        return fmt.Errorf("unable to iterate through "+
×
NEW
2315
                                "channels: %w", err)
×
NEW
2316
                }
×
2317

2318
                // We want to ensure that we deal with the channels in the
2319
                // same order that they were passed in, so we iterate over the
2320
                // original chansInfo slice and then check if that channel is
2321
                // still in the infoLookup map.
2322
                for _, chanInfo := range chansInfo {
×
2323
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
NEW
2324
                        if _, ok := infoLookup[channelID]; !ok {
×
2325
                                continue
×
2326
                        }
2327

2328
                        isZombie, err := db.IsZombieChannel(
×
2329
                                ctx, sqlc.IsZombieChannelParams{
×
NEW
2330
                                        Scid:    channelIDToBytes(channelID),
×
2331
                                        Version: int16(ProtocolV1),
×
2332
                                },
×
2333
                        )
×
2334
                        if err != nil {
×
2335
                                return fmt.Errorf("unable to fetch zombie "+
×
2336
                                        "channel: %w", err)
×
2337
                        }
×
2338

2339
                        if isZombie {
×
2340
                                knownZombies = append(knownZombies, chanInfo)
×
2341

×
2342
                                continue
×
2343
                        }
2344

2345
                        newChanIDs = append(newChanIDs, channelID)
×
2346
                }
2347

2348
                return nil
×
2349
        }, func() {
×
2350
                newChanIDs = nil
×
2351
                knownZombies = nil
×
2352
        })
×
2353
        if err != nil {
×
2354
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2355
        }
×
2356

2357
        return newChanIDs, knownZombies, nil
×
2358
}
2359

2360
// forEachChanInSCIDList is a helper method that executes a paged query
2361
// against the database to fetch all channels that match the passed
2362
// ChannelUpdateInfo slice. The callback function is called for each channel
2363
// that is found.
2364
func (s *SQLStore) forEachChanInSCIDList(ctx context.Context, db SQLQueries,
2365
        cb func(ctx context.Context, channel sqlc.GraphChannel) error,
NEW
2366
        chansInfo []ChannelUpdateInfo) error {
×
NEW
2367

×
NEW
2368
        queryWrapper := func(ctx context.Context,
×
NEW
2369
                scids [][]byte) ([]sqlc.GraphChannel, error) {
×
NEW
2370

×
NEW
2371
                return db.GetChannelsBySCIDs(
×
NEW
2372
                        ctx, sqlc.GetChannelsBySCIDsParams{
×
NEW
2373
                                Version: int16(ProtocolV1),
×
NEW
2374
                                Scids:   scids,
×
NEW
2375
                        },
×
NEW
2376
                )
×
NEW
2377
        }
×
2378

NEW
2379
        chanIDConverter := func(chanInfo ChannelUpdateInfo) []byte {
×
NEW
2380
                channelID := chanInfo.ShortChannelID.ToUint64()
×
NEW
2381

×
NEW
2382
                return channelIDToBytes(channelID)
×
NEW
2383
        }
×
2384

NEW
2385
        return sqldb.ExecutePagedQuery(
×
NEW
2386
                ctx, s.cfg.PaginationCfg, chansInfo, chanIDConverter,
×
NEW
2387
                queryWrapper, cb,
×
NEW
2388
        )
×
2389
}
2390

2391
// PruneGraphNodes is a garbage collection method which attempts to prune out
2392
// any nodes from the channel graph that are currently unconnected. This ensure
2393
// that we only maintain a graph of reachable nodes. In the event that a pruned
2394
// node gains more channels, it will be re-added back to the graph.
2395
//
2396
// NOTE: this prunes nodes across protocol versions. It will never prune the
2397
// source nodes.
2398
//
2399
// NOTE: part of the V1Store interface.
2400
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
2401
        var ctx = context.TODO()
×
2402

×
2403
        var prunedNodes []route.Vertex
×
2404
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2405
                var err error
×
2406
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2407

×
2408
                return err
×
2409
        }, func() {
×
2410
                prunedNodes = nil
×
2411
        })
×
2412
        if err != nil {
×
2413
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2414
        }
×
2415

2416
        return prunedNodes, nil
×
2417
}
2418

2419
// PruneGraph prunes newly closed channels from the channel graph in response
2420
// to a new block being solved on the network. Any transactions which spend the
2421
// funding output of any known channels within he graph will be deleted.
2422
// Additionally, the "prune tip", or the last block which has been used to
2423
// prune the graph is stored so callers can ensure the graph is fully in sync
2424
// with the current UTXO state. A slice of channels that have been closed by
2425
// the target block along with any pruned nodes are returned if the function
2426
// succeeds without error.
2427
//
2428
// NOTE: part of the V1Store interface.
2429
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2430
        blockHash *chainhash.Hash, blockHeight uint32) (
2431
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
2432

×
2433
        ctx := context.TODO()
×
2434

×
2435
        s.cacheMu.Lock()
×
2436
        defer s.cacheMu.Unlock()
×
2437

×
2438
        var (
×
2439
                closedChans []*models.ChannelEdgeInfo
×
2440
                prunedNodes []route.Vertex
×
2441
        )
×
2442
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
NEW
2443
                // Define the callback function for processing each channel
×
NEW
2444
                channelCallback := func(ctx context.Context,
×
NEW
2445
                        row sqlc.GetChannelsByOutpointsRow) error {
×
2446

×
2447
                        node1, node2, err := buildNodeVertices(
×
2448
                                row.Node1Pubkey, row.Node2Pubkey,
×
2449
                        )
×
2450
                        if err != nil {
×
2451
                                return err
×
2452
                        }
×
2453

2454
                        info, err := getAndBuildEdgeInfo(
×
2455
                                ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
2456
                                row.GraphChannel, node1, node2,
×
2457
                        )
×
2458
                        if err != nil {
×
2459
                                return err
×
2460
                        }
×
2461

2462
                        err = db.DeleteChannel(ctx, row.GraphChannel.ID)
×
2463
                        if err != nil {
×
2464
                                return fmt.Errorf("unable to delete "+
×
2465
                                        "channel: %w", err)
×
2466
                        }
×
2467

2468
                        closedChans = append(closedChans, info)
×
NEW
2469
                        return nil
×
2470
                }
2471

UNCOV
2472
                err := s.forEachChanInOutpoints(
×
NEW
2473
                        ctx, db, spentOutputs, channelCallback,
×
NEW
2474
                )
×
NEW
2475
                if err != nil {
×
NEW
2476
                        return fmt.Errorf("unable to fetch channels by "+
×
NEW
2477
                                "outpoints: %w", err)
×
NEW
2478
                }
×
2479

NEW
2480
                err = db.UpsertPruneLogEntry(
×
NEW
2481
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
2482
                                BlockHash:   blockHash[:],
×
2483
                                BlockHeight: int64(blockHeight),
×
2484
                        },
×
2485
                )
×
2486
                if err != nil {
×
2487
                        return fmt.Errorf("unable to insert prune log "+
×
2488
                                "entry: %w", err)
×
2489
                }
×
2490

2491
                // Now that we've pruned some channels, we'll also prune any
2492
                // nodes that no longer have any channels.
UNCOV
2493
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2494
                if err != nil {
×
2495
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2496
                                err)
×
2497
                }
×
2498

UNCOV
2499
                return nil
×
2500
        }, func() {
×
2501
                prunedNodes = nil
×
2502
                closedChans = nil
×
2503
        })
×
2504
        if err != nil {
×
2505
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2506
        }
×
2507

UNCOV
2508
        for _, channel := range closedChans {
×
2509
                s.rejectCache.remove(channel.ChannelID)
×
2510
                s.chanCache.remove(channel.ChannelID)
×
2511
        }
×
2512

UNCOV
2513
        return closedChans, prunedNodes, nil
×
2514
}
2515

2516
// forEachChanInOutpoints is a helper function that executes a paginated
2517
// query to fetch channels by their outpoints and applies the given call-back
2518
// to each.
2519
//
2520
// NOTE: this fetches channels for all protocol versions.
2521
func (s *SQLStore) forEachChanInOutpoints(ctx context.Context, db SQLQueries,
2522
        outpoints []*wire.OutPoint, cb func(ctx context.Context,
NEW
2523
                row sqlc.GetChannelsByOutpointsRow) error) error {
×
NEW
2524

×
NEW
2525
        // Create a wrapper that uses the transaction's db instance to execute
×
NEW
2526
        // the query.
×
NEW
2527
        queryWrapper := func(ctx context.Context,
×
NEW
2528
                pageOutpoints []string) ([]sqlc.GetChannelsByOutpointsRow,
×
NEW
2529
                error) {
×
NEW
2530

×
NEW
2531
                return db.GetChannelsByOutpoints(ctx, pageOutpoints)
×
NEW
2532
        }
×
2533

2534
        // Define the conversion function from Outpoint to string
NEW
2535
        outpointToString := func(outpoint *wire.OutPoint) string {
×
NEW
2536
                return outpoint.String()
×
NEW
2537
        }
×
2538

NEW
2539
        return sqldb.ExecutePagedQuery(
×
NEW
2540
                ctx, s.cfg.PaginationCfg, outpoints, outpointToString,
×
NEW
2541
                queryWrapper, cb,
×
NEW
2542
        )
×
2543
}
2544

2545
// ChannelView returns the verifiable edge information for each active channel
2546
// within the known channel graph. The set of UTXOs (along with their scripts)
2547
// returned are the ones that need to be watched on chain to detect channel
2548
// closes on the resident blockchain.
2549
//
2550
// NOTE: part of the V1Store interface.
UNCOV
2551
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
2552
        var (
×
2553
                ctx        = context.TODO()
×
2554
                edgePoints []EdgePoint
×
2555
        )
×
2556

×
2557
        handleChannel := func(db SQLQueries,
×
2558
                channel sqlc.ListChannelsPaginatedRow) error {
×
2559

×
2560
                pkScript, err := genMultiSigP2WSH(
×
2561
                        channel.BitcoinKey1, channel.BitcoinKey2,
×
2562
                )
×
2563
                if err != nil {
×
2564
                        return err
×
2565
                }
×
2566

UNCOV
2567
                op, err := wire.NewOutPointFromString(channel.Outpoint)
×
2568
                if err != nil {
×
2569
                        return err
×
2570
                }
×
2571

UNCOV
2572
                edgePoints = append(edgePoints, EdgePoint{
×
2573
                        FundingPkScript: pkScript,
×
2574
                        OutPoint:        *op,
×
2575
                })
×
2576

×
2577
                return nil
×
2578
        }
2579

UNCOV
2580
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2581
                lastID := int64(-1)
×
2582
                for {
×
2583
                        rows, err := db.ListChannelsPaginated(
×
2584
                                ctx, sqlc.ListChannelsPaginatedParams{
×
2585
                                        Version: int16(ProtocolV1),
×
2586
                                        ID:      lastID,
×
2587
                                        Limit:   pageSize,
×
2588
                                },
×
2589
                        )
×
2590
                        if err != nil {
×
2591
                                return err
×
2592
                        }
×
2593

UNCOV
2594
                        if len(rows) == 0 {
×
2595
                                break
×
2596
                        }
2597

UNCOV
2598
                        for _, row := range rows {
×
2599
                                err := handleChannel(db, row)
×
2600
                                if err != nil {
×
2601
                                        return err
×
2602
                                }
×
2603

UNCOV
2604
                                lastID = row.ID
×
2605
                        }
2606
                }
2607

UNCOV
2608
                return nil
×
2609
        }, func() {
×
2610
                edgePoints = nil
×
2611
        })
×
2612
        if err != nil {
×
2613
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
2614
        }
×
2615

UNCOV
2616
        return edgePoints, nil
×
2617
}
2618

2619
// PruneTip returns the block height and hash of the latest block that has been
2620
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2621
// to tell if the graph is currently in sync with the current best known UTXO
2622
// state.
2623
//
2624
// NOTE: part of the V1Store interface.
UNCOV
2625
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
2626
        var (
×
2627
                ctx       = context.TODO()
×
2628
                tipHash   chainhash.Hash
×
2629
                tipHeight uint32
×
2630
        )
×
2631
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2632
                pruneTip, err := db.GetPruneTip(ctx)
×
2633
                if errors.Is(err, sql.ErrNoRows) {
×
2634
                        return ErrGraphNeverPruned
×
2635
                } else if err != nil {
×
2636
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
2637
                }
×
2638

UNCOV
2639
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
2640
                tipHeight = uint32(pruneTip.BlockHeight)
×
2641

×
2642
                return nil
×
2643
        }, sqldb.NoOpReset)
UNCOV
2644
        if err != nil {
×
2645
                return nil, 0, err
×
2646
        }
×
2647

UNCOV
2648
        return &tipHash, tipHeight, nil
×
2649
}
2650

2651
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2652
//
2653
// NOTE: this prunes nodes across protocol versions. It will never prune the
2654
// source nodes.
2655
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
UNCOV
2656
        db SQLQueries) ([]route.Vertex, error) {
×
2657

×
2658
        nodeKeys, err := db.DeleteUnconnectedNodes(ctx)
×
2659
        if err != nil {
×
2660
                return nil, fmt.Errorf("unable to delete unconnected "+
×
2661
                        "nodes: %w", err)
×
2662
        }
×
2663

UNCOV
2664
        prunedNodes := make([]route.Vertex, len(nodeKeys))
×
2665
        for i, nodeKey := range nodeKeys {
×
2666
                pub, err := route.NewVertexFromBytes(nodeKey)
×
2667
                if err != nil {
×
2668
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
2669
                                "from bytes: %w", err)
×
2670
                }
×
2671

UNCOV
2672
                prunedNodes[i] = pub
×
2673
        }
2674

UNCOV
2675
        return prunedNodes, nil
×
2676
}
2677

2678
// DisconnectBlockAtHeight is used to indicate that the block specified
2679
// by the passed height has been disconnected from the main chain. This
2680
// will "rewind" the graph back to the height below, deleting channels
2681
// that are no longer confirmed from the graph. The prune log will be
2682
// set to the last prune height valid for the remaining chain.
2683
// Channels that were removed from the graph resulting from the
2684
// disconnected block are returned.
2685
//
2686
// NOTE: part of the V1Store interface.
2687
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
UNCOV
2688
        []*models.ChannelEdgeInfo, error) {
×
2689

×
2690
        ctx := context.TODO()
×
2691

×
2692
        var (
×
2693
                // Every channel having a ShortChannelID starting at 'height'
×
2694
                // will no longer be confirmed.
×
2695
                startShortChanID = lnwire.ShortChannelID{
×
2696
                        BlockHeight: height,
×
2697
                }
×
2698

×
2699
                // Delete everything after this height from the db up until the
×
2700
                // SCID alias range.
×
2701
                endShortChanID = aliasmgr.StartingAlias
×
2702

×
2703
                removedChans []*models.ChannelEdgeInfo
×
2704

×
2705
                chanIDStart = channelIDToBytes(startShortChanID.ToUint64())
×
2706
                chanIDEnd   = channelIDToBytes(endShortChanID.ToUint64())
×
2707
        )
×
2708

×
2709
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2710
                rows, err := db.GetChannelsBySCIDRange(
×
2711
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
2712
                                StartScid: chanIDStart,
×
2713
                                EndScid:   chanIDEnd,
×
2714
                        },
×
2715
                )
×
2716
                if err != nil {
×
2717
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2718
                }
×
2719

UNCOV
2720
                for _, row := range rows {
×
2721
                        node1, node2, err := buildNodeVertices(
×
2722
                                row.Node1PubKey, row.Node2PubKey,
×
2723
                        )
×
2724
                        if err != nil {
×
2725
                                return err
×
2726
                        }
×
2727

UNCOV
2728
                        channel, err := getAndBuildEdgeInfo(
×
2729
                                ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
2730
                                row.GraphChannel, node1, node2,
×
2731
                        )
×
2732
                        if err != nil {
×
2733
                                return err
×
2734
                        }
×
2735

UNCOV
2736
                        err = db.DeleteChannel(ctx, row.GraphChannel.ID)
×
2737
                        if err != nil {
×
2738
                                return fmt.Errorf("unable to delete "+
×
2739
                                        "channel: %w", err)
×
2740
                        }
×
2741

UNCOV
2742
                        removedChans = append(removedChans, channel)
×
2743
                }
2744

UNCOV
2745
                return db.DeletePruneLogEntriesInRange(
×
2746
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2747
                                StartHeight: int64(height),
×
2748
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
2749
                        },
×
2750
                )
×
2751
        }, func() {
×
2752
                removedChans = nil
×
2753
        })
×
2754
        if err != nil {
×
2755
                return nil, fmt.Errorf("unable to disconnect block at "+
×
2756
                        "height: %w", err)
×
2757
        }
×
2758

UNCOV
2759
        for _, channel := range removedChans {
×
2760
                s.rejectCache.remove(channel.ChannelID)
×
2761
                s.chanCache.remove(channel.ChannelID)
×
2762
        }
×
2763

UNCOV
2764
        return removedChans, nil
×
2765
}
2766

2767
// AddEdgeProof sets the proof of an existing edge in the graph database.
2768
//
2769
// NOTE: part of the V1Store interface.
2770
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
UNCOV
2771
        proof *models.ChannelAuthProof) error {
×
2772

×
2773
        var (
×
2774
                ctx       = context.TODO()
×
2775
                scidBytes = channelIDToBytes(scid.ToUint64())
×
2776
        )
×
2777

×
2778
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2779
                res, err := db.AddV1ChannelProof(
×
2780
                        ctx, sqlc.AddV1ChannelProofParams{
×
2781
                                Scid:              scidBytes,
×
2782
                                Node1Signature:    proof.NodeSig1Bytes,
×
2783
                                Node2Signature:    proof.NodeSig2Bytes,
×
2784
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
2785
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
2786
                        },
×
2787
                )
×
2788
                if err != nil {
×
2789
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
2790
                }
×
2791

UNCOV
2792
                n, err := res.RowsAffected()
×
2793
                if err != nil {
×
2794
                        return err
×
2795
                }
×
2796

UNCOV
2797
                if n == 0 {
×
2798
                        return fmt.Errorf("no rows affected when adding edge "+
×
2799
                                "proof for SCID %v", scid)
×
2800
                } else if n > 1 {
×
2801
                        return fmt.Errorf("multiple rows affected when adding "+
×
2802
                                "edge proof for SCID %v: %d rows affected",
×
2803
                                scid, n)
×
2804
                }
×
2805

UNCOV
2806
                return nil
×
2807
        }, sqldb.NoOpReset)
UNCOV
2808
        if err != nil {
×
2809
                return fmt.Errorf("unable to add edge proof: %w", err)
×
2810
        }
×
2811

UNCOV
2812
        return nil
×
2813
}
2814

2815
// PutClosedScid stores a SCID for a closed channel in the database. This is so
2816
// that we can ignore channel announcements that we know to be closed without
2817
// having to validate them and fetch a block.
2818
//
2819
// NOTE: part of the V1Store interface.
UNCOV
2820
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
2821
        var (
×
2822
                ctx     = context.TODO()
×
2823
                chanIDB = channelIDToBytes(scid.ToUint64())
×
2824
        )
×
2825

×
2826
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2827
                return db.InsertClosedChannel(ctx, chanIDB)
×
2828
        }, sqldb.NoOpReset)
×
2829
}
2830

2831
// IsClosedScid checks whether a channel identified by the passed in scid is
2832
// closed. This helps avoid having to perform expensive validation checks.
2833
//
2834
// NOTE: part of the V1Store interface.
UNCOV
2835
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
2836
        var (
×
2837
                ctx      = context.TODO()
×
2838
                isClosed bool
×
2839
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
2840
        )
×
2841
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2842
                var err error
×
2843
                isClosed, err = db.IsClosedChannel(ctx, chanIDB)
×
2844
                if err != nil {
×
2845
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
2846
                                err)
×
2847
                }
×
2848

UNCOV
2849
                return nil
×
2850
        }, sqldb.NoOpReset)
UNCOV
2851
        if err != nil {
×
2852
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
2853
                        err)
×
2854
        }
×
2855

UNCOV
2856
        return isClosed, nil
×
2857
}
2858

2859
// GraphSession will provide the call-back with access to a NodeTraverser
2860
// instance which can be used to perform queries against the channel graph.
2861
//
2862
// NOTE: part of the V1Store interface.
2863
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error,
UNCOV
2864
        reset func()) error {
×
2865

×
2866
        var ctx = context.TODO()
×
2867

×
2868
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2869
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
2870
        }, reset)
×
2871
}
2872

2873
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
2874
// read only transaction for a consistent view of the graph.
2875
type sqlNodeTraverser struct {
2876
        db    SQLQueries
2877
        chain chainhash.Hash
2878
}
2879

2880
// A compile-time assertion to ensure that sqlNodeTraverser implements the
2881
// NodeTraverser interface.
2882
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
2883

2884
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
2885
func newSQLNodeTraverser(db SQLQueries,
UNCOV
2886
        chain chainhash.Hash) *sqlNodeTraverser {
×
2887

×
2888
        return &sqlNodeTraverser{
×
2889
                db:    db,
×
2890
                chain: chain,
×
2891
        }
×
2892
}
×
2893

2894
// ForEachNodeDirectedChannel calls the callback for every channel of the given
2895
// node.
2896
//
2897
// NOTE: Part of the NodeTraverser interface.
2898
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
UNCOV
2899
        cb func(channel *DirectedChannel) error, _ func()) error {
×
2900

×
2901
        ctx := context.TODO()
×
2902

×
2903
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
2904
}
×
2905

2906
// FetchNodeFeatures returns the features of the given node. If the node is
2907
// unknown, assume no additional features are supported.
2908
//
2909
// NOTE: Part of the NodeTraverser interface.
2910
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
UNCOV
2911
        *lnwire.FeatureVector, error) {
×
2912

×
2913
        ctx := context.TODO()
×
2914

×
2915
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
2916
}
×
2917

2918
// forEachNodeDirectedChannel iterates through all channels of a given
2919
// node, executing the passed callback on the directed edge representing the
2920
// channel and its incoming policy. If the node is not found, no error is
2921
// returned.
2922
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
UNCOV
2923
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
2924

×
2925
        toNodeCallback := func() route.Vertex {
×
2926
                return nodePub
×
2927
        }
×
2928

UNCOV
2929
        dbID, err := db.GetNodeIDByPubKey(
×
2930
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
2931
                        Version: int16(ProtocolV1),
×
2932
                        PubKey:  nodePub[:],
×
2933
                },
×
2934
        )
×
2935
        if errors.Is(err, sql.ErrNoRows) {
×
2936
                return nil
×
2937
        } else if err != nil {
×
2938
                return fmt.Errorf("unable to fetch node: %w", err)
×
2939
        }
×
2940

UNCOV
2941
        rows, err := db.ListChannelsByNodeID(
×
2942
                ctx, sqlc.ListChannelsByNodeIDParams{
×
2943
                        Version: int16(ProtocolV1),
×
2944
                        NodeID1: dbID,
×
2945
                },
×
2946
        )
×
2947
        if err != nil {
×
2948
                return fmt.Errorf("unable to fetch channels: %w", err)
×
2949
        }
×
2950

2951
        // Exit early if there are no channels for this node so we don't
2952
        // do the unnecessary feature fetching.
UNCOV
2953
        if len(rows) == 0 {
×
2954
                return nil
×
2955
        }
×
2956

UNCOV
2957
        features, err := getNodeFeatures(ctx, db, dbID)
×
2958
        if err != nil {
×
2959
                return fmt.Errorf("unable to fetch node features: %w", err)
×
2960
        }
×
2961

UNCOV
2962
        for _, row := range rows {
×
2963
                node1, node2, err := buildNodeVertices(
×
2964
                        row.Node1Pubkey, row.Node2Pubkey,
×
2965
                )
×
2966
                if err != nil {
×
2967
                        return fmt.Errorf("unable to build node vertices: %w",
×
2968
                                err)
×
2969
                }
×
2970

UNCOV
2971
                edge := buildCacheableChannelInfo(
×
2972
                        row.GraphChannel, node1, node2,
×
2973
                )
×
2974

×
2975
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2976
                if err != nil {
×
2977
                        return err
×
2978
                }
×
2979

UNCOV
2980
                var p1, p2 *models.CachedEdgePolicy
×
2981
                if dbPol1 != nil {
×
2982
                        policy1, err := buildChanPolicy(
×
2983
                                *dbPol1, edge.ChannelID, nil, node2,
×
2984
                        )
×
2985
                        if err != nil {
×
2986
                                return err
×
2987
                        }
×
2988

UNCOV
2989
                        p1 = models.NewCachedPolicy(policy1)
×
2990
                }
UNCOV
2991
                if dbPol2 != nil {
×
2992
                        policy2, err := buildChanPolicy(
×
2993
                                *dbPol2, edge.ChannelID, nil, node1,
×
2994
                        )
×
2995
                        if err != nil {
×
2996
                                return err
×
2997
                        }
×
2998

UNCOV
2999
                        p2 = models.NewCachedPolicy(policy2)
×
3000
                }
3001

3002
                // Determine the outgoing and incoming policy for this
3003
                // channel and node combo.
UNCOV
3004
                outPolicy, inPolicy := p1, p2
×
3005
                if p1 != nil && node2 == nodePub {
×
3006
                        outPolicy, inPolicy = p2, p1
×
3007
                } else if p2 != nil && node1 != nodePub {
×
3008
                        outPolicy, inPolicy = p2, p1
×
3009
                }
×
3010

UNCOV
3011
                var cachedInPolicy *models.CachedEdgePolicy
×
3012
                if inPolicy != nil {
×
3013
                        cachedInPolicy = inPolicy
×
3014
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
3015
                        cachedInPolicy.ToNodeFeatures = features
×
3016
                }
×
3017

UNCOV
3018
                directedChannel := &DirectedChannel{
×
3019
                        ChannelID:    edge.ChannelID,
×
3020
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
3021
                        OtherNode:    edge.NodeKey2Bytes,
×
3022
                        Capacity:     edge.Capacity,
×
3023
                        OutPolicySet: outPolicy != nil,
×
3024
                        InPolicy:     cachedInPolicy,
×
3025
                }
×
3026
                if outPolicy != nil {
×
3027
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3028
                                directedChannel.InboundFee = fee
×
3029
                        })
×
3030
                }
3031

UNCOV
3032
                if nodePub == edge.NodeKey2Bytes {
×
3033
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
3034
                }
×
3035

UNCOV
3036
                if err := cb(directedChannel); err != nil {
×
3037
                        return err
×
3038
                }
×
3039
        }
3040

UNCOV
3041
        return nil
×
3042
}
3043

3044
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
3045
// and executes the provided callback for each node.
3046
func forEachNodeCacheable(ctx context.Context, db SQLQueries,
UNCOV
3047
        cb func(nodeID int64, nodePub route.Vertex) error) error {
×
3048

×
3049
        lastID := int64(-1)
×
3050

×
3051
        for {
×
3052
                nodes, err := db.ListNodeIDsAndPubKeys(
×
3053
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
3054
                                Version: int16(ProtocolV1),
×
3055
                                ID:      lastID,
×
3056
                                Limit:   pageSize,
×
3057
                        },
×
3058
                )
×
3059
                if err != nil {
×
3060
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
3061
                }
×
3062

UNCOV
3063
                if len(nodes) == 0 {
×
3064
                        break
×
3065
                }
3066

UNCOV
3067
                for _, node := range nodes {
×
3068
                        var pub route.Vertex
×
3069
                        copy(pub[:], node.PubKey)
×
3070

×
3071
                        if err := cb(node.ID, pub); err != nil {
×
3072
                                return fmt.Errorf("forEachNodeCacheable "+
×
3073
                                        "callback failed for node(id=%d): %w",
×
3074
                                        node.ID, err)
×
3075
                        }
×
3076

UNCOV
3077
                        lastID = node.ID
×
3078
                }
3079
        }
3080

UNCOV
3081
        return nil
×
3082
}
3083

3084
// forEachNodeChannel iterates through all channels of a node, executing
3085
// the passed callback on each. The call-back is provided with the channel's
3086
// edge information, the outgoing policy and the incoming policy for the
3087
// channel and node combo.
3088
func forEachNodeChannel(ctx context.Context, db SQLQueries,
3089
        chain chainhash.Hash, id int64, cb func(*models.ChannelEdgeInfo,
3090
                *models.ChannelEdgePolicy,
UNCOV
3091
                *models.ChannelEdgePolicy) error) error {
×
3092

×
3093
        // Get all the V1 channels for this node.Add commentMore actions
×
3094
        rows, err := db.ListChannelsByNodeID(
×
3095
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3096
                        Version: int16(ProtocolV1),
×
3097
                        NodeID1: id,
×
3098
                },
×
3099
        )
×
3100
        if err != nil {
×
3101
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3102
        }
×
3103

3104
        // Call the call-back for each channel and its known policies.
UNCOV
3105
        for _, row := range rows {
×
3106
                node1, node2, err := buildNodeVertices(
×
3107
                        row.Node1Pubkey, row.Node2Pubkey,
×
3108
                )
×
3109
                if err != nil {
×
3110
                        return fmt.Errorf("unable to build node vertices: %w",
×
3111
                                err)
×
3112
                }
×
3113

UNCOV
3114
                edge, err := getAndBuildEdgeInfo(
×
3115
                        ctx, db, chain, row.GraphChannel.ID, row.GraphChannel,
×
3116
                        node1, node2,
×
3117
                )
×
3118
                if err != nil {
×
3119
                        return fmt.Errorf("unable to build channel info: %w",
×
3120
                                err)
×
3121
                }
×
3122

UNCOV
3123
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3124
                if err != nil {
×
3125
                        return fmt.Errorf("unable to extract channel "+
×
3126
                                "policies: %w", err)
×
3127
                }
×
3128

UNCOV
3129
                p1, p2, err := getAndBuildChanPolicies(
×
3130
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
3131
                )
×
3132
                if err != nil {
×
3133
                        return fmt.Errorf("unable to build channel "+
×
3134
                                "policies: %w", err)
×
3135
                }
×
3136

3137
                // Determine the outgoing and incoming policy for this
3138
                // channel and node combo.
UNCOV
3139
                p1ToNode := row.GraphChannel.NodeID2
×
3140
                p2ToNode := row.GraphChannel.NodeID1
×
3141
                outPolicy, inPolicy := p1, p2
×
3142
                if (p1 != nil && p1ToNode == id) ||
×
3143
                        (p2 != nil && p2ToNode != id) {
×
3144

×
3145
                        outPolicy, inPolicy = p2, p1
×
3146
                }
×
3147

UNCOV
3148
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3149
                        return err
×
3150
                }
×
3151
        }
3152

UNCOV
3153
        return nil
×
3154
}
3155

3156
// updateChanEdgePolicy upserts the channel policy info we have stored for
3157
// a channel we already know of.
3158
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3159
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
UNCOV
3160
        error) {
×
3161

×
3162
        var (
×
3163
                node1Pub, node2Pub route.Vertex
×
3164
                isNode1            bool
×
3165
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3166
        )
×
3167

×
3168
        // Check that this edge policy refers to a channel that we already
×
3169
        // know of. We do this explicitly so that we can return the appropriate
×
3170
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
3171
        // abort the transaction which would abort the entire batch.
×
3172
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3173
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3174
                        Scid:    chanIDB,
×
3175
                        Version: int16(ProtocolV1),
×
3176
                },
×
3177
        )
×
3178
        if errors.Is(err, sql.ErrNoRows) {
×
3179
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3180
        } else if err != nil {
×
3181
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3182
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3183
        }
×
3184

UNCOV
3185
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3186
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3187

×
3188
        // Figure out which node this edge is from.
×
3189
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3190
        nodeID := dbChan.NodeID1
×
3191
        if !isNode1 {
×
3192
                nodeID = dbChan.NodeID2
×
3193
        }
×
3194

UNCOV
3195
        var (
×
3196
                inboundBase sql.NullInt64
×
3197
                inboundRate sql.NullInt64
×
3198
        )
×
3199
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3200
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3201
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3202
        })
×
3203

UNCOV
3204
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
3205
                Version:     int16(ProtocolV1),
×
3206
                ChannelID:   dbChan.ID,
×
3207
                NodeID:      nodeID,
×
3208
                Timelock:    int32(edge.TimeLockDelta),
×
3209
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
3210
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
3211
                MinHtlcMsat: int64(edge.MinHTLC),
×
3212
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3213
                Disabled: sql.NullBool{
×
3214
                        Valid: true,
×
3215
                        Bool:  edge.IsDisabled(),
×
3216
                },
×
3217
                MaxHtlcMsat: sql.NullInt64{
×
3218
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3219
                        Int64: int64(edge.MaxHTLC),
×
3220
                },
×
3221
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
3222
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
3223
                InboundBaseFeeMsat:      inboundBase,
×
3224
                InboundFeeRateMilliMsat: inboundRate,
×
3225
                Signature:               edge.SigBytes,
×
3226
        })
×
3227
        if err != nil {
×
3228
                return node1Pub, node2Pub, isNode1,
×
3229
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3230
        }
×
3231

3232
        // Convert the flat extra opaque data into a map of TLV types to
3233
        // values.
UNCOV
3234
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3235
        if err != nil {
×
3236
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3237
                        "marshal extra opaque data: %w", err)
×
3238
        }
×
3239

3240
        // Update the channel policy's extra signed fields.
UNCOV
3241
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3242
        if err != nil {
×
3243
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3244
                        "policy extra TLVs: %w", err)
×
3245
        }
×
3246

UNCOV
3247
        return node1Pub, node2Pub, isNode1, nil
×
3248
}
3249

3250
// getNodeByPubKey attempts to look up a target node by its public key.
3251
func getNodeByPubKey(ctx context.Context, db SQLQueries,
UNCOV
3252
        pubKey route.Vertex) (int64, *models.LightningNode, error) {
×
3253

×
3254
        dbNode, err := db.GetNodeByPubKey(
×
3255
                ctx, sqlc.GetNodeByPubKeyParams{
×
3256
                        Version: int16(ProtocolV1),
×
3257
                        PubKey:  pubKey[:],
×
3258
                },
×
3259
        )
×
3260
        if errors.Is(err, sql.ErrNoRows) {
×
3261
                return 0, nil, ErrGraphNodeNotFound
×
3262
        } else if err != nil {
×
3263
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3264
        }
×
3265

UNCOV
3266
        node, err := buildNode(ctx, db, &dbNode)
×
3267
        if err != nil {
×
3268
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3269
        }
×
3270

UNCOV
3271
        return dbNode.ID, node, nil
×
3272
}
3273

3274
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3275
// provided database channel row and the public keys of the two nodes
3276
// involved in the channel.
3277
func buildCacheableChannelInfo(dbChan sqlc.GraphChannel, node1Pub,
UNCOV
3278
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
3279

×
3280
        return &models.CachedEdgeInfo{
×
3281
                ChannelID:     byteOrder.Uint64(dbChan.Scid),
×
3282
                NodeKey1Bytes: node1Pub,
×
3283
                NodeKey2Bytes: node2Pub,
×
3284
                Capacity:      btcutil.Amount(dbChan.Capacity.Int64),
×
3285
        }
×
3286
}
×
3287

3288
// buildNode constructs a LightningNode instance from the given database node
3289
// record. The node's features, addresses and extra signed fields are also
3290
// fetched from the database and set on the node.
3291
func buildNode(ctx context.Context, db SQLQueries, dbNode *sqlc.GraphNode) (
UNCOV
3292
        *models.LightningNode, error) {
×
3293

×
3294
        if dbNode.Version != int16(ProtocolV1) {
×
3295
                return nil, fmt.Errorf("unsupported node version: %d",
×
3296
                        dbNode.Version)
×
3297
        }
×
3298

UNCOV
3299
        var pub [33]byte
×
3300
        copy(pub[:], dbNode.PubKey)
×
3301

×
3302
        node := &models.LightningNode{
×
3303
                PubKeyBytes: pub,
×
3304
                Features:    lnwire.EmptyFeatureVector(),
×
3305
                LastUpdate:  time.Unix(0, 0),
×
3306
        }
×
3307

×
3308
        if len(dbNode.Signature) == 0 {
×
3309
                return node, nil
×
3310
        }
×
3311

UNCOV
3312
        node.HaveNodeAnnouncement = true
×
3313
        node.AuthSigBytes = dbNode.Signature
×
3314
        node.Alias = dbNode.Alias.String
×
3315
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
3316

×
3317
        var err error
×
3318
        if dbNode.Color.Valid {
×
3319
                node.Color, err = DecodeHexColor(dbNode.Color.String)
×
3320
                if err != nil {
×
3321
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3322
                                err)
×
3323
                }
×
3324
        }
3325

3326
        // Fetch the node's features.
UNCOV
3327
        node.Features, err = getNodeFeatures(ctx, db, dbNode.ID)
×
3328
        if err != nil {
×
3329
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3330
                        "features: %w", dbNode.ID, err)
×
3331
        }
×
3332

3333
        // Fetch the node's addresses.
UNCOV
3334
        _, node.Addresses, err = getNodeAddresses(ctx, db, pub[:])
×
3335
        if err != nil {
×
3336
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3337
                        "addresses: %w", dbNode.ID, err)
×
3338
        }
×
3339

3340
        // Fetch the node's extra signed fields.
UNCOV
3341
        extraTLVMap, err := getNodeExtraSignedFields(ctx, db, dbNode.ID)
×
3342
        if err != nil {
×
3343
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3344
                        "extra signed fields: %w", dbNode.ID, err)
×
3345
        }
×
3346

UNCOV
3347
        recs, err := lnwire.CustomRecords(extraTLVMap).Serialize()
×
3348
        if err != nil {
×
3349
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
3350
                        "fields: %w", err)
×
3351
        }
×
3352

UNCOV
3353
        if len(recs) != 0 {
×
3354
                node.ExtraOpaqueData = recs
×
3355
        }
×
3356

UNCOV
3357
        return node, nil
×
3358
}
3359

3360
// getNodeFeatures fetches the feature bits and constructs the feature vector
3361
// for a node with the given DB ID.
3362
func getNodeFeatures(ctx context.Context, db SQLQueries,
UNCOV
3363
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3364

×
3365
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3366
        if err != nil {
×
3367
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3368
                        nodeID, err)
×
3369
        }
×
3370

UNCOV
3371
        features := lnwire.EmptyFeatureVector()
×
3372
        for _, feature := range rows {
×
3373
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3374
        }
×
3375

UNCOV
3376
        return features, nil
×
3377
}
3378

3379
// getNodeExtraSignedFields fetches the extra signed fields for a node with the
3380
// given DB ID.
3381
func getNodeExtraSignedFields(ctx context.Context, db SQLQueries,
UNCOV
3382
        nodeID int64) (map[uint64][]byte, error) {
×
3383

×
3384
        fields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3385
        if err != nil {
×
3386
                return nil, fmt.Errorf("unable to get node(%d) extra "+
×
3387
                        "signed fields: %w", nodeID, err)
×
3388
        }
×
3389

UNCOV
3390
        extraFields := make(map[uint64][]byte)
×
3391
        for _, field := range fields {
×
3392
                extraFields[uint64(field.Type)] = field.Value
×
3393
        }
×
3394

UNCOV
3395
        return extraFields, nil
×
3396
}
3397

3398
// upsertNode upserts the node record into the database. If the node already
3399
// exists, then the node's information is updated. If the node doesn't exist,
3400
// then a new node is created. The node's features, addresses and extra TLV
3401
// types are also updated. The node's DB ID is returned.
3402
func upsertNode(ctx context.Context, db SQLQueries,
UNCOV
3403
        node *models.LightningNode) (int64, error) {
×
3404

×
3405
        params := sqlc.UpsertNodeParams{
×
3406
                Version: int16(ProtocolV1),
×
3407
                PubKey:  node.PubKeyBytes[:],
×
3408
        }
×
3409

×
3410
        if node.HaveNodeAnnouncement {
×
3411
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
3412
                params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
×
3413
                params.Alias = sqldb.SQLStr(node.Alias)
×
3414
                params.Signature = node.AuthSigBytes
×
3415
        }
×
3416

UNCOV
3417
        nodeID, err := db.UpsertNode(ctx, params)
×
3418
        if err != nil {
×
3419
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3420
                        err)
×
3421
        }
×
3422

3423
        // We can exit here if we don't have the announcement yet.
UNCOV
3424
        if !node.HaveNodeAnnouncement {
×
3425
                return nodeID, nil
×
3426
        }
×
3427

3428
        // Update the node's features.
UNCOV
3429
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3430
        if err != nil {
×
3431
                return 0, fmt.Errorf("inserting node features: %w", err)
×
3432
        }
×
3433

3434
        // Update the node's addresses.
UNCOV
3435
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3436
        if err != nil {
×
3437
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
3438
        }
×
3439

3440
        // Convert the flat extra opaque data into a map of TLV types to
3441
        // values.
UNCOV
3442
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
3443
        if err != nil {
×
3444
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
3445
                        err)
×
3446
        }
×
3447

3448
        // Update the node's extra signed fields.
UNCOV
3449
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3450
        if err != nil {
×
3451
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
3452
        }
×
3453

UNCOV
3454
        return nodeID, nil
×
3455
}
3456

3457
// upsertNodeFeatures updates the node's features node_features table. This
3458
// includes deleting any feature bits no longer present and inserting any new
3459
// feature bits. If the feature bit does not yet exist in the features table,
3460
// then an entry is created in that table first.
3461
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
UNCOV
3462
        features *lnwire.FeatureVector) error {
×
3463

×
3464
        // Get any existing features for the node.
×
3465
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3466
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3467
                return err
×
3468
        }
×
3469

3470
        // Copy the nodes latest set of feature bits.
UNCOV
3471
        newFeatures := make(map[int32]struct{})
×
3472
        if features != nil {
×
3473
                for feature := range features.Features() {
×
3474
                        newFeatures[int32(feature)] = struct{}{}
×
3475
                }
×
3476
        }
3477

3478
        // For any current feature that already exists in the DB, remove it from
3479
        // the in-memory map. For any existing feature that does not exist in
3480
        // the in-memory map, delete it from the database.
UNCOV
3481
        for _, feature := range existingFeatures {
×
3482
                // The feature is still present, so there are no updates to be
×
3483
                // made.
×
3484
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3485
                        delete(newFeatures, feature.FeatureBit)
×
3486
                        continue
×
3487
                }
3488

3489
                // The feature is no longer present, so we remove it from the
3490
                // database.
UNCOV
3491
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
3492
                        NodeID:     nodeID,
×
3493
                        FeatureBit: feature.FeatureBit,
×
3494
                })
×
3495
                if err != nil {
×
3496
                        return fmt.Errorf("unable to delete node(%d) "+
×
3497
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3498
                                err)
×
3499
                }
×
3500
        }
3501

3502
        // Any remaining entries in newFeatures are new features that need to be
3503
        // added to the database for the first time.
UNCOV
3504
        for feature := range newFeatures {
×
3505
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3506
                        NodeID:     nodeID,
×
3507
                        FeatureBit: feature,
×
3508
                })
×
3509
                if err != nil {
×
3510
                        return fmt.Errorf("unable to insert node(%d) "+
×
3511
                                "feature(%v): %w", nodeID, feature, err)
×
3512
                }
×
3513
        }
3514

UNCOV
3515
        return nil
×
3516
}
3517

3518
// fetchNodeFeatures fetches the features for a node with the given public key.
3519
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
UNCOV
3520
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3521

×
3522
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3523
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3524
                        PubKey:  nodePub[:],
×
3525
                        Version: int16(ProtocolV1),
×
3526
                },
×
3527
        )
×
3528
        if err != nil {
×
3529
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
3530
                        nodePub, err)
×
3531
        }
×
3532

UNCOV
3533
        features := lnwire.EmptyFeatureVector()
×
3534
        for _, bit := range rows {
×
3535
                features.Set(lnwire.FeatureBit(bit))
×
3536
        }
×
3537

UNCOV
3538
        return features, nil
×
3539
}
3540

3541
// dbAddressType is an enum type that represents the different address types
3542
// that we store in the node_addresses table. The address type determines how
3543
// the address is to be serialised/deserialize.
3544
type dbAddressType uint8
3545

3546
const (
3547
        addressTypeIPv4   dbAddressType = 1
3548
        addressTypeIPv6   dbAddressType = 2
3549
        addressTypeTorV2  dbAddressType = 3
3550
        addressTypeTorV3  dbAddressType = 4
3551
        addressTypeOpaque dbAddressType = math.MaxInt8
3552
)
3553

3554
// upsertNodeAddresses updates the node's addresses in the database. This
3555
// includes deleting any existing addresses and inserting the new set of
3556
// addresses. The deletion is necessary since the ordering of the addresses may
3557
// change, and we need to ensure that the database reflects the latest set of
3558
// addresses so that at the time of reconstructing the node announcement, the
3559
// order is preserved and the signature over the message remains valid.
3560
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
UNCOV
3561
        addresses []net.Addr) error {
×
3562

×
3563
        // Delete any existing addresses for the node. This is required since
×
3564
        // even if the new set of addresses is the same, the ordering may have
×
3565
        // changed for a given address type.
×
3566
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
3567
        if err != nil {
×
3568
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
3569
                        nodeID, err)
×
3570
        }
×
3571

3572
        // Copy the nodes latest set of addresses.
UNCOV
3573
        newAddresses := map[dbAddressType][]string{
×
3574
                addressTypeIPv4:   {},
×
3575
                addressTypeIPv6:   {},
×
3576
                addressTypeTorV2:  {},
×
3577
                addressTypeTorV3:  {},
×
3578
                addressTypeOpaque: {},
×
3579
        }
×
3580
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3581
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3582
        }
×
3583

UNCOV
3584
        for _, address := range addresses {
×
3585
                switch addr := address.(type) {
×
3586
                case *net.TCPAddr:
×
3587
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3588
                                addAddr(addressTypeIPv4, addr)
×
3589
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3590
                                addAddr(addressTypeIPv6, addr)
×
3591
                        } else {
×
3592
                                return fmt.Errorf("unhandled IP address: %v",
×
3593
                                        addr)
×
3594
                        }
×
3595

UNCOV
3596
                case *tor.OnionAddr:
×
3597
                        switch len(addr.OnionService) {
×
3598
                        case tor.V2Len:
×
3599
                                addAddr(addressTypeTorV2, addr)
×
3600
                        case tor.V3Len:
×
3601
                                addAddr(addressTypeTorV3, addr)
×
3602
                        default:
×
3603
                                return fmt.Errorf("invalid length for a tor " +
×
3604
                                        "address")
×
3605
                        }
3606

UNCOV
3607
                case *lnwire.OpaqueAddrs:
×
3608
                        addAddr(addressTypeOpaque, addr)
×
3609

UNCOV
3610
                default:
×
3611
                        return fmt.Errorf("unhandled address type: %T", addr)
×
3612
                }
3613
        }
3614

3615
        // Any remaining entries in newAddresses are new addresses that need to
3616
        // be added to the database for the first time.
UNCOV
3617
        for addrType, addrList := range newAddresses {
×
3618
                for position, addr := range addrList {
×
3619
                        err := db.InsertNodeAddress(
×
3620
                                ctx, sqlc.InsertNodeAddressParams{
×
3621
                                        NodeID:   nodeID,
×
3622
                                        Type:     int16(addrType),
×
3623
                                        Address:  addr,
×
3624
                                        Position: int32(position),
×
3625
                                },
×
3626
                        )
×
3627
                        if err != nil {
×
3628
                                return fmt.Errorf("unable to insert "+
×
3629
                                        "node(%d) address(%v): %w", nodeID,
×
3630
                                        addr, err)
×
3631
                        }
×
3632
                }
3633
        }
3634

UNCOV
3635
        return nil
×
3636
}
3637

3638
// getNodeAddresses fetches the addresses for a node with the given public key.
3639
func getNodeAddresses(ctx context.Context, db SQLQueries, nodePub []byte) (bool,
UNCOV
3640
        []net.Addr, error) {
×
3641

×
3642
        // GetNodeAddressesByPubKey ensures that the addresses for a given type
×
3643
        // are returned in the same order as they were inserted.
×
3644
        rows, err := db.GetNodeAddressesByPubKey(
×
3645
                ctx, sqlc.GetNodeAddressesByPubKeyParams{
×
3646
                        Version: int16(ProtocolV1),
×
3647
                        PubKey:  nodePub,
×
3648
                },
×
3649
        )
×
3650
        if err != nil {
×
3651
                return false, nil, err
×
3652
        }
×
3653

3654
        // GetNodeAddressesByPubKey uses a left join so there should always be
3655
        // at least one row returned if the node exists even if it has no
3656
        // addresses.
UNCOV
3657
        if len(rows) == 0 {
×
3658
                return false, nil, nil
×
3659
        }
×
3660

UNCOV
3661
        addresses := make([]net.Addr, 0, len(rows))
×
3662
        for _, addr := range rows {
×
3663
                if !(addr.Type.Valid && addr.Address.Valid) {
×
3664
                        continue
×
3665
                }
3666

UNCOV
3667
                address := addr.Address.String
×
3668

×
3669
                switch dbAddressType(addr.Type.Int16) {
×
3670
                case addressTypeIPv4:
×
3671
                        tcp, err := net.ResolveTCPAddr("tcp4", address)
×
3672
                        if err != nil {
×
3673
                                return false, nil, nil
×
3674
                        }
×
3675
                        tcp.IP = tcp.IP.To4()
×
3676

×
3677
                        addresses = append(addresses, tcp)
×
3678

UNCOV
3679
                case addressTypeIPv6:
×
3680
                        tcp, err := net.ResolveTCPAddr("tcp6", address)
×
3681
                        if err != nil {
×
3682
                                return false, nil, nil
×
3683
                        }
×
3684
                        addresses = append(addresses, tcp)
×
3685

UNCOV
3686
                case addressTypeTorV3, addressTypeTorV2:
×
3687
                        service, portStr, err := net.SplitHostPort(address)
×
3688
                        if err != nil {
×
3689
                                return false, nil, fmt.Errorf("unable to "+
×
3690
                                        "split tor v3 address: %v",
×
3691
                                        addr.Address)
×
3692
                        }
×
3693

UNCOV
3694
                        port, err := strconv.Atoi(portStr)
×
3695
                        if err != nil {
×
3696
                                return false, nil, err
×
3697
                        }
×
3698

UNCOV
3699
                        addresses = append(addresses, &tor.OnionAddr{
×
3700
                                OnionService: service,
×
3701
                                Port:         port,
×
3702
                        })
×
3703

UNCOV
3704
                case addressTypeOpaque:
×
3705
                        opaque, err := hex.DecodeString(address)
×
3706
                        if err != nil {
×
3707
                                return false, nil, fmt.Errorf("unable to "+
×
3708
                                        "decode opaque address: %v", addr)
×
3709
                        }
×
3710

UNCOV
3711
                        addresses = append(addresses, &lnwire.OpaqueAddrs{
×
3712
                                Payload: opaque,
×
3713
                        })
×
3714

UNCOV
3715
                default:
×
3716
                        return false, nil, fmt.Errorf("unknown address "+
×
3717
                                "type: %v", addr.Type)
×
3718
                }
3719
        }
3720

3721
        // If we have no addresses, then we'll return nil instead of an
3722
        // empty slice.
UNCOV
3723
        if len(addresses) == 0 {
×
3724
                addresses = nil
×
3725
        }
×
3726

UNCOV
3727
        return true, addresses, nil
×
3728
}
3729

3730
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
3731
// database. This includes updating any existing types, inserting any new types,
3732
// and deleting any types that are no longer present.
3733
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
UNCOV
3734
        nodeID int64, extraFields map[uint64][]byte) error {
×
3735

×
3736
        // Get any existing extra signed fields for the node.
×
3737
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3738
        if err != nil {
×
3739
                return err
×
3740
        }
×
3741

3742
        // Make a lookup map of the existing field types so that we can use it
3743
        // to keep track of any fields we should delete.
UNCOV
3744
        m := make(map[uint64]bool)
×
3745
        for _, field := range existingFields {
×
3746
                m[uint64(field.Type)] = true
×
3747
        }
×
3748

3749
        // For all the new fields, we'll upsert them and remove them from the
3750
        // map of existing fields.
UNCOV
3751
        for tlvType, value := range extraFields {
×
3752
                err = db.UpsertNodeExtraType(
×
3753
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
3754
                                NodeID: nodeID,
×
3755
                                Type:   int64(tlvType),
×
3756
                                Value:  value,
×
3757
                        },
×
3758
                )
×
3759
                if err != nil {
×
3760
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
3761
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3762
                }
×
3763

3764
                // Remove the field from the map of existing fields if it was
3765
                // present.
UNCOV
3766
                delete(m, tlvType)
×
3767
        }
3768

3769
        // For all the fields that are left in the map of existing fields, we'll
3770
        // delete them as they are no longer present in the new set of fields.
UNCOV
3771
        for tlvType := range m {
×
3772
                err = db.DeleteExtraNodeType(
×
3773
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
3774
                                NodeID: nodeID,
×
3775
                                Type:   int64(tlvType),
×
3776
                        },
×
3777
                )
×
3778
                if err != nil {
×
3779
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
3780
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3781
                }
×
3782
        }
3783

UNCOV
3784
        return nil
×
3785
}
3786

3787
// srcNodeInfo holds the information about the source node of the graph.
3788
type srcNodeInfo struct {
3789
        // id is the DB level ID of the source node entry in the "nodes" table.
3790
        id int64
3791

3792
        // pub is the public key of the source node.
3793
        pub route.Vertex
3794
}
3795

3796
// sourceNode returns the DB node ID and pub key of the source node for the
3797
// specified protocol version.
3798
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
UNCOV
3799
        version ProtocolVersion) (int64, route.Vertex, error) {
×
3800

×
3801
        s.srcNodeMu.Lock()
×
3802
        defer s.srcNodeMu.Unlock()
×
3803

×
3804
        // If we already have the source node ID and pub key cached, then
×
3805
        // return them.
×
3806
        if info, ok := s.srcNodes[version]; ok {
×
3807
                return info.id, info.pub, nil
×
3808
        }
×
3809

UNCOV
3810
        var pubKey route.Vertex
×
3811

×
3812
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
3813
        if err != nil {
×
3814
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
3815
                        err)
×
3816
        }
×
3817

UNCOV
3818
        if len(nodes) == 0 {
×
3819
                return 0, pubKey, ErrSourceNodeNotSet
×
3820
        } else if len(nodes) > 1 {
×
3821
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
3822
                        "protocol %s found", version)
×
3823
        }
×
3824

UNCOV
3825
        copy(pubKey[:], nodes[0].PubKey)
×
3826

×
3827
        s.srcNodes[version] = &srcNodeInfo{
×
3828
                id:  nodes[0].NodeID,
×
3829
                pub: pubKey,
×
3830
        }
×
3831

×
3832
        return nodes[0].NodeID, pubKey, nil
×
3833
}
3834

3835
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
3836
// This then produces a map from TLV type to value. If the input is not a
3837
// valid TLV stream, then an error is returned.
UNCOV
3838
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
3839
        r := bytes.NewReader(data)
×
3840

×
3841
        tlvStream, err := tlv.NewStream()
×
3842
        if err != nil {
×
3843
                return nil, err
×
3844
        }
×
3845

3846
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
3847
        // pass it into the P2P decoding variant.
UNCOV
3848
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
3849
        if err != nil {
×
3850
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
3851
        }
×
3852
        if len(parsedTypes) == 0 {
×
3853
                return nil, nil
×
3854
        }
×
3855

UNCOV
3856
        records := make(map[uint64][]byte)
×
3857
        for k, v := range parsedTypes {
×
3858
                records[uint64(k)] = v
×
3859
        }
×
3860

UNCOV
3861
        return records, nil
×
3862
}
3863

3864
// dbChanInfo holds the DB level IDs of a channel and the nodes involved in the
3865
// channel.
3866
type dbChanInfo struct {
3867
        channelID int64
3868
        node1ID   int64
3869
        node2ID   int64
3870
}
3871

3872
// insertChannel inserts a new channel record into the database.
3873
func insertChannel(ctx context.Context, db SQLQueries,
UNCOV
3874
        edge *models.ChannelEdgeInfo) (*dbChanInfo, error) {
×
3875

×
3876
        chanIDB := channelIDToBytes(edge.ChannelID)
×
3877

×
3878
        // Make sure that the channel doesn't already exist. We do this
×
3879
        // explicitly instead of relying on catching a unique constraint error
×
3880
        // because relying on SQL to throw that error would abort the entire
×
3881
        // batch of transactions.
×
3882
        _, err := db.GetChannelBySCID(
×
3883
                ctx, sqlc.GetChannelBySCIDParams{
×
3884
                        Scid:    chanIDB,
×
3885
                        Version: int16(ProtocolV1),
×
3886
                },
×
3887
        )
×
3888
        if err == nil {
×
3889
                return nil, ErrEdgeAlreadyExist
×
3890
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3891
                return nil, fmt.Errorf("unable to fetch channel: %w", err)
×
3892
        }
×
3893

3894
        // Make sure that at least a "shell" entry for each node is present in
3895
        // the nodes table.
UNCOV
3896
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
3897
        if err != nil {
×
3898
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3899
        }
×
3900

UNCOV
3901
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
3902
        if err != nil {
×
3903
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3904
        }
×
3905

UNCOV
3906
        var capacity sql.NullInt64
×
3907
        if edge.Capacity != 0 {
×
3908
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
3909
        }
×
3910

UNCOV
3911
        createParams := sqlc.CreateChannelParams{
×
3912
                Version:     int16(ProtocolV1),
×
3913
                Scid:        chanIDB,
×
3914
                NodeID1:     node1DBID,
×
3915
                NodeID2:     node2DBID,
×
3916
                Outpoint:    edge.ChannelPoint.String(),
×
3917
                Capacity:    capacity,
×
3918
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
3919
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
3920
        }
×
3921

×
3922
        if edge.AuthProof != nil {
×
3923
                proof := edge.AuthProof
×
3924

×
3925
                createParams.Node1Signature = proof.NodeSig1Bytes
×
3926
                createParams.Node2Signature = proof.NodeSig2Bytes
×
3927
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
3928
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
3929
        }
×
3930

3931
        // Insert the new channel record.
UNCOV
3932
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
3933
        if err != nil {
×
3934
                return nil, err
×
3935
        }
×
3936

3937
        // Insert any channel features.
UNCOV
3938
        for feature := range edge.Features.Features() {
×
3939
                err = db.InsertChannelFeature(
×
3940
                        ctx, sqlc.InsertChannelFeatureParams{
×
3941
                                ChannelID:  dbChanID,
×
3942
                                FeatureBit: int32(feature),
×
3943
                        },
×
3944
                )
×
3945
                if err != nil {
×
3946
                        return nil, fmt.Errorf("unable to insert channel(%d) "+
×
3947
                                "feature(%v): %w", dbChanID, feature, err)
×
3948
                }
×
3949
        }
3950

3951
        // Finally, insert any extra TLV fields in the channel announcement.
UNCOV
3952
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3953
        if err != nil {
×
3954
                return nil, fmt.Errorf("unable to marshal extra opaque "+
×
3955
                        "data: %w", err)
×
3956
        }
×
3957

UNCOV
3958
        for tlvType, value := range extra {
×
3959
                err := db.CreateChannelExtraType(
×
3960
                        ctx, sqlc.CreateChannelExtraTypeParams{
×
3961
                                ChannelID: dbChanID,
×
3962
                                Type:      int64(tlvType),
×
3963
                                Value:     value,
×
3964
                        },
×
3965
                )
×
3966
                if err != nil {
×
3967
                        return nil, fmt.Errorf("unable to upsert "+
×
3968
                                "channel(%d) extra signed field(%v): %w",
×
3969
                                edge.ChannelID, tlvType, err)
×
3970
                }
×
3971
        }
3972

UNCOV
3973
        return &dbChanInfo{
×
3974
                channelID: dbChanID,
×
3975
                node1ID:   node1DBID,
×
3976
                node2ID:   node2DBID,
×
3977
        }, nil
×
3978
}
3979

3980
// maybeCreateShellNode checks if a shell node entry exists for the
3981
// given public key. If it does not exist, then a new shell node entry is
3982
// created. The ID of the node is returned. A shell node only has a protocol
3983
// version and public key persisted.
3984
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
UNCOV
3985
        pubKey route.Vertex) (int64, error) {
×
3986

×
3987
        dbNode, err := db.GetNodeByPubKey(
×
3988
                ctx, sqlc.GetNodeByPubKeyParams{
×
3989
                        PubKey:  pubKey[:],
×
3990
                        Version: int16(ProtocolV1),
×
3991
                },
×
3992
        )
×
3993
        // The node exists. Return the ID.
×
3994
        if err == nil {
×
3995
                return dbNode.ID, nil
×
3996
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3997
                return 0, err
×
3998
        }
×
3999

4000
        // Otherwise, the node does not exist, so we create a shell entry for
4001
        // it.
UNCOV
4002
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
4003
                Version: int16(ProtocolV1),
×
4004
                PubKey:  pubKey[:],
×
4005
        })
×
4006
        if err != nil {
×
4007
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
4008
        }
×
4009

UNCOV
4010
        return id, nil
×
4011
}
4012

4013
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
4014
// the database. This includes deleting any existing types and then inserting
4015
// the new types.
4016
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
UNCOV
4017
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
4018

×
4019
        // Delete all existing extra signed fields for the channel policy.
×
4020
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
4021
        if err != nil {
×
4022
                return fmt.Errorf("unable to delete "+
×
4023
                        "existing policy extra signed fields for policy %d: %w",
×
4024
                        chanPolicyID, err)
×
4025
        }
×
4026

4027
        // Insert all new extra signed fields for the channel policy.
UNCOV
4028
        for tlvType, value := range extraFields {
×
4029
                err = db.InsertChanPolicyExtraType(
×
4030
                        ctx, sqlc.InsertChanPolicyExtraTypeParams{
×
4031
                                ChannelPolicyID: chanPolicyID,
×
4032
                                Type:            int64(tlvType),
×
4033
                                Value:           value,
×
4034
                        },
×
4035
                )
×
4036
                if err != nil {
×
4037
                        return fmt.Errorf("unable to insert "+
×
4038
                                "channel_policy(%d) extra signed field(%v): %w",
×
4039
                                chanPolicyID, tlvType, err)
×
4040
                }
×
4041
        }
4042

UNCOV
4043
        return nil
×
4044
}
4045

4046
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
4047
// provided dbChanRow and also fetches any other required information
4048
// to construct the edge info.
4049
func getAndBuildEdgeInfo(ctx context.Context, db SQLQueries,
4050
        chain chainhash.Hash, dbChanID int64, dbChan sqlc.GraphChannel, node1,
UNCOV
4051
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
4052

×
4053
        if dbChan.Version != int16(ProtocolV1) {
×
4054
                return nil, fmt.Errorf("unsupported channel version: %d",
×
4055
                        dbChan.Version)
×
4056
        }
×
4057

UNCOV
4058
        fv, extras, err := getChanFeaturesAndExtras(
×
4059
                ctx, db, dbChanID,
×
4060
        )
×
4061
        if err != nil {
×
4062
                return nil, err
×
4063
        }
×
4064

UNCOV
4065
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
4066
        if err != nil {
×
4067
                return nil, err
×
4068
        }
×
4069

UNCOV
4070
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4071
        if err != nil {
×
4072
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4073
                        "fields: %w", err)
×
4074
        }
×
4075
        if recs == nil {
×
4076
                recs = make([]byte, 0)
×
4077
        }
×
4078

UNCOV
4079
        var btcKey1, btcKey2 route.Vertex
×
4080
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
4081
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
4082

×
4083
        channel := &models.ChannelEdgeInfo{
×
4084
                ChainHash:        chain,
×
4085
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
4086
                NodeKey1Bytes:    node1,
×
4087
                NodeKey2Bytes:    node2,
×
4088
                BitcoinKey1Bytes: btcKey1,
×
4089
                BitcoinKey2Bytes: btcKey2,
×
4090
                ChannelPoint:     *op,
×
4091
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
4092
                Features:         fv,
×
4093
                ExtraOpaqueData:  recs,
×
4094
        }
×
4095

×
4096
        // We always set all the signatures at the same time, so we can
×
4097
        // safely check if one signature is present to determine if we have the
×
4098
        // rest of the signatures for the auth proof.
×
4099
        if len(dbChan.Bitcoin1Signature) > 0 {
×
4100
                channel.AuthProof = &models.ChannelAuthProof{
×
4101
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
4102
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
4103
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
4104
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
4105
                }
×
4106
        }
×
4107

UNCOV
4108
        return channel, nil
×
4109
}
4110

4111
// buildNodeVertices is a helper that converts raw node public keys
4112
// into route.Vertex instances.
4113
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
UNCOV
4114
        route.Vertex, error) {
×
4115

×
4116
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4117
        if err != nil {
×
4118
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4119
                        "create vertex from node1 pubkey: %w", err)
×
4120
        }
×
4121

UNCOV
4122
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
4123
        if err != nil {
×
4124
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4125
                        "create vertex from node2 pubkey: %w", err)
×
4126
        }
×
4127

UNCOV
4128
        return node1Vertex, node2Vertex, nil
×
4129
}
4130

4131
// getChanFeaturesAndExtras fetches the channel features and extra TLV types
4132
// for a channel with the given ID.
4133
func getChanFeaturesAndExtras(ctx context.Context, db SQLQueries,
UNCOV
4134
        id int64) (*lnwire.FeatureVector, map[uint64][]byte, error) {
×
4135

×
4136
        rows, err := db.GetChannelFeaturesAndExtras(ctx, id)
×
4137
        if err != nil {
×
4138
                return nil, nil, fmt.Errorf("unable to fetch channel "+
×
4139
                        "features and extras: %w", err)
×
4140
        }
×
4141

UNCOV
4142
        var (
×
4143
                fv     = lnwire.EmptyFeatureVector()
×
4144
                extras = make(map[uint64][]byte)
×
4145
        )
×
4146
        for _, row := range rows {
×
4147
                if row.IsFeature {
×
4148
                        fv.Set(lnwire.FeatureBit(row.FeatureBit))
×
4149

×
4150
                        continue
×
4151
                }
4152

UNCOV
4153
                tlvType, ok := row.ExtraKey.(int64)
×
4154
                if !ok {
×
4155
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
4156
                                "TLV type: %T", row.ExtraKey)
×
4157
                }
×
4158

UNCOV
4159
                valueBytes, ok := row.Value.([]byte)
×
4160
                if !ok {
×
4161
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
4162
                                "Value: %T", row.Value)
×
4163
                }
×
4164

UNCOV
4165
                extras[uint64(tlvType)] = valueBytes
×
4166
        }
4167

UNCOV
4168
        return fv, extras, nil
×
4169
}
4170

4171
// getAndBuildChanPolicies uses the given sqlc.GraphChannelPolicy and also
4172
// retrieves all the extra info required to build the complete
4173
// models.ChannelEdgePolicy types. It returns two policies, which may be nil if
4174
// the provided sqlc.GraphChannelPolicy records are nil.
4175
func getAndBuildChanPolicies(ctx context.Context, db SQLQueries,
4176
        dbPol1, dbPol2 *sqlc.GraphChannelPolicy, channelID uint64, node1,
4177
        node2 route.Vertex) (*models.ChannelEdgePolicy,
UNCOV
4178
        *models.ChannelEdgePolicy, error) {
×
4179

×
4180
        if dbPol1 == nil && dbPol2 == nil {
×
4181
                return nil, nil, nil
×
4182
        }
×
4183

UNCOV
4184
        var (
×
4185
                policy1ID int64
×
4186
                policy2ID int64
×
4187
        )
×
4188
        if dbPol1 != nil {
×
4189
                policy1ID = dbPol1.ID
×
4190
        }
×
4191
        if dbPol2 != nil {
×
4192
                policy2ID = dbPol2.ID
×
4193
        }
×
4194
        rows, err := db.GetChannelPolicyExtraTypes(
×
4195
                ctx, sqlc.GetChannelPolicyExtraTypesParams{
×
4196
                        ID:   policy1ID,
×
4197
                        ID_2: policy2ID,
×
4198
                },
×
4199
        )
×
4200
        if err != nil {
×
4201
                return nil, nil, err
×
4202
        }
×
4203

UNCOV
4204
        var (
×
4205
                dbPol1Extras = make(map[uint64][]byte)
×
4206
                dbPol2Extras = make(map[uint64][]byte)
×
4207
        )
×
4208
        for _, row := range rows {
×
4209
                switch row.PolicyID {
×
4210
                case policy1ID:
×
4211
                        dbPol1Extras[uint64(row.Type)] = row.Value
×
4212
                case policy2ID:
×
4213
                        dbPol2Extras[uint64(row.Type)] = row.Value
×
4214
                default:
×
4215
                        return nil, nil, fmt.Errorf("unexpected policy ID %d "+
×
4216
                                "in row: %v", row.PolicyID, row)
×
4217
                }
4218
        }
4219

UNCOV
4220
        var pol1, pol2 *models.ChannelEdgePolicy
×
4221
        if dbPol1 != nil {
×
4222
                pol1, err = buildChanPolicy(
×
4223
                        *dbPol1, channelID, dbPol1Extras, node2,
×
4224
                )
×
4225
                if err != nil {
×
4226
                        return nil, nil, err
×
4227
                }
×
4228
        }
UNCOV
4229
        if dbPol2 != nil {
×
4230
                pol2, err = buildChanPolicy(
×
4231
                        *dbPol2, channelID, dbPol2Extras, node1,
×
4232
                )
×
4233
                if err != nil {
×
4234
                        return nil, nil, err
×
4235
                }
×
4236
        }
4237

UNCOV
4238
        return pol1, pol2, nil
×
4239
}
4240

4241
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4242
// provided sqlc.GraphChannelPolicy and other required information.
4243
func buildChanPolicy(dbPolicy sqlc.GraphChannelPolicy, channelID uint64,
4244
        extras map[uint64][]byte,
UNCOV
4245
        toNode route.Vertex) (*models.ChannelEdgePolicy, error) {
×
4246

×
4247
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4248
        if err != nil {
×
4249
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4250
                        "fields: %w", err)
×
4251
        }
×
4252

UNCOV
4253
        var inboundFee fn.Option[lnwire.Fee]
×
4254
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
4255
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4256

×
4257
                inboundFee = fn.Some(lnwire.Fee{
×
4258
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4259
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4260
                })
×
4261
        }
×
4262

UNCOV
4263
        return &models.ChannelEdgePolicy{
×
4264
                SigBytes:  dbPolicy.Signature,
×
4265
                ChannelID: channelID,
×
4266
                LastUpdate: time.Unix(
×
4267
                        dbPolicy.LastUpdate.Int64, 0,
×
4268
                ),
×
4269
                MessageFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags](
×
4270
                        dbPolicy.MessageFlags,
×
4271
                ),
×
4272
                ChannelFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags](
×
4273
                        dbPolicy.ChannelFlags,
×
4274
                ),
×
4275
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
4276
                MinHTLC: lnwire.MilliSatoshi(
×
4277
                        dbPolicy.MinHtlcMsat,
×
4278
                ),
×
4279
                MaxHTLC: lnwire.MilliSatoshi(
×
4280
                        dbPolicy.MaxHtlcMsat.Int64,
×
4281
                ),
×
4282
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4283
                        dbPolicy.BaseFeeMsat,
×
4284
                ),
×
4285
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4286
                ToNode:                    toNode,
×
4287
                InboundFee:                inboundFee,
×
4288
                ExtraOpaqueData:           recs,
×
4289
        }, nil
×
4290
}
4291

4292
// buildNodes builds the models.LightningNode instances for the
4293
// given row which is expected to be a sqlc type that contains node information.
4294
func buildNodes(ctx context.Context, db SQLQueries, dbNode1,
4295
        dbNode2 sqlc.GraphNode) (*models.LightningNode, *models.LightningNode,
UNCOV
4296
        error) {
×
4297

×
4298
        node1, err := buildNode(ctx, db, &dbNode1)
×
4299
        if err != nil {
×
4300
                return nil, nil, err
×
4301
        }
×
4302

UNCOV
4303
        node2, err := buildNode(ctx, db, &dbNode2)
×
4304
        if err != nil {
×
4305
                return nil, nil, err
×
4306
        }
×
4307

UNCOV
4308
        return node1, node2, nil
×
4309
}
4310

4311
// extractChannelPolicies extracts the sqlc.GraphChannelPolicy records from the give
4312
// row which is expected to be a sqlc type that contains channel policy
4313
// information. It returns two policies, which may be nil if the policy
4314
// information is not present in the row.
4315
//
4316
//nolint:ll,dupl,funlen
4317
func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy,
UNCOV
4318
        *sqlc.GraphChannelPolicy, error) {
×
4319

×
4320
        var policy1, policy2 *sqlc.GraphChannelPolicy
×
4321
        switch r := row.(type) {
×
4322
        case sqlc.GetChannelsBySCIDWithPoliciesRow:
×
NEW
4323
                if r.Policy1ID.Valid {
×
NEW
4324
                        policy1 = &sqlc.GraphChannelPolicy{
×
NEW
4325
                                ID:                      r.Policy1ID.Int64,
×
NEW
4326
                                Version:                 r.Policy1Version.Int16,
×
NEW
4327
                                ChannelID:               r.GraphChannel.ID,
×
NEW
4328
                                NodeID:                  r.Policy1NodeID.Int64,
×
NEW
4329
                                Timelock:                r.Policy1Timelock.Int32,
×
NEW
4330
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
NEW
4331
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
NEW
4332
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
NEW
4333
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
NEW
4334
                                LastUpdate:              r.Policy1LastUpdate,
×
NEW
4335
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
NEW
4336
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
NEW
4337
                                Disabled:                r.Policy1Disabled,
×
NEW
4338
                                MessageFlags:            r.Policy1MessageFlags,
×
NEW
4339
                                ChannelFlags:            r.Policy1ChannelFlags,
×
NEW
4340
                                Signature:               r.Policy1Signature,
×
NEW
4341
                        }
×
NEW
4342
                }
×
NEW
4343
                if r.Policy2ID.Valid {
×
NEW
4344
                        policy2 = &sqlc.GraphChannelPolicy{
×
NEW
4345
                                ID:                      r.Policy2ID.Int64,
×
NEW
4346
                                Version:                 r.Policy2Version.Int16,
×
NEW
4347
                                ChannelID:               r.GraphChannel.ID,
×
NEW
4348
                                NodeID:                  r.Policy2NodeID.Int64,
×
NEW
4349
                                Timelock:                r.Policy2Timelock.Int32,
×
NEW
4350
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
NEW
4351
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
NEW
4352
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
NEW
4353
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
NEW
4354
                                LastUpdate:              r.Policy2LastUpdate,
×
NEW
4355
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
NEW
4356
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
NEW
4357
                                Disabled:                r.Policy2Disabled,
×
NEW
4358
                                MessageFlags:            r.Policy2MessageFlags,
×
NEW
4359
                                ChannelFlags:            r.Policy2ChannelFlags,
×
NEW
4360
                                Signature:               r.Policy2Signature,
×
NEW
4361
                        }
×
NEW
4362
                }
×
4363

NEW
4364
                return policy1, policy2, nil
×
4365

NEW
4366
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
4367
                if r.Policy1ID.Valid {
×
4368
                        policy1 = &sqlc.GraphChannelPolicy{
×
4369
                                ID:                      r.Policy1ID.Int64,
×
4370
                                Version:                 r.Policy1Version.Int16,
×
4371
                                ChannelID:               r.GraphChannel.ID,
×
4372
                                NodeID:                  r.Policy1NodeID.Int64,
×
4373
                                Timelock:                r.Policy1Timelock.Int32,
×
4374
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4375
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4376
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4377
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4378
                                LastUpdate:              r.Policy1LastUpdate,
×
4379
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4380
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4381
                                Disabled:                r.Policy1Disabled,
×
4382
                                MessageFlags:            r.Policy1MessageFlags,
×
4383
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4384
                                Signature:               r.Policy1Signature,
×
4385
                        }
×
4386
                }
×
4387
                if r.Policy2ID.Valid {
×
4388
                        policy2 = &sqlc.GraphChannelPolicy{
×
4389
                                ID:                      r.Policy2ID.Int64,
×
4390
                                Version:                 r.Policy2Version.Int16,
×
4391
                                ChannelID:               r.GraphChannel.ID,
×
4392
                                NodeID:                  r.Policy2NodeID.Int64,
×
4393
                                Timelock:                r.Policy2Timelock.Int32,
×
4394
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4395
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4396
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4397
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4398
                                LastUpdate:              r.Policy2LastUpdate,
×
4399
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4400
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4401
                                Disabled:                r.Policy2Disabled,
×
4402
                                MessageFlags:            r.Policy2MessageFlags,
×
4403
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4404
                                Signature:               r.Policy2Signature,
×
4405
                        }
×
4406
                }
×
4407

UNCOV
4408
                return policy1, policy2, nil
×
4409

UNCOV
4410
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4411
                if r.Policy1ID.Valid {
×
4412
                        policy1 = &sqlc.GraphChannelPolicy{
×
4413
                                ID:                      r.Policy1ID.Int64,
×
4414
                                Version:                 r.Policy1Version.Int16,
×
4415
                                ChannelID:               r.GraphChannel.ID,
×
4416
                                NodeID:                  r.Policy1NodeID.Int64,
×
4417
                                Timelock:                r.Policy1Timelock.Int32,
×
4418
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4419
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4420
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4421
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4422
                                LastUpdate:              r.Policy1LastUpdate,
×
4423
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4424
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4425
                                Disabled:                r.Policy1Disabled,
×
4426
                                MessageFlags:            r.Policy1MessageFlags,
×
4427
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4428
                                Signature:               r.Policy1Signature,
×
4429
                        }
×
4430
                }
×
4431
                if r.Policy2ID.Valid {
×
4432
                        policy2 = &sqlc.GraphChannelPolicy{
×
4433
                                ID:                      r.Policy2ID.Int64,
×
4434
                                Version:                 r.Policy2Version.Int16,
×
4435
                                ChannelID:               r.GraphChannel.ID,
×
4436
                                NodeID:                  r.Policy2NodeID.Int64,
×
4437
                                Timelock:                r.Policy2Timelock.Int32,
×
4438
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4439
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4440
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4441
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4442
                                LastUpdate:              r.Policy2LastUpdate,
×
4443
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4444
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4445
                                Disabled:                r.Policy2Disabled,
×
4446
                                MessageFlags:            r.Policy2MessageFlags,
×
4447
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4448
                                Signature:               r.Policy2Signature,
×
4449
                        }
×
4450
                }
×
4451

UNCOV
4452
                return policy1, policy2, nil
×
4453

UNCOV
4454
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4455
                if r.Policy1ID.Valid {
×
4456
                        policy1 = &sqlc.GraphChannelPolicy{
×
4457
                                ID:                      r.Policy1ID.Int64,
×
4458
                                Version:                 r.Policy1Version.Int16,
×
4459
                                ChannelID:               r.GraphChannel.ID,
×
4460
                                NodeID:                  r.Policy1NodeID.Int64,
×
4461
                                Timelock:                r.Policy1Timelock.Int32,
×
4462
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4463
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4464
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4465
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4466
                                LastUpdate:              r.Policy1LastUpdate,
×
4467
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4468
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4469
                                Disabled:                r.Policy1Disabled,
×
4470
                                MessageFlags:            r.Policy1MessageFlags,
×
4471
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4472
                                Signature:               r.Policy1Signature,
×
4473
                        }
×
4474
                }
×
4475
                if r.Policy2ID.Valid {
×
4476
                        policy2 = &sqlc.GraphChannelPolicy{
×
4477
                                ID:                      r.Policy2ID.Int64,
×
4478
                                Version:                 r.Policy2Version.Int16,
×
4479
                                ChannelID:               r.GraphChannel.ID,
×
4480
                                NodeID:                  r.Policy2NodeID.Int64,
×
4481
                                Timelock:                r.Policy2Timelock.Int32,
×
4482
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4483
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4484
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4485
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4486
                                LastUpdate:              r.Policy2LastUpdate,
×
4487
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4488
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4489
                                Disabled:                r.Policy2Disabled,
×
4490
                                MessageFlags:            r.Policy2MessageFlags,
×
4491
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4492
                                Signature:               r.Policy2Signature,
×
4493
                        }
×
4494
                }
×
4495

UNCOV
4496
                return policy1, policy2, nil
×
4497

UNCOV
4498
        case sqlc.ListChannelsByNodeIDRow:
×
4499
                if r.Policy1ID.Valid {
×
4500
                        policy1 = &sqlc.GraphChannelPolicy{
×
4501
                                ID:                      r.Policy1ID.Int64,
×
4502
                                Version:                 r.Policy1Version.Int16,
×
4503
                                ChannelID:               r.GraphChannel.ID,
×
4504
                                NodeID:                  r.Policy1NodeID.Int64,
×
4505
                                Timelock:                r.Policy1Timelock.Int32,
×
4506
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4507
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4508
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4509
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4510
                                LastUpdate:              r.Policy1LastUpdate,
×
4511
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4512
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4513
                                Disabled:                r.Policy1Disabled,
×
4514
                                MessageFlags:            r.Policy1MessageFlags,
×
4515
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4516
                                Signature:               r.Policy1Signature,
×
4517
                        }
×
4518
                }
×
4519
                if r.Policy2ID.Valid {
×
4520
                        policy2 = &sqlc.GraphChannelPolicy{
×
4521
                                ID:                      r.Policy2ID.Int64,
×
4522
                                Version:                 r.Policy2Version.Int16,
×
4523
                                ChannelID:               r.GraphChannel.ID,
×
4524
                                NodeID:                  r.Policy2NodeID.Int64,
×
4525
                                Timelock:                r.Policy2Timelock.Int32,
×
4526
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4527
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4528
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4529
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4530
                                LastUpdate:              r.Policy2LastUpdate,
×
4531
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4532
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4533
                                Disabled:                r.Policy2Disabled,
×
4534
                                MessageFlags:            r.Policy2MessageFlags,
×
4535
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4536
                                Signature:               r.Policy2Signature,
×
4537
                        }
×
4538
                }
×
4539

UNCOV
4540
                return policy1, policy2, nil
×
4541

UNCOV
4542
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4543
                if r.Policy1ID.Valid {
×
4544
                        policy1 = &sqlc.GraphChannelPolicy{
×
4545
                                ID:                      r.Policy1ID.Int64,
×
4546
                                Version:                 r.Policy1Version.Int16,
×
4547
                                ChannelID:               r.GraphChannel.ID,
×
4548
                                NodeID:                  r.Policy1NodeID.Int64,
×
4549
                                Timelock:                r.Policy1Timelock.Int32,
×
4550
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4551
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4552
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4553
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4554
                                LastUpdate:              r.Policy1LastUpdate,
×
4555
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4556
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4557
                                Disabled:                r.Policy1Disabled,
×
4558
                                MessageFlags:            r.Policy1MessageFlags,
×
4559
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4560
                                Signature:               r.Policy1Signature,
×
4561
                        }
×
4562
                }
×
4563
                if r.Policy2ID.Valid {
×
4564
                        policy2 = &sqlc.GraphChannelPolicy{
×
4565
                                ID:                      r.Policy2ID.Int64,
×
4566
                                Version:                 r.Policy2Version.Int16,
×
4567
                                ChannelID:               r.GraphChannel.ID,
×
4568
                                NodeID:                  r.Policy2NodeID.Int64,
×
4569
                                Timelock:                r.Policy2Timelock.Int32,
×
4570
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4571
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4572
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4573
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4574
                                LastUpdate:              r.Policy2LastUpdate,
×
4575
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4576
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4577
                                Disabled:                r.Policy2Disabled,
×
4578
                                MessageFlags:            r.Policy2MessageFlags,
×
4579
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4580
                                Signature:               r.Policy2Signature,
×
4581
                        }
×
4582
                }
×
4583

UNCOV
4584
                return policy1, policy2, nil
×
4585
        default:
×
4586
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
4587
                        "extractChannelPolicies: %T", r)
×
4588
        }
4589
}
4590

4591
// channelIDToBytes converts a channel ID (SCID) to a byte array
4592
// representation.
UNCOV
4593
func channelIDToBytes(channelID uint64) []byte {
×
4594
        var chanIDB [8]byte
×
4595
        byteOrder.PutUint64(chanIDB[:], channelID)
×
4596

×
4597
        return chanIDB[:]
×
4598
}
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc