• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 16311823849

16 Jul 2025 06:10AM UTC coverage: 57.534% (-9.8%) from 67.321%
16311823849

Pull #10081

github

web-flow
Merge 59ccaf41e into 9059a4e7b
Pull Request #10081: graph/db: use `/*SLICE:<field_name>*/` to optimise various graph queries

0 of 136 new or added lines in 1 file covered. (0.0%)

28921 existing lines in 461 files now uncovered.

98645 of 171456 relevant lines covered (57.53%)

1.79 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "maps"
11
        "math"
12
        "net"
13
        "slices"
14
        "strconv"
15
        "sync"
16
        "time"
17

18
        "github.com/btcsuite/btcd/btcec/v2"
19
        "github.com/btcsuite/btcd/btcutil"
20
        "github.com/btcsuite/btcd/chaincfg/chainhash"
21
        "github.com/btcsuite/btcd/wire"
22
        "github.com/lightningnetwork/lnd/aliasmgr"
23
        "github.com/lightningnetwork/lnd/batch"
24
        "github.com/lightningnetwork/lnd/fn/v2"
25
        "github.com/lightningnetwork/lnd/graph/db/models"
26
        "github.com/lightningnetwork/lnd/lnwire"
27
        "github.com/lightningnetwork/lnd/routing/route"
28
        "github.com/lightningnetwork/lnd/sqldb"
29
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
30
        "github.com/lightningnetwork/lnd/tlv"
31
        "github.com/lightningnetwork/lnd/tor"
32
)
33

34
// pageSize is the limit for the number of records that can be returned
35
// in a paginated query. This can be tuned after some benchmarks.
36
const pageSize = 2000
37

38
// ProtocolVersion is an enum that defines the gossip protocol version of a
39
// message.
40
type ProtocolVersion uint8
41

42
const (
43
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
44
        ProtocolV1 ProtocolVersion = 1
45
)
46

47
// String returns a string representation of the protocol version.
48
func (v ProtocolVersion) String() string {
×
49
        return fmt.Sprintf("V%d", v)
×
50
}
×
51

52
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
53
// execute queries against the SQL graph tables.
54
//
55
//nolint:ll,interfacebloat
56
type SQLQueries interface {
57
        /*
58
                Node queries.
59
        */
60
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
61
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.GraphNode, error)
62
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
63
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.GraphNode, error)
64
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.GraphNode, error)
65
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
66
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
67
        DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error)
68
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
69
        DeleteNode(ctx context.Context, id int64) error
70

71
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeExtraType, error)
72
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
73
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
74

75
        InsertNodeAddress(ctx context.Context, arg sqlc.InsertNodeAddressParams) error
76
        GetNodeAddressesByPubKey(ctx context.Context, arg sqlc.GetNodeAddressesByPubKeyParams) ([]sqlc.GetNodeAddressesByPubKeyRow, error)
77
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
78

79
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
80
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeFeature, error)
81
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
82
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
83

84
        /*
85
                Source node queries.
86
        */
87
        AddSourceNode(ctx context.Context, nodeID int64) error
88
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
89

90
        /*
91
                Channel queries.
92
        */
93
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
94
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
95
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.GraphChannel, error)
96
        GetChannelsBySCIDs(ctx context.Context, arg sqlc.GetChannelsBySCIDsParams) ([]sqlc.GraphChannel, error)
97
        GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]sqlc.GetChannelsByOutpointsRow, error)
98
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
99
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
100
        GetChannelsBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelsBySCIDWithPoliciesParams) ([]sqlc.GetChannelsBySCIDWithPoliciesRow, error)
101
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
102
        GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]sqlc.GetChannelFeaturesAndExtrasRow, error)
103
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
104
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
105
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
106
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
107
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
108
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
109
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.GraphChannel, error)
110
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
111
        DeleteChannel(ctx context.Context, id int64) error
112

113
        CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
114
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
115

116
        /*
117
                Channel Policy table queries.
118
        */
119
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
120
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.GraphChannelPolicy, error)
121
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
122

123
        InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error
124
        GetChannelPolicyExtraTypes(ctx context.Context, arg sqlc.GetChannelPolicyExtraTypesParams) ([]sqlc.GetChannelPolicyExtraTypesRow, error)
125
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
126

127
        /*
128
                Zombie index queries.
129
        */
130
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
131
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.GraphZombieChannel, error)
132
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
133
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
134
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
135

136
        /*
137
                Prune log table queries.
138
        */
139
        GetPruneTip(ctx context.Context) (sqlc.GraphPruneLog, error)
140
        GetPruneHashByHeight(ctx context.Context, blockHeight int64) ([]byte, error)
141
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
142
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
143

144
        /*
145
                Closed SCID table queries.
146
        */
147
        InsertClosedChannel(ctx context.Context, scid []byte) error
148
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
149
}
150

151
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
152
// database operations.
153
type BatchedSQLQueries interface {
154
        SQLQueries
155
        sqldb.BatchedTx[SQLQueries]
156
}
157

158
// SQLStore is an implementation of the V1Store interface that uses a SQL
159
// database as the backend.
160
type SQLStore struct {
161
        cfg *SQLStoreConfig
162
        db  BatchedSQLQueries
163

164
        // cacheMu guards all caches (rejectCache and chanCache). If
165
        // this mutex will be acquired at the same time as the DB mutex then
166
        // the cacheMu MUST be acquired first to prevent deadlock.
167
        cacheMu     sync.RWMutex
168
        rejectCache *rejectCache
169
        chanCache   *channelCache
170

171
        chanScheduler batch.Scheduler[SQLQueries]
172
        nodeScheduler batch.Scheduler[SQLQueries]
173

174
        srcNodes  map[ProtocolVersion]*srcNodeInfo
175
        srcNodeMu sync.Mutex
176
}
177

178
// A compile-time assertion to ensure that SQLStore implements the V1Store
179
// interface.
180
var _ V1Store = (*SQLStore)(nil)
181

182
// SQLStoreConfig holds the configuration for the SQLStore.
183
type SQLStoreConfig struct {
184
        // ChainHash is the genesis hash for the chain that all the gossip
185
        // messages in this store are aimed at.
186
        ChainHash chainhash.Hash
187

188
        // PaginationCfg is the configuration for paginated queries.
189
        PaginationCfg *sqldb.PagedQueryConfig
190
}
191

192
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
193
// storage backend.
194
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
195
        options ...StoreOptionModifier) (*SQLStore, error) {
×
196

×
197
        opts := DefaultOptions()
×
198
        for _, o := range options {
×
199
                o(opts)
×
200
        }
×
201

202
        if opts.NoMigration {
×
203
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
204
                        "supported for SQL stores")
×
205
        }
×
206

207
        s := &SQLStore{
×
208
                cfg:         cfg,
×
209
                db:          db,
×
210
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
211
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
212
                srcNodes:    make(map[ProtocolVersion]*srcNodeInfo),
×
213
        }
×
214

×
215
        s.chanScheduler = batch.NewTimeScheduler(
×
216
                db, &s.cacheMu, opts.BatchCommitInterval,
×
217
        )
×
218
        s.nodeScheduler = batch.NewTimeScheduler(
×
219
                db, nil, opts.BatchCommitInterval,
×
220
        )
×
221

×
222
        return s, nil
×
223
}
224

225
// AddLightningNode adds a vertex/node to the graph database. If the node is not
226
// in the database from before, this will add a new, unconnected one to the
227
// graph. If it is present from before, this will update that node's
228
// information.
229
//
230
// NOTE: part of the V1Store interface.
231
func (s *SQLStore) AddLightningNode(ctx context.Context,
232
        node *models.LightningNode, opts ...batch.SchedulerOption) error {
×
233

×
234
        r := &batch.Request[SQLQueries]{
×
235
                Opts: batch.NewSchedulerOptions(opts...),
×
236
                Do: func(queries SQLQueries) error {
×
237
                        _, err := upsertNode(ctx, queries, node)
×
238
                        return err
×
239
                },
×
240
        }
241

242
        return s.nodeScheduler.Execute(ctx, r)
×
243
}
244

245
// FetchLightningNode attempts to look up a target node by its identity public
246
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
247
// returned.
248
//
249
// NOTE: part of the V1Store interface.
250
func (s *SQLStore) FetchLightningNode(ctx context.Context,
251
        pubKey route.Vertex) (*models.LightningNode, error) {
×
252

×
253
        var node *models.LightningNode
×
254
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
255
                var err error
×
256
                _, node, err = getNodeByPubKey(ctx, db, pubKey)
×
257

×
258
                return err
×
259
        }, sqldb.NoOpReset)
×
260
        if err != nil {
×
261
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
262
        }
×
263

264
        return node, nil
×
265
}
266

267
// HasLightningNode determines if the graph has a vertex identified by the
268
// target node identity public key. If the node exists in the database, a
269
// timestamp of when the data for the node was lasted updated is returned along
270
// with a true boolean. Otherwise, an empty time.Time is returned with a false
271
// boolean.
272
//
273
// NOTE: part of the V1Store interface.
274
func (s *SQLStore) HasLightningNode(ctx context.Context,
275
        pubKey [33]byte) (time.Time, bool, error) {
×
276

×
277
        var (
×
278
                exists     bool
×
279
                lastUpdate time.Time
×
280
        )
×
281
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
282
                dbNode, err := db.GetNodeByPubKey(
×
283
                        ctx, sqlc.GetNodeByPubKeyParams{
×
284
                                Version: int16(ProtocolV1),
×
285
                                PubKey:  pubKey[:],
×
286
                        },
×
287
                )
×
288
                if errors.Is(err, sql.ErrNoRows) {
×
289
                        return nil
×
290
                } else if err != nil {
×
291
                        return fmt.Errorf("unable to fetch node: %w", err)
×
292
                }
×
293

294
                exists = true
×
295

×
296
                if dbNode.LastUpdate.Valid {
×
297
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
298
                }
×
299

300
                return nil
×
301
        }, sqldb.NoOpReset)
302
        if err != nil {
×
303
                return time.Time{}, false,
×
304
                        fmt.Errorf("unable to fetch node: %w", err)
×
305
        }
×
306

307
        return lastUpdate, exists, nil
×
308
}
309

310
// AddrsForNode returns all known addresses for the target node public key
311
// that the graph DB is aware of. The returned boolean indicates if the
312
// given node is unknown to the graph DB or not.
313
//
314
// NOTE: part of the V1Store interface.
315
func (s *SQLStore) AddrsForNode(ctx context.Context,
316
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
317

×
318
        var (
×
319
                addresses []net.Addr
×
320
                known     bool
×
321
        )
×
322
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
323
                var err error
×
324
                known, addresses, err = getNodeAddresses(
×
325
                        ctx, db, nodePub.SerializeCompressed(),
×
326
                )
×
327
                if err != nil {
×
328
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
329
                                err)
×
330
                }
×
331

332
                return nil
×
333
        }, sqldb.NoOpReset)
334
        if err != nil {
×
335
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
336
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
337
        }
×
338

339
        return known, addresses, nil
×
340
}
341

342
// DeleteLightningNode starts a new database transaction to remove a vertex/node
343
// from the database according to the node's public key.
344
//
345
// NOTE: part of the V1Store interface.
346
func (s *SQLStore) DeleteLightningNode(ctx context.Context,
347
        pubKey route.Vertex) error {
×
348

×
349
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
350
                res, err := db.DeleteNodeByPubKey(
×
351
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
352
                                Version: int16(ProtocolV1),
×
353
                                PubKey:  pubKey[:],
×
354
                        },
×
355
                )
×
356
                if err != nil {
×
357
                        return err
×
358
                }
×
359

360
                rows, err := res.RowsAffected()
×
361
                if err != nil {
×
362
                        return err
×
363
                }
×
364

365
                if rows == 0 {
×
366
                        return ErrGraphNodeNotFound
×
367
                } else if rows > 1 {
×
368
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
369
                }
×
370

371
                return err
×
372
        }, sqldb.NoOpReset)
373
        if err != nil {
×
374
                return fmt.Errorf("unable to delete node: %w", err)
×
375
        }
×
376

377
        return nil
×
378
}
379

380
// FetchNodeFeatures returns the features of the given node. If no features are
381
// known for the node, an empty feature vector is returned.
382
//
383
// NOTE: this is part of the graphdb.NodeTraverser interface.
384
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
385
        *lnwire.FeatureVector, error) {
×
386

×
387
        ctx := context.TODO()
×
388

×
389
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
390
}
×
391

392
// DisabledChannelIDs returns the channel ids of disabled channels.
393
// A channel is disabled when two of the associated ChanelEdgePolicies
394
// have their disabled bit on.
395
//
396
// NOTE: part of the V1Store interface.
397
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
398
        var (
×
399
                ctx     = context.TODO()
×
400
                chanIDs []uint64
×
401
        )
×
402
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
403
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
404
                if err != nil {
×
405
                        return fmt.Errorf("unable to fetch disabled "+
×
406
                                "channels: %w", err)
×
407
                }
×
408

409
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
410

×
411
                return nil
×
412
        }, sqldb.NoOpReset)
413
        if err != nil {
×
414
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
415
                        err)
×
416
        }
×
417

418
        return chanIDs, nil
×
419
}
420

421
// LookupAlias attempts to return the alias as advertised by the target node.
422
//
423
// NOTE: part of the V1Store interface.
424
func (s *SQLStore) LookupAlias(ctx context.Context,
425
        pub *btcec.PublicKey) (string, error) {
×
426

×
427
        var alias string
×
428
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
429
                dbNode, err := db.GetNodeByPubKey(
×
430
                        ctx, sqlc.GetNodeByPubKeyParams{
×
431
                                Version: int16(ProtocolV1),
×
432
                                PubKey:  pub.SerializeCompressed(),
×
433
                        },
×
434
                )
×
435
                if errors.Is(err, sql.ErrNoRows) {
×
436
                        return ErrNodeAliasNotFound
×
437
                } else if err != nil {
×
438
                        return fmt.Errorf("unable to fetch node: %w", err)
×
439
                }
×
440

441
                if !dbNode.Alias.Valid {
×
442
                        return ErrNodeAliasNotFound
×
443
                }
×
444

445
                alias = dbNode.Alias.String
×
446

×
447
                return nil
×
448
        }, sqldb.NoOpReset)
449
        if err != nil {
×
450
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
451
        }
×
452

453
        return alias, nil
×
454
}
455

456
// SourceNode returns the source node of the graph. The source node is treated
457
// as the center node within a star-graph. This method may be used to kick off
458
// a path finding algorithm in order to explore the reachability of another
459
// node based off the source node.
460
//
461
// NOTE: part of the V1Store interface.
462
func (s *SQLStore) SourceNode(ctx context.Context) (*models.LightningNode,
463
        error) {
×
464

×
465
        var node *models.LightningNode
×
466
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
467
                _, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
468
                if err != nil {
×
469
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
470
                                err)
×
471
                }
×
472

473
                _, node, err = getNodeByPubKey(ctx, db, nodePub)
×
474

×
475
                return err
×
476
        }, sqldb.NoOpReset)
477
        if err != nil {
×
478
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
479
        }
×
480

481
        return node, nil
×
482
}
483

484
// SetSourceNode sets the source node within the graph database. The source
485
// node is to be used as the center of a star-graph within path finding
486
// algorithms.
487
//
488
// NOTE: part of the V1Store interface.
489
func (s *SQLStore) SetSourceNode(ctx context.Context,
490
        node *models.LightningNode) error {
×
491

×
492
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
493
                id, err := upsertNode(ctx, db, node)
×
494
                if err != nil {
×
495
                        return fmt.Errorf("unable to upsert source node: %w",
×
496
                                err)
×
497
                }
×
498

499
                // Make sure that if a source node for this version is already
500
                // set, then the ID is the same as the one we are about to set.
501
                dbSourceNodeID, _, err := s.getSourceNode(ctx, db, ProtocolV1)
×
502
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
503
                        return fmt.Errorf("unable to fetch source node: %w",
×
504
                                err)
×
505
                } else if err == nil {
×
506
                        if dbSourceNodeID != id {
×
507
                                return fmt.Errorf("v1 source node already "+
×
508
                                        "set to a different node: %d vs %d",
×
509
                                        dbSourceNodeID, id)
×
510
                        }
×
511

512
                        return nil
×
513
                }
514

515
                return db.AddSourceNode(ctx, id)
×
516
        }, sqldb.NoOpReset)
517
}
518

519
// NodeUpdatesInHorizon returns all the known lightning node which have an
520
// update timestamp within the passed range. This method can be used by two
521
// nodes to quickly determine if they have the same set of up to date node
522
// announcements.
523
//
524
// NOTE: This is part of the V1Store interface.
525
func (s *SQLStore) NodeUpdatesInHorizon(startTime,
526
        endTime time.Time) ([]models.LightningNode, error) {
×
527

×
528
        ctx := context.TODO()
×
529

×
530
        var nodes []models.LightningNode
×
531
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
532
                dbNodes, err := db.GetNodesByLastUpdateRange(
×
533
                        ctx, sqlc.GetNodesByLastUpdateRangeParams{
×
534
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
535
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
536
                        },
×
537
                )
×
538
                if err != nil {
×
539
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
540
                }
×
541

542
                for _, dbNode := range dbNodes {
×
543
                        node, err := buildNode(ctx, db, &dbNode)
×
544
                        if err != nil {
×
545
                                return fmt.Errorf("unable to build node: %w",
×
546
                                        err)
×
547
                        }
×
548

549
                        nodes = append(nodes, *node)
×
550
                }
551

552
                return nil
×
553
        }, sqldb.NoOpReset)
554
        if err != nil {
×
555
                return nil, fmt.Errorf("unable to fetch nodes: %w", err)
×
556
        }
×
557

558
        return nodes, nil
×
559
}
560

561
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
562
// undirected edge from the two target nodes are created. The information stored
563
// denotes the static attributes of the channel, such as the channelID, the keys
564
// involved in creation of the channel, and the set of features that the channel
565
// supports. The chanPoint and chanID are used to uniquely identify the edge
566
// globally within the database.
567
//
568
// NOTE: part of the V1Store interface.
569
func (s *SQLStore) AddChannelEdge(ctx context.Context,
570
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
571

×
572
        var alreadyExists bool
×
573
        r := &batch.Request[SQLQueries]{
×
574
                Opts: batch.NewSchedulerOptions(opts...),
×
575
                Reset: func() {
×
576
                        alreadyExists = false
×
577
                },
×
578
                Do: func(tx SQLQueries) error {
×
579
                        _, err := insertChannel(ctx, tx, edge)
×
580

×
581
                        // Silence ErrEdgeAlreadyExist so that the batch can
×
582
                        // succeed, but propagate the error via local state.
×
583
                        if errors.Is(err, ErrEdgeAlreadyExist) {
×
584
                                alreadyExists = true
×
585
                                return nil
×
586
                        }
×
587

588
                        return err
×
589
                },
590
                OnCommit: func(err error) error {
×
591
                        switch {
×
592
                        case err != nil:
×
593
                                return err
×
594
                        case alreadyExists:
×
595
                                return ErrEdgeAlreadyExist
×
596
                        default:
×
597
                                s.rejectCache.remove(edge.ChannelID)
×
598
                                s.chanCache.remove(edge.ChannelID)
×
599
                                return nil
×
600
                        }
601
                },
602
        }
603

604
        return s.chanScheduler.Execute(ctx, r)
×
605
}
606

607
// HighestChanID returns the "highest" known channel ID in the channel graph.
608
// This represents the "newest" channel from the PoV of the chain. This method
609
// can be used by peers to quickly determine if their graphs are in sync.
610
//
611
// NOTE: This is part of the V1Store interface.
612
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
613
        var highestChanID uint64
×
614
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
615
                chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
×
616
                if errors.Is(err, sql.ErrNoRows) {
×
617
                        return nil
×
618
                } else if err != nil {
×
619
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
620
                                err)
×
621
                }
×
622

623
                highestChanID = byteOrder.Uint64(chanID)
×
624

×
625
                return nil
×
626
        }, sqldb.NoOpReset)
627
        if err != nil {
×
628
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
629
        }
×
630

631
        return highestChanID, nil
×
632
}
633

634
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
635
// within the database for the referenced channel. The `flags` attribute within
636
// the ChannelEdgePolicy determines which of the directed edges are being
637
// updated. If the flag is 1, then the first node's information is being
638
// updated, otherwise it's the second node's information. The node ordering is
639
// determined by the lexicographical ordering of the identity public keys of the
640
// nodes on either side of the channel.
641
//
642
// NOTE: part of the V1Store interface.
643
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
644
        edge *models.ChannelEdgePolicy,
645
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
646

×
647
        var (
×
648
                isUpdate1    bool
×
649
                edgeNotFound bool
×
650
                from, to     route.Vertex
×
651
        )
×
652

×
653
        r := &batch.Request[SQLQueries]{
×
654
                Opts: batch.NewSchedulerOptions(opts...),
×
655
                Reset: func() {
×
656
                        isUpdate1 = false
×
657
                        edgeNotFound = false
×
658
                },
×
659
                Do: func(tx SQLQueries) error {
×
660
                        var err error
×
661
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
662
                                ctx, tx, edge,
×
663
                        )
×
664
                        if err != nil {
×
665
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
666
                        }
×
667

668
                        // Silence ErrEdgeNotFound so that the batch can
669
                        // succeed, but propagate the error via local state.
670
                        if errors.Is(err, ErrEdgeNotFound) {
×
671
                                edgeNotFound = true
×
672
                                return nil
×
673
                        }
×
674

675
                        return err
×
676
                },
677
                OnCommit: func(err error) error {
×
678
                        switch {
×
679
                        case err != nil:
×
680
                                return err
×
681
                        case edgeNotFound:
×
682
                                return ErrEdgeNotFound
×
683
                        default:
×
684
                                s.updateEdgeCache(edge, isUpdate1)
×
685
                                return nil
×
686
                        }
687
                },
688
        }
689

690
        err := s.chanScheduler.Execute(ctx, r)
×
691

×
692
        return from, to, err
×
693
}
694

695
// updateEdgeCache updates our reject and channel caches with the new
696
// edge policy information.
697
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
698
        isUpdate1 bool) {
×
699

×
700
        // If an entry for this channel is found in reject cache, we'll modify
×
701
        // the entry with the updated timestamp for the direction that was just
×
702
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
703
        // during the next query for this edge.
×
704
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
705
                if isUpdate1 {
×
706
                        entry.upd1Time = e.LastUpdate.Unix()
×
707
                } else {
×
708
                        entry.upd2Time = e.LastUpdate.Unix()
×
709
                }
×
710
                s.rejectCache.insert(e.ChannelID, entry)
×
711
        }
712

713
        // If an entry for this channel is found in channel cache, we'll modify
714
        // the entry with the updated policy for the direction that was just
715
        // written. If the edge doesn't exist, we'll defer loading the info and
716
        // policies and lazily read from disk during the next query.
717
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
718
                if isUpdate1 {
×
719
                        channel.Policy1 = e
×
720
                } else {
×
721
                        channel.Policy2 = e
×
722
                }
×
723
                s.chanCache.insert(e.ChannelID, channel)
×
724
        }
725
}
726

727
// ForEachSourceNodeChannel iterates through all channels of the source node,
728
// executing the passed callback on each. The call-back is provided with the
729
// channel's outpoint, whether we have a policy for the channel and the channel
730
// peer's node information.
731
//
732
// NOTE: part of the V1Store interface.
733
func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context,
734
        cb func(chanPoint wire.OutPoint, havePolicy bool,
735
        otherNode *models.LightningNode) error, reset func()) error {
×
736

×
737
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
738
                nodeID, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
739
                if err != nil {
×
740
                        return fmt.Errorf("unable to fetch source node: %w",
×
741
                                err)
×
742
                }
×
743

744
                return forEachNodeChannel(
×
745
                        ctx, db, s.cfg.ChainHash, nodeID,
×
746
                        func(info *models.ChannelEdgeInfo,
×
747
                                outPolicy *models.ChannelEdgePolicy,
×
748
                                _ *models.ChannelEdgePolicy) error {
×
749

×
750
                                // Fetch the other node.
×
751
                                var (
×
752
                                        otherNodePub [33]byte
×
753
                                        node1        = info.NodeKey1Bytes
×
754
                                        node2        = info.NodeKey2Bytes
×
755
                                )
×
756
                                switch {
×
757
                                case bytes.Equal(node1[:], nodePub[:]):
×
758
                                        otherNodePub = node2
×
759
                                case bytes.Equal(node2[:], nodePub[:]):
×
760
                                        otherNodePub = node1
×
761
                                default:
×
762
                                        return fmt.Errorf("node not " +
×
763
                                                "participating in this channel")
×
764
                                }
765

766
                                _, otherNode, err := getNodeByPubKey(
×
767
                                        ctx, db, otherNodePub,
×
768
                                )
×
769
                                if err != nil {
×
770
                                        return fmt.Errorf("unable to fetch "+
×
771
                                                "other node(%x): %w",
×
772
                                                otherNodePub, err)
×
773
                                }
×
774

775
                                return cb(
×
776
                                        info.ChannelPoint, outPolicy != nil,
×
777
                                        otherNode,
×
778
                                )
×
779
                        },
780
                )
781
        }, reset)
782
}
783

784
// ForEachNode iterates through all the stored vertices/nodes in the graph,
785
// executing the passed callback with each node encountered. If the callback
786
// returns an error, then the transaction is aborted and the iteration stops
787
// early. Any operations performed on the NodeTx passed to the call-back are
788
// executed under the same read transaction and so, methods on the NodeTx object
789
// _MUST_ only be called from within the call-back.
790
//
791
// NOTE: part of the V1Store interface.
792
func (s *SQLStore) ForEachNode(ctx context.Context,
793
        cb func(tx NodeRTx) error, reset func()) error {
×
794

×
795
        var lastID int64 = 0
×
796
        handleNode := func(db SQLQueries, dbNode sqlc.GraphNode) error {
×
797
                node, err := buildNode(ctx, db, &dbNode)
×
798
                if err != nil {
×
799
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
800
                                dbNode.ID, err)
×
801
                }
×
802

803
                err = cb(
×
804
                        newSQLGraphNodeTx(db, s.cfg.ChainHash, dbNode.ID, node),
×
805
                )
×
806
                if err != nil {
×
807
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
808
                                dbNode.ID, err)
×
809
                }
×
810

811
                return nil
×
812
        }
813

814
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
815
                for {
×
816
                        nodes, err := db.ListNodesPaginated(
×
817
                                ctx, sqlc.ListNodesPaginatedParams{
×
818
                                        Version: int16(ProtocolV1),
×
819
                                        ID:      lastID,
×
820
                                        Limit:   pageSize,
×
821
                                },
×
822
                        )
×
823
                        if err != nil {
×
824
                                return fmt.Errorf("unable to fetch nodes: %w",
×
825
                                        err)
×
826
                        }
×
827

828
                        if len(nodes) == 0 {
×
829
                                break
×
830
                        }
831

832
                        for _, dbNode := range nodes {
×
833
                                err = handleNode(db, dbNode)
×
834
                                if err != nil {
×
835
                                        return err
×
836
                                }
×
837

838
                                lastID = dbNode.ID
×
839
                        }
840
                }
841

842
                return nil
×
843
        }, reset)
844
}
845

846
// sqlGraphNodeTx is an implementation of the NodeRTx interface backed by the
847
// SQLStore and a SQL transaction.
848
type sqlGraphNodeTx struct {
849
        db    SQLQueries
850
        id    int64
851
        node  *models.LightningNode
852
        chain chainhash.Hash
853
}
854

855
// A compile-time constraint to ensure sqlGraphNodeTx implements the NodeRTx
856
// interface.
857
var _ NodeRTx = (*sqlGraphNodeTx)(nil)
858

859
func newSQLGraphNodeTx(db SQLQueries, chain chainhash.Hash,
860
        id int64, node *models.LightningNode) *sqlGraphNodeTx {
×
861

×
862
        return &sqlGraphNodeTx{
×
863
                db:    db,
×
864
                chain: chain,
×
865
                id:    id,
×
866
                node:  node,
×
867
        }
×
868
}
×
869

870
// Node returns the raw information of the node.
871
//
872
// NOTE: This is a part of the NodeRTx interface.
873
func (s *sqlGraphNodeTx) Node() *models.LightningNode {
×
874
        return s.node
×
875
}
×
876

877
// ForEachChannel can be used to iterate over the node's channels under the same
878
// transaction used to fetch the node.
879
//
880
// NOTE: This is a part of the NodeRTx interface.
881
func (s *sqlGraphNodeTx) ForEachChannel(cb func(*models.ChannelEdgeInfo,
882
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
×
883

×
884
        ctx := context.TODO()
×
885

×
886
        return forEachNodeChannel(ctx, s.db, s.chain, s.id, cb)
×
887
}
×
888

889
// FetchNode fetches the node with the given pub key under the same transaction
890
// used to fetch the current node. The returned node is also a NodeRTx and any
891
// operations on that NodeRTx will also be done under the same transaction.
892
//
893
// NOTE: This is a part of the NodeRTx interface.
894
func (s *sqlGraphNodeTx) FetchNode(nodePub route.Vertex) (NodeRTx, error) {
×
895
        ctx := context.TODO()
×
896

×
897
        id, node, err := getNodeByPubKey(ctx, s.db, nodePub)
×
898
        if err != nil {
×
899
                return nil, fmt.Errorf("unable to fetch V1 node(%x): %w",
×
900
                        nodePub, err)
×
901
        }
×
902

903
        return newSQLGraphNodeTx(s.db, s.chain, id, node), nil
×
904
}
905

906
// ForEachNodeDirectedChannel iterates through all channels of a given node,
907
// executing the passed callback on the directed edge representing the channel
908
// and its incoming policy. If the callback returns an error, then the iteration
909
// is halted with the error propagated back up to the caller.
910
//
911
// Unknown policies are passed into the callback as nil values.
912
//
913
// NOTE: this is part of the graphdb.NodeTraverser interface.
914
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
915
        cb func(channel *DirectedChannel) error, reset func()) error {
×
916

×
917
        var ctx = context.TODO()
×
918

×
919
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
920
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
921
        }, reset)
×
922
}
923

924
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
925
// graph, executing the passed callback with each node encountered. If the
926
// callback returns an error, then the transaction is aborted and the iteration
927
// stops early.
928
//
929
// NOTE: This is a part of the V1Store interface.
930
func (s *SQLStore) ForEachNodeCacheable(ctx context.Context,
931
        cb func(route.Vertex, *lnwire.FeatureVector) error,
932
        reset func()) error {
×
933

×
934
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
935
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
936
                        nodePub route.Vertex) error {
×
937

×
938
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
939
                        if err != nil {
×
940
                                return fmt.Errorf("unable to fetch node "+
×
941
                                        "features: %w", err)
×
942
                        }
×
943

944
                        return cb(nodePub, features)
×
945
                })
946
        }, reset)
947
        if err != nil {
×
948
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
949
        }
×
950

951
        return nil
×
952
}
953

954
// ForEachNodeChannel iterates through all channels of the given node,
955
// executing the passed callback with an edge info structure and the policies
956
// of each end of the channel. The first edge policy is the outgoing edge *to*
957
// the connecting node, while the second is the incoming edge *from* the
958
// connecting node. If the callback returns an error, then the iteration is
959
// halted with the error propagated back up to the caller.
960
//
961
// Unknown policies are passed into the callback as nil values.
962
//
963
// NOTE: part of the V1Store interface.
964
func (s *SQLStore) ForEachNodeChannel(ctx context.Context, nodePub route.Vertex,
965
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
966
        *models.ChannelEdgePolicy) error, reset func()) error {
×
967

×
968
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
969
                dbNode, err := db.GetNodeByPubKey(
×
970
                        ctx, sqlc.GetNodeByPubKeyParams{
×
971
                                Version: int16(ProtocolV1),
×
972
                                PubKey:  nodePub[:],
×
973
                        },
×
974
                )
×
975
                if errors.Is(err, sql.ErrNoRows) {
×
976
                        return nil
×
977
                } else if err != nil {
×
978
                        return fmt.Errorf("unable to fetch node: %w", err)
×
979
                }
×
980

981
                return forEachNodeChannel(
×
982
                        ctx, db, s.cfg.ChainHash, dbNode.ID, cb,
×
983
                )
×
984
        }, reset)
985
}
986

987
// ChanUpdatesInHorizon returns all the known channel edges which have at least
988
// one edge that has an update timestamp within the specified horizon.
989
//
990
// NOTE: This is part of the V1Store interface.
991
func (s *SQLStore) ChanUpdatesInHorizon(startTime,
992
        endTime time.Time) ([]ChannelEdge, error) {
×
993

×
994
        s.cacheMu.Lock()
×
995
        defer s.cacheMu.Unlock()
×
996

×
997
        var (
×
998
                ctx = context.TODO()
×
999
                // To ensure we don't return duplicate ChannelEdges, we'll use
×
1000
                // an additional map to keep track of the edges already seen to
×
1001
                // prevent re-adding it.
×
1002
                edgesSeen    = make(map[uint64]struct{})
×
1003
                edgesToCache = make(map[uint64]ChannelEdge)
×
1004
                edges        []ChannelEdge
×
1005
                hits         int
×
1006
        )
×
1007
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1008
                rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
1009
                        ctx, sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
1010
                                Version:   int16(ProtocolV1),
×
1011
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
1012
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
1013
                        },
×
1014
                )
×
1015
                if err != nil {
×
1016
                        return err
×
1017
                }
×
1018

1019
                for _, row := range rows {
×
1020
                        // If we've already retrieved the info and policies for
×
1021
                        // this edge, then we can skip it as we don't need to do
×
1022
                        // so again.
×
1023
                        chanIDInt := byteOrder.Uint64(row.GraphChannel.Scid)
×
1024
                        if _, ok := edgesSeen[chanIDInt]; ok {
×
1025
                                continue
×
1026
                        }
1027

1028
                        if channel, ok := s.chanCache.get(chanIDInt); ok {
×
1029
                                hits++
×
1030
                                edgesSeen[chanIDInt] = struct{}{}
×
1031
                                edges = append(edges, channel)
×
1032

×
1033
                                continue
×
1034
                        }
1035

1036
                        node1, node2, err := buildNodes(
×
1037
                                ctx, db, row.GraphNode, row.GraphNode_2,
×
1038
                        )
×
1039
                        if err != nil {
×
1040
                                return err
×
1041
                        }
×
1042

1043
                        channel, err := getAndBuildEdgeInfo(
×
1044
                                ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
1045
                                row.GraphChannel, node1.PubKeyBytes,
×
1046
                                node2.PubKeyBytes,
×
1047
                        )
×
1048
                        if err != nil {
×
1049
                                return fmt.Errorf("unable to build channel "+
×
1050
                                        "info: %w", err)
×
1051
                        }
×
1052

1053
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1054
                        if err != nil {
×
1055
                                return fmt.Errorf("unable to extract channel "+
×
1056
                                        "policies: %w", err)
×
1057
                        }
×
1058

1059
                        p1, p2, err := getAndBuildChanPolicies(
×
1060
                                ctx, db, dbPol1, dbPol2, channel.ChannelID,
×
1061
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
1062
                        )
×
1063
                        if err != nil {
×
1064
                                return fmt.Errorf("unable to build channel "+
×
1065
                                        "policies: %w", err)
×
1066
                        }
×
1067

1068
                        edgesSeen[chanIDInt] = struct{}{}
×
1069
                        chanEdge := ChannelEdge{
×
1070
                                Info:    channel,
×
1071
                                Policy1: p1,
×
1072
                                Policy2: p2,
×
1073
                                Node1:   node1,
×
1074
                                Node2:   node2,
×
1075
                        }
×
1076
                        edges = append(edges, chanEdge)
×
1077
                        edgesToCache[chanIDInt] = chanEdge
×
1078
                }
1079

1080
                return nil
×
1081
        }, func() {
×
1082
                edgesSeen = make(map[uint64]struct{})
×
1083
                edgesToCache = make(map[uint64]ChannelEdge)
×
1084
                edges = nil
×
1085
        })
×
1086
        if err != nil {
×
1087
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
1088
        }
×
1089

1090
        // Insert any edges loaded from disk into the cache.
1091
        for chanid, channel := range edgesToCache {
×
1092
                s.chanCache.insert(chanid, channel)
×
1093
        }
×
1094

1095
        if len(edges) > 0 {
×
1096
                log.Debugf("ChanUpdatesInHorizon hit percentage: %.2f (%d/%d)",
×
1097
                        float64(hits)*100/float64(len(edges)), hits, len(edges))
×
1098
        } else {
×
1099
                log.Debugf("ChanUpdatesInHorizon returned no edges in "+
×
1100
                        "horizon (%s, %s)", startTime, endTime)
×
1101
        }
×
1102

1103
        return edges, nil
×
1104
}
1105

1106
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1107
// data to the call-back.
1108
//
1109
// NOTE: The callback contents MUST not be modified.
1110
//
1111
// NOTE: part of the V1Store interface.
1112
func (s *SQLStore) ForEachNodeCached(ctx context.Context,
1113
        cb func(node route.Vertex, chans map[uint64]*DirectedChannel) error,
1114
        reset func()) error {
×
1115

×
1116
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1117
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
1118
                        nodePub route.Vertex) error {
×
1119

×
1120
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
1121
                        if err != nil {
×
1122
                                return fmt.Errorf("unable to fetch "+
×
1123
                                        "node(id=%d) features: %w", nodeID, err)
×
1124
                        }
×
1125

1126
                        toNodeCallback := func() route.Vertex {
×
1127
                                return nodePub
×
1128
                        }
×
1129

1130
                        rows, err := db.ListChannelsByNodeID(
×
1131
                                ctx, sqlc.ListChannelsByNodeIDParams{
×
1132
                                        Version: int16(ProtocolV1),
×
1133
                                        NodeID1: nodeID,
×
1134
                                },
×
1135
                        )
×
1136
                        if err != nil {
×
1137
                                return fmt.Errorf("unable to fetch channels "+
×
1138
                                        "of node(id=%d): %w", nodeID, err)
×
1139
                        }
×
1140

1141
                        channels := make(map[uint64]*DirectedChannel, len(rows))
×
1142
                        for _, row := range rows {
×
1143
                                node1, node2, err := buildNodeVertices(
×
1144
                                        row.Node1Pubkey, row.Node2Pubkey,
×
1145
                                )
×
1146
                                if err != nil {
×
1147
                                        return err
×
1148
                                }
×
1149

1150
                                e, err := getAndBuildEdgeInfo(
×
1151
                                        ctx, db, s.cfg.ChainHash,
×
1152
                                        row.GraphChannel.ID, row.GraphChannel,
×
1153
                                        node1, node2,
×
1154
                                )
×
1155
                                if err != nil {
×
1156
                                        return fmt.Errorf("unable to build "+
×
1157
                                                "channel info: %w", err)
×
1158
                                }
×
1159

1160
                                dbPol1, dbPol2, err := extractChannelPolicies(
×
1161
                                        row,
×
1162
                                )
×
1163
                                if err != nil {
×
1164
                                        return fmt.Errorf("unable to "+
×
1165
                                                "extract channel "+
×
1166
                                                "policies: %w", err)
×
1167
                                }
×
1168

1169
                                p1, p2, err := getAndBuildChanPolicies(
×
1170
                                        ctx, db, dbPol1, dbPol2, e.ChannelID,
×
1171
                                        node1, node2,
×
1172
                                )
×
1173
                                if err != nil {
×
1174
                                        return fmt.Errorf("unable to "+
×
1175
                                                "build channel policies: %w",
×
1176
                                                err)
×
1177
                                }
×
1178

1179
                                // Determine the outgoing and incoming policy
1180
                                // for this channel and node combo.
1181
                                outPolicy, inPolicy := p1, p2
×
1182
                                if p1 != nil && p1.ToNode == nodePub {
×
1183
                                        outPolicy, inPolicy = p2, p1
×
1184
                                } else if p2 != nil && p2.ToNode != nodePub {
×
1185
                                        outPolicy, inPolicy = p2, p1
×
1186
                                }
×
1187

1188
                                var cachedInPolicy *models.CachedEdgePolicy
×
1189
                                if inPolicy != nil {
×
1190
                                        cachedInPolicy = models.NewCachedPolicy(
×
1191
                                                p2,
×
1192
                                        )
×
1193
                                        cachedInPolicy.ToNodePubKey =
×
1194
                                                toNodeCallback
×
1195
                                        cachedInPolicy.ToNodeFeatures =
×
1196
                                                features
×
1197
                                }
×
1198

1199
                                var inboundFee lnwire.Fee
×
1200
                                outPolicy.InboundFee.WhenSome(
×
1201
                                        func(fee lnwire.Fee) {
×
1202
                                                inboundFee = fee
×
1203
                                        },
×
1204
                                )
1205

1206
                                directedChannel := &DirectedChannel{
×
1207
                                        ChannelID: e.ChannelID,
×
1208
                                        IsNode1: nodePub ==
×
1209
                                                e.NodeKey1Bytes,
×
1210
                                        OtherNode:    e.NodeKey2Bytes,
×
1211
                                        Capacity:     e.Capacity,
×
1212
                                        OutPolicySet: p1 != nil,
×
1213
                                        InPolicy:     cachedInPolicy,
×
1214
                                        InboundFee:   inboundFee,
×
1215
                                }
×
1216

×
1217
                                if nodePub == e.NodeKey2Bytes {
×
1218
                                        directedChannel.OtherNode =
×
1219
                                                e.NodeKey1Bytes
×
1220
                                }
×
1221

1222
                                channels[e.ChannelID] = directedChannel
×
1223
                        }
1224

1225
                        return cb(nodePub, channels)
×
1226
                })
1227
        }, reset)
1228
}
1229

1230
// ForEachChannelCacheable iterates through all the channel edges stored
1231
// within the graph and invokes the passed callback for each edge. The
1232
// callback takes two edges as since this is a directed graph, both the
1233
// in/out edges are visited. If the callback returns an error, then the
1234
// transaction is aborted and the iteration stops early.
1235
//
1236
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1237
// pointer for that particular channel edge routing policy will be
1238
// passed into the callback.
1239
//
1240
// NOTE: this method is like ForEachChannel but fetches only the data
1241
// required for the graph cache.
1242
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1243
        *models.CachedEdgePolicy, *models.CachedEdgePolicy) error,
1244
        reset func()) error {
×
1245

×
1246
        ctx := context.TODO()
×
1247

×
1248
        handleChannel := func(db SQLQueries,
×
1249
                row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
×
1250

×
1251
                node1, node2, err := buildNodeVertices(
×
1252
                        row.Node1Pubkey, row.Node2Pubkey,
×
1253
                )
×
1254
                if err != nil {
×
1255
                        return err
×
1256
                }
×
1257

1258
                edge := buildCacheableChannelInfo(
×
1259
                        row.GraphChannel, node1, node2,
×
1260
                )
×
1261

×
1262
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1263
                if err != nil {
×
1264
                        return err
×
1265
                }
×
1266

1267
                var pol1, pol2 *models.CachedEdgePolicy
×
1268
                if dbPol1 != nil {
×
1269
                        policy1, err := buildChanPolicy(
×
1270
                                *dbPol1, edge.ChannelID, nil, node2,
×
1271
                        )
×
1272
                        if err != nil {
×
1273
                                return err
×
1274
                        }
×
1275

1276
                        pol1 = models.NewCachedPolicy(policy1)
×
1277
                }
1278
                if dbPol2 != nil {
×
1279
                        policy2, err := buildChanPolicy(
×
1280
                                *dbPol2, edge.ChannelID, nil, node1,
×
1281
                        )
×
1282
                        if err != nil {
×
1283
                                return err
×
1284
                        }
×
1285

1286
                        pol2 = models.NewCachedPolicy(policy2)
×
1287
                }
1288

1289
                if err := cb(edge, pol1, pol2); err != nil {
×
1290
                        return err
×
1291
                }
×
1292

1293
                return nil
×
1294
        }
1295

1296
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1297
                lastID := int64(-1)
×
1298
                for {
×
1299
                        //nolint:ll
×
1300
                        rows, err := db.ListChannelsWithPoliciesPaginated(
×
1301
                                ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
1302
                                        Version: int16(ProtocolV1),
×
1303
                                        ID:      lastID,
×
1304
                                        Limit:   pageSize,
×
1305
                                },
×
1306
                        )
×
1307
                        if err != nil {
×
1308
                                return err
×
1309
                        }
×
1310

1311
                        if len(rows) == 0 {
×
1312
                                break
×
1313
                        }
1314

1315
                        for _, row := range rows {
×
1316
                                err := handleChannel(db, row)
×
1317
                                if err != nil {
×
1318
                                        return err
×
1319
                                }
×
1320

1321
                                lastID = row.GraphChannel.ID
×
1322
                        }
1323
                }
1324

1325
                return nil
×
1326
        }, reset)
1327
}
1328

1329
// ForEachChannel iterates through all the channel edges stored within the
1330
// graph and invokes the passed callback for each edge. The callback takes two
1331
// edges as since this is a directed graph, both the in/out edges are visited.
1332
// If the callback returns an error, then the transaction is aborted and the
1333
// iteration stops early.
1334
//
1335
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1336
// for that particular channel edge routing policy will be passed into the
1337
// callback.
1338
//
1339
// NOTE: part of the V1Store interface.
1340
func (s *SQLStore) ForEachChannel(ctx context.Context,
1341
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1342
        *models.ChannelEdgePolicy) error, reset func()) error {
×
1343

×
1344
        handleChannel := func(db SQLQueries,
×
1345
                row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
×
1346

×
1347
                node1, node2, err := buildNodeVertices(
×
1348
                        row.Node1Pubkey, row.Node2Pubkey,
×
1349
                )
×
1350
                if err != nil {
×
1351
                        return fmt.Errorf("unable to build node vertices: %w",
×
1352
                                err)
×
1353
                }
×
1354

1355
                edge, err := getAndBuildEdgeInfo(
×
1356
                        ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
1357
                        row.GraphChannel, node1, node2,
×
1358
                )
×
1359
                if err != nil {
×
1360
                        return fmt.Errorf("unable to build channel info: %w",
×
1361
                                err)
×
1362
                }
×
1363

1364
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1365
                if err != nil {
×
1366
                        return fmt.Errorf("unable to extract channel "+
×
1367
                                "policies: %w", err)
×
1368
                }
×
1369

1370
                p1, p2, err := getAndBuildChanPolicies(
×
1371
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1372
                )
×
1373
                if err != nil {
×
1374
                        return fmt.Errorf("unable to build channel "+
×
1375
                                "policies: %w", err)
×
1376
                }
×
1377

1378
                err = cb(edge, p1, p2)
×
1379
                if err != nil {
×
1380
                        return fmt.Errorf("callback failed for channel "+
×
1381
                                "id=%d: %w", edge.ChannelID, err)
×
1382
                }
×
1383

1384
                return nil
×
1385
        }
1386

1387
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1388
                lastID := int64(-1)
×
1389
                for {
×
1390
                        //nolint:ll
×
1391
                        rows, err := db.ListChannelsWithPoliciesPaginated(
×
1392
                                ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
1393
                                        Version: int16(ProtocolV1),
×
1394
                                        ID:      lastID,
×
1395
                                        Limit:   pageSize,
×
1396
                                },
×
1397
                        )
×
1398
                        if err != nil {
×
1399
                                return err
×
1400
                        }
×
1401

1402
                        if len(rows) == 0 {
×
1403
                                break
×
1404
                        }
1405

1406
                        for _, row := range rows {
×
1407
                                err := handleChannel(db, row)
×
1408
                                if err != nil {
×
1409
                                        return err
×
1410
                                }
×
1411

1412
                                lastID = row.GraphChannel.ID
×
1413
                        }
1414
                }
1415

1416
                return nil
×
1417
        }, reset)
1418
}
1419

1420
// FilterChannelRange returns the channel ID's of all known channels which were
1421
// mined in a block height within the passed range. The channel IDs are grouped
1422
// by their common block height. This method can be used to quickly share with a
1423
// peer the set of channels we know of within a particular range to catch them
1424
// up after a period of time offline. If withTimestamps is true then the
1425
// timestamp info of the latest received channel update messages of the channel
1426
// will be included in the response.
1427
//
1428
// NOTE: This is part of the V1Store interface.
1429
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1430
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1431

×
1432
        var (
×
1433
                ctx       = context.TODO()
×
1434
                startSCID = &lnwire.ShortChannelID{
×
1435
                        BlockHeight: startHeight,
×
1436
                }
×
1437
                endSCID = lnwire.ShortChannelID{
×
1438
                        BlockHeight: endHeight,
×
1439
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1440
                        TxPosition:  math.MaxUint16,
×
1441
                }
×
1442
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1443
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1444
        )
×
1445

×
1446
        // 1) get all channels where channelID is between start and end chan ID.
×
1447
        // 2) skip if not public (ie, no channel_proof)
×
1448
        // 3) collect that channel.
×
1449
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1450
        //    and add those timestamps to the collected channel.
×
1451
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1452
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1453
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1454
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1455
                                StartScid: chanIDStart,
×
1456
                                EndScid:   chanIDEnd,
×
1457
                        },
×
1458
                )
×
1459
                if err != nil {
×
1460
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1461
                                err)
×
1462
                }
×
1463

1464
                for _, dbChan := range dbChans {
×
1465
                        cid := lnwire.NewShortChanIDFromInt(
×
1466
                                byteOrder.Uint64(dbChan.Scid),
×
1467
                        )
×
1468
                        chanInfo := NewChannelUpdateInfo(
×
1469
                                cid, time.Time{}, time.Time{},
×
1470
                        )
×
1471

×
1472
                        if !withTimestamps {
×
1473
                                channelsPerBlock[cid.BlockHeight] = append(
×
1474
                                        channelsPerBlock[cid.BlockHeight],
×
1475
                                        chanInfo,
×
1476
                                )
×
1477

×
1478
                                continue
×
1479
                        }
1480

1481
                        //nolint:ll
1482
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1483
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1484
                                        Version:   int16(ProtocolV1),
×
1485
                                        ChannelID: dbChan.ID,
×
1486
                                        NodeID:    dbChan.NodeID1,
×
1487
                                },
×
1488
                        )
×
1489
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1490
                                return fmt.Errorf("unable to fetch node1 "+
×
1491
                                        "policy: %w", err)
×
1492
                        } else if err == nil {
×
1493
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1494
                                        node1Policy.LastUpdate.Int64, 0,
×
1495
                                )
×
1496
                        }
×
1497

1498
                        //nolint:ll
1499
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1500
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1501
                                        Version:   int16(ProtocolV1),
×
1502
                                        ChannelID: dbChan.ID,
×
1503
                                        NodeID:    dbChan.NodeID2,
×
1504
                                },
×
1505
                        )
×
1506
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1507
                                return fmt.Errorf("unable to fetch node2 "+
×
1508
                                        "policy: %w", err)
×
1509
                        } else if err == nil {
×
1510
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1511
                                        node2Policy.LastUpdate.Int64, 0,
×
1512
                                )
×
1513
                        }
×
1514

1515
                        channelsPerBlock[cid.BlockHeight] = append(
×
1516
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1517
                        )
×
1518
                }
1519

1520
                return nil
×
1521
        }, func() {
×
1522
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1523
        })
×
1524
        if err != nil {
×
1525
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1526
        }
×
1527

1528
        if len(channelsPerBlock) == 0 {
×
1529
                return nil, nil
×
1530
        }
×
1531

1532
        // Return the channel ranges in ascending block height order.
1533
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1534
        slices.Sort(blocks)
×
1535

×
1536
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1537
                return BlockChannelRange{
×
1538
                        Height:   block,
×
1539
                        Channels: channelsPerBlock[block],
×
1540
                }
×
1541
        }), nil
×
1542
}
1543

1544
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1545
// zombie. This method is used on an ad-hoc basis, when channels need to be
1546
// marked as zombies outside the normal pruning cycle.
1547
//
1548
// NOTE: part of the V1Store interface.
1549
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1550
        pubKey1, pubKey2 [33]byte) error {
×
1551

×
1552
        ctx := context.TODO()
×
1553

×
1554
        s.cacheMu.Lock()
×
1555
        defer s.cacheMu.Unlock()
×
1556

×
1557
        chanIDB := channelIDToBytes(chanID)
×
1558

×
1559
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1560
                return db.UpsertZombieChannel(
×
1561
                        ctx, sqlc.UpsertZombieChannelParams{
×
1562
                                Version:  int16(ProtocolV1),
×
1563
                                Scid:     chanIDB,
×
1564
                                NodeKey1: pubKey1[:],
×
1565
                                NodeKey2: pubKey2[:],
×
1566
                        },
×
1567
                )
×
1568
        }, sqldb.NoOpReset)
×
1569
        if err != nil {
×
1570
                return fmt.Errorf("unable to upsert zombie channel "+
×
1571
                        "(channel_id=%d): %w", chanID, err)
×
1572
        }
×
1573

1574
        s.rejectCache.remove(chanID)
×
1575
        s.chanCache.remove(chanID)
×
1576

×
1577
        return nil
×
1578
}
1579

1580
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1581
//
1582
// NOTE: part of the V1Store interface.
1583
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1584
        s.cacheMu.Lock()
×
1585
        defer s.cacheMu.Unlock()
×
1586

×
1587
        var (
×
1588
                ctx     = context.TODO()
×
1589
                chanIDB = channelIDToBytes(chanID)
×
1590
        )
×
1591

×
1592
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1593
                res, err := db.DeleteZombieChannel(
×
1594
                        ctx, sqlc.DeleteZombieChannelParams{
×
1595
                                Scid:    chanIDB,
×
1596
                                Version: int16(ProtocolV1),
×
1597
                        },
×
1598
                )
×
1599
                if err != nil {
×
1600
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1601
                                err)
×
1602
                }
×
1603

1604
                rows, err := res.RowsAffected()
×
1605
                if err != nil {
×
1606
                        return err
×
1607
                }
×
1608

1609
                if rows == 0 {
×
1610
                        return ErrZombieEdgeNotFound
×
1611
                } else if rows > 1 {
×
1612
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1613
                                "expected 1", rows)
×
1614
                }
×
1615

1616
                return nil
×
1617
        }, sqldb.NoOpReset)
1618
        if err != nil {
×
1619
                return fmt.Errorf("unable to mark edge live "+
×
1620
                        "(channel_id=%d): %w", chanID, err)
×
1621
        }
×
1622

1623
        s.rejectCache.remove(chanID)
×
1624
        s.chanCache.remove(chanID)
×
1625

×
1626
        return err
×
1627
}
1628

1629
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1630
// zombie, then the two node public keys corresponding to this edge are also
1631
// returned.
1632
//
1633
// NOTE: part of the V1Store interface.
1634
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
1635
        error) {
×
1636

×
1637
        var (
×
1638
                ctx              = context.TODO()
×
1639
                isZombie         bool
×
1640
                pubKey1, pubKey2 route.Vertex
×
1641
                chanIDB          = channelIDToBytes(chanID)
×
1642
        )
×
1643

×
1644
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1645
                zombie, err := db.GetZombieChannel(
×
1646
                        ctx, sqlc.GetZombieChannelParams{
×
1647
                                Scid:    chanIDB,
×
1648
                                Version: int16(ProtocolV1),
×
1649
                        },
×
1650
                )
×
1651
                if errors.Is(err, sql.ErrNoRows) {
×
1652
                        return nil
×
1653
                }
×
1654
                if err != nil {
×
1655
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1656
                                err)
×
1657
                }
×
1658

1659
                copy(pubKey1[:], zombie.NodeKey1)
×
1660
                copy(pubKey2[:], zombie.NodeKey2)
×
1661
                isZombie = true
×
1662

×
1663
                return nil
×
1664
        }, sqldb.NoOpReset)
1665
        if err != nil {
×
1666
                return false, route.Vertex{}, route.Vertex{},
×
1667
                        fmt.Errorf("%w: %w (chanID=%d)",
×
1668
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
×
1669
        }
×
1670

1671
        return isZombie, pubKey1, pubKey2, nil
×
1672
}
1673

1674
// NumZombies returns the current number of zombie channels in the graph.
1675
//
1676
// NOTE: part of the V1Store interface.
1677
func (s *SQLStore) NumZombies() (uint64, error) {
×
1678
        var (
×
1679
                ctx        = context.TODO()
×
1680
                numZombies uint64
×
1681
        )
×
1682
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1683
                count, err := db.CountZombieChannels(ctx, int16(ProtocolV1))
×
1684
                if err != nil {
×
1685
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1686
                                err)
×
1687
                }
×
1688

1689
                numZombies = uint64(count)
×
1690

×
1691
                return nil
×
1692
        }, sqldb.NoOpReset)
1693
        if err != nil {
×
1694
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1695
        }
×
1696

1697
        return numZombies, nil
×
1698
}
1699

1700
// DeleteChannelEdges removes edges with the given channel IDs from the
1701
// database and marks them as zombies. This ensures that we're unable to re-add
1702
// it to our database once again. If an edge does not exist within the
1703
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1704
// true, then when we mark these edges as zombies, we'll set up the keys such
1705
// that we require the node that failed to send the fresh update to be the one
1706
// that resurrects the channel from its zombie state. The markZombie bool
1707
// denotes whether to mark the channel as a zombie.
1708
//
1709
// NOTE: part of the V1Store interface.
1710
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1711
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
1712

×
1713
        s.cacheMu.Lock()
×
1714
        defer s.cacheMu.Unlock()
×
1715

×
1716
        var (
×
1717
                ctx     = context.TODO()
×
1718
                deleted []*models.ChannelEdgeInfo
×
1719
        )
×
1720
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1721
                for _, chanID := range chanIDs {
×
1722
                        chanIDB := channelIDToBytes(chanID)
×
1723

×
1724
                        row, err := db.GetChannelBySCIDWithPolicies(
×
1725
                                ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1726
                                        Scid:    chanIDB,
×
1727
                                        Version: int16(ProtocolV1),
×
1728
                                },
×
1729
                        )
×
1730
                        if errors.Is(err, sql.ErrNoRows) {
×
1731
                                return ErrEdgeNotFound
×
1732
                        } else if err != nil {
×
1733
                                return fmt.Errorf("unable to fetch channel: %w",
×
1734
                                        err)
×
1735
                        }
×
1736

1737
                        node1, node2, err := buildNodeVertices(
×
1738
                                row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
1739
                        )
×
1740
                        if err != nil {
×
1741
                                return err
×
1742
                        }
×
1743

1744
                        info, err := getAndBuildEdgeInfo(
×
1745
                                ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
1746
                                row.GraphChannel, node1, node2,
×
1747
                        )
×
1748
                        if err != nil {
×
1749
                                return err
×
1750
                        }
×
1751

1752
                        err = db.DeleteChannel(ctx, row.GraphChannel.ID)
×
1753
                        if err != nil {
×
1754
                                return fmt.Errorf("unable to delete "+
×
1755
                                        "channel: %w", err)
×
1756
                        }
×
1757

1758
                        deleted = append(deleted, info)
×
1759

×
1760
                        if !markZombie {
×
1761
                                continue
×
1762
                        }
1763

1764
                        nodeKey1, nodeKey2 := info.NodeKey1Bytes,
×
1765
                                info.NodeKey2Bytes
×
1766
                        if strictZombiePruning {
×
1767
                                var e1UpdateTime, e2UpdateTime *time.Time
×
1768
                                if row.Policy1LastUpdate.Valid {
×
1769
                                        e1Time := time.Unix(
×
1770
                                                row.Policy1LastUpdate.Int64, 0,
×
1771
                                        )
×
1772
                                        e1UpdateTime = &e1Time
×
1773
                                }
×
1774
                                if row.Policy2LastUpdate.Valid {
×
1775
                                        e2Time := time.Unix(
×
1776
                                                row.Policy2LastUpdate.Int64, 0,
×
1777
                                        )
×
1778
                                        e2UpdateTime = &e2Time
×
1779
                                }
×
1780

1781
                                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
1782
                                        info, e1UpdateTime, e2UpdateTime,
×
1783
                                )
×
1784
                        }
1785

1786
                        err = db.UpsertZombieChannel(
×
1787
                                ctx, sqlc.UpsertZombieChannelParams{
×
1788
                                        Version:  int16(ProtocolV1),
×
1789
                                        Scid:     chanIDB,
×
1790
                                        NodeKey1: nodeKey1[:],
×
1791
                                        NodeKey2: nodeKey2[:],
×
1792
                                },
×
1793
                        )
×
1794
                        if err != nil {
×
1795
                                return fmt.Errorf("unable to mark channel as "+
×
1796
                                        "zombie: %w", err)
×
1797
                        }
×
1798
                }
1799

1800
                return nil
×
1801
        }, func() {
×
1802
                deleted = nil
×
1803
        })
×
1804
        if err != nil {
×
1805
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1806
                        err)
×
1807
        }
×
1808

1809
        for _, chanID := range chanIDs {
×
1810
                s.rejectCache.remove(chanID)
×
1811
                s.chanCache.remove(chanID)
×
1812
        }
×
1813

1814
        return deleted, nil
×
1815
}
1816

1817
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1818
// channel identified by the channel ID. If the channel can't be found, then
1819
// ErrEdgeNotFound is returned. A struct which houses the general information
1820
// for the channel itself is returned as well as two structs that contain the
1821
// routing policies for the channel in either direction.
1822
//
1823
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1824
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1825
// the ChannelEdgeInfo will only include the public keys of each node.
1826
//
1827
// NOTE: part of the V1Store interface.
1828
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1829
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1830
        *models.ChannelEdgePolicy, error) {
×
1831

×
1832
        var (
×
1833
                ctx              = context.TODO()
×
1834
                edge             *models.ChannelEdgeInfo
×
1835
                policy1, policy2 *models.ChannelEdgePolicy
×
1836
                chanIDB          = channelIDToBytes(chanID)
×
1837
        )
×
1838
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1839
                row, err := db.GetChannelBySCIDWithPolicies(
×
1840
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1841
                                Scid:    chanIDB,
×
1842
                                Version: int16(ProtocolV1),
×
1843
                        },
×
1844
                )
×
1845
                if errors.Is(err, sql.ErrNoRows) {
×
1846
                        // First check if this edge is perhaps in the zombie
×
1847
                        // index.
×
1848
                        zombie, err := db.GetZombieChannel(
×
1849
                                ctx, sqlc.GetZombieChannelParams{
×
1850
                                        Scid:    chanIDB,
×
1851
                                        Version: int16(ProtocolV1),
×
1852
                                },
×
1853
                        )
×
1854
                        if errors.Is(err, sql.ErrNoRows) {
×
1855
                                return ErrEdgeNotFound
×
1856
                        } else if err != nil {
×
1857
                                return fmt.Errorf("unable to check if "+
×
1858
                                        "channel is zombie: %w", err)
×
1859
                        }
×
1860

1861
                        // At this point, we know the channel is a zombie, so
1862
                        // we'll return an error indicating this, and we will
1863
                        // populate the edge info with the public keys of each
1864
                        // party as this is the only information we have about
1865
                        // it.
1866
                        edge = &models.ChannelEdgeInfo{}
×
1867
                        copy(edge.NodeKey1Bytes[:], zombie.NodeKey1)
×
1868
                        copy(edge.NodeKey2Bytes[:], zombie.NodeKey2)
×
1869

×
1870
                        return ErrZombieEdge
×
1871
                } else if err != nil {
×
1872
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1873
                }
×
1874

1875
                node1, node2, err := buildNodeVertices(
×
1876
                        row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
1877
                )
×
1878
                if err != nil {
×
1879
                        return err
×
1880
                }
×
1881

1882
                edge, err = getAndBuildEdgeInfo(
×
1883
                        ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
1884
                        row.GraphChannel, node1, node2,
×
1885
                )
×
1886
                if err != nil {
×
1887
                        return fmt.Errorf("unable to build channel info: %w",
×
1888
                                err)
×
1889
                }
×
1890

1891
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1892
                if err != nil {
×
1893
                        return fmt.Errorf("unable to extract channel "+
×
1894
                                "policies: %w", err)
×
1895
                }
×
1896

1897
                policy1, policy2, err = getAndBuildChanPolicies(
×
1898
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1899
                )
×
1900
                if err != nil {
×
1901
                        return fmt.Errorf("unable to build channel "+
×
1902
                                "policies: %w", err)
×
1903
                }
×
1904

1905
                return nil
×
1906
        }, sqldb.NoOpReset)
1907
        if err != nil {
×
1908
                // If we are returning the ErrZombieEdge, then we also need to
×
1909
                // return the edge info as the method comment indicates that
×
1910
                // this will be populated when the edge is a zombie.
×
1911
                return edge, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1912
                        err)
×
1913
        }
×
1914

1915
        return edge, policy1, policy2, nil
×
1916
}
1917

1918
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
1919
// the channel identified by the funding outpoint. If the channel can't be
1920
// found, then ErrEdgeNotFound is returned. A struct which houses the general
1921
// information for the channel itself is returned as well as two structs that
1922
// contain the routing policies for the channel in either direction.
1923
//
1924
// NOTE: part of the V1Store interface.
1925
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
1926
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1927
        *models.ChannelEdgePolicy, error) {
×
1928

×
1929
        var (
×
1930
                ctx              = context.TODO()
×
1931
                edge             *models.ChannelEdgeInfo
×
1932
                policy1, policy2 *models.ChannelEdgePolicy
×
1933
        )
×
1934
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1935
                row, err := db.GetChannelByOutpointWithPolicies(
×
1936
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
1937
                                Outpoint: op.String(),
×
1938
                                Version:  int16(ProtocolV1),
×
1939
                        },
×
1940
                )
×
1941
                if errors.Is(err, sql.ErrNoRows) {
×
1942
                        return ErrEdgeNotFound
×
1943
                } else if err != nil {
×
1944
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1945
                }
×
1946

1947
                node1, node2, err := buildNodeVertices(
×
1948
                        row.Node1Pubkey, row.Node2Pubkey,
×
1949
                )
×
1950
                if err != nil {
×
1951
                        return err
×
1952
                }
×
1953

1954
                edge, err = getAndBuildEdgeInfo(
×
1955
                        ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
1956
                        row.GraphChannel, node1, node2,
×
1957
                )
×
1958
                if err != nil {
×
1959
                        return fmt.Errorf("unable to build channel info: %w",
×
1960
                                err)
×
1961
                }
×
1962

1963
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1964
                if err != nil {
×
1965
                        return fmt.Errorf("unable to extract channel "+
×
1966
                                "policies: %w", err)
×
1967
                }
×
1968

1969
                policy1, policy2, err = getAndBuildChanPolicies(
×
1970
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1971
                )
×
1972
                if err != nil {
×
1973
                        return fmt.Errorf("unable to build channel "+
×
1974
                                "policies: %w", err)
×
1975
                }
×
1976

1977
                return nil
×
1978
        }, sqldb.NoOpReset)
1979
        if err != nil {
×
1980
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1981
                        err)
×
1982
        }
×
1983

1984
        return edge, policy1, policy2, nil
×
1985
}
1986

1987
// HasChannelEdge returns true if the database knows of a channel edge with the
1988
// passed channel ID, and false otherwise. If an edge with that ID is found
1989
// within the graph, then two time stamps representing the last time the edge
1990
// was updated for both directed edges are returned along with the boolean. If
1991
// it is not found, then the zombie index is checked and its result is returned
1992
// as the second boolean.
1993
//
1994
// NOTE: part of the V1Store interface.
1995
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
1996
        bool, error) {
×
1997

×
1998
        ctx := context.TODO()
×
1999

×
2000
        var (
×
2001
                exists          bool
×
2002
                isZombie        bool
×
2003
                node1LastUpdate time.Time
×
2004
                node2LastUpdate time.Time
×
2005
        )
×
2006

×
2007
        // We'll query the cache with the shared lock held to allow multiple
×
2008
        // readers to access values in the cache concurrently if they exist.
×
2009
        s.cacheMu.RLock()
×
2010
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2011
                s.cacheMu.RUnlock()
×
2012
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2013
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2014
                exists, isZombie = entry.flags.unpack()
×
2015

×
2016
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2017
        }
×
2018
        s.cacheMu.RUnlock()
×
2019

×
2020
        s.cacheMu.Lock()
×
2021
        defer s.cacheMu.Unlock()
×
2022

×
2023
        // The item was not found with the shared lock, so we'll acquire the
×
2024
        // exclusive lock and check the cache again in case another method added
×
2025
        // the entry to the cache while no lock was held.
×
2026
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2027
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2028
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2029
                exists, isZombie = entry.flags.unpack()
×
2030

×
2031
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2032
        }
×
2033

2034
        chanIDB := channelIDToBytes(chanID)
×
2035
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2036
                channel, err := db.GetChannelBySCID(
×
2037
                        ctx, sqlc.GetChannelBySCIDParams{
×
2038
                                Scid:    chanIDB,
×
2039
                                Version: int16(ProtocolV1),
×
2040
                        },
×
2041
                )
×
2042
                if errors.Is(err, sql.ErrNoRows) {
×
2043
                        // Check if it is a zombie channel.
×
2044
                        isZombie, err = db.IsZombieChannel(
×
2045
                                ctx, sqlc.IsZombieChannelParams{
×
2046
                                        Scid:    chanIDB,
×
2047
                                        Version: int16(ProtocolV1),
×
2048
                                },
×
2049
                        )
×
2050
                        if err != nil {
×
2051
                                return fmt.Errorf("could not check if channel "+
×
2052
                                        "is zombie: %w", err)
×
2053
                        }
×
2054

2055
                        return nil
×
2056
                } else if err != nil {
×
2057
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2058
                }
×
2059

2060
                exists = true
×
2061

×
2062
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
2063
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2064
                                Version:   int16(ProtocolV1),
×
2065
                                ChannelID: channel.ID,
×
2066
                                NodeID:    channel.NodeID1,
×
2067
                        },
×
2068
                )
×
2069
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2070
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2071
                                err)
×
2072
                } else if err == nil {
×
2073
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
2074
                }
×
2075

2076
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
2077
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2078
                                Version:   int16(ProtocolV1),
×
2079
                                ChannelID: channel.ID,
×
2080
                                NodeID:    channel.NodeID2,
×
2081
                        },
×
2082
                )
×
2083
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2084
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2085
                                err)
×
2086
                } else if err == nil {
×
2087
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
2088
                }
×
2089

2090
                return nil
×
2091
        }, sqldb.NoOpReset)
2092
        if err != nil {
×
2093
                return time.Time{}, time.Time{}, false, false,
×
2094
                        fmt.Errorf("unable to fetch channel: %w", err)
×
2095
        }
×
2096

2097
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
2098
                upd1Time: node1LastUpdate.Unix(),
×
2099
                upd2Time: node2LastUpdate.Unix(),
×
2100
                flags:    packRejectFlags(exists, isZombie),
×
2101
        })
×
2102

×
2103
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2104
}
2105

2106
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2107
// passed channel point (outpoint). If the passed channel doesn't exist within
2108
// the database, then ErrEdgeNotFound is returned.
2109
//
2110
// NOTE: part of the V1Store interface.
2111
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
2112
        var (
×
2113
                ctx       = context.TODO()
×
2114
                channelID uint64
×
2115
        )
×
2116
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2117
                chanID, err := db.GetSCIDByOutpoint(
×
2118
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2119
                                Outpoint: chanPoint.String(),
×
2120
                                Version:  int16(ProtocolV1),
×
2121
                        },
×
2122
                )
×
2123
                if errors.Is(err, sql.ErrNoRows) {
×
2124
                        return ErrEdgeNotFound
×
2125
                } else if err != nil {
×
2126
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2127
                                err)
×
2128
                }
×
2129

2130
                channelID = byteOrder.Uint64(chanID)
×
2131

×
2132
                return nil
×
2133
        }, sqldb.NoOpReset)
2134
        if err != nil {
×
2135
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2136
        }
×
2137

2138
        return channelID, nil
×
2139
}
2140

2141
// IsPublicNode is a helper method that determines whether the node with the
2142
// given public key is seen as a public node in the graph from the graph's
2143
// source node's point of view.
2144
//
2145
// NOTE: part of the V1Store interface.
2146
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
2147
        ctx := context.TODO()
×
2148

×
2149
        var isPublic bool
×
2150
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2151
                var err error
×
2152
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2153

×
2154
                return err
×
2155
        }, sqldb.NoOpReset)
×
2156
        if err != nil {
×
2157
                return false, fmt.Errorf("unable to check if node is "+
×
2158
                        "public: %w", err)
×
2159
        }
×
2160

2161
        return isPublic, nil
×
2162
}
2163

2164
// FetchChanInfos returns the set of channel edges that correspond to the passed
2165
// channel ID's. If an edge is the query is unknown to the database, it will
2166
// skipped and the result will contain only those edges that exist at the time
2167
// of the query. This can be used to respond to peer queries that are seeking to
2168
// fill in gaps in their view of the channel graph.
2169
//
2170
// NOTE: part of the V1Store interface.
2171
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2172
        var (
×
2173
                ctx   = context.TODO()
×
NEW
2174
                edges = make(map[uint64]ChannelEdge)
×
2175
        )
×
2176
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
2177
                chanCallBack := func(ctx context.Context,
×
NEW
2178
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
2179

×
2180
                        node1, node2, err := buildNodes(
×
2181
                                ctx, db, row.GraphNode, row.GraphNode_2,
×
2182
                        )
×
2183
                        if err != nil {
×
2184
                                return fmt.Errorf("unable to fetch nodes: %w",
×
2185
                                        err)
×
2186
                        }
×
2187

2188
                        edge, err := getAndBuildEdgeInfo(
×
2189
                                ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
2190
                                row.GraphChannel, node1.PubKeyBytes,
×
2191
                                node2.PubKeyBytes,
×
2192
                        )
×
2193
                        if err != nil {
×
2194
                                return fmt.Errorf("unable to build "+
×
2195
                                        "channel info: %w", err)
×
2196
                        }
×
2197

2198
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2199
                        if err != nil {
×
2200
                                return fmt.Errorf("unable to extract channel "+
×
2201
                                        "policies: %w", err)
×
2202
                        }
×
2203

2204
                        p1, p2, err := getAndBuildChanPolicies(
×
2205
                                ctx, db, dbPol1, dbPol2, edge.ChannelID,
×
2206
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
2207
                        )
×
2208
                        if err != nil {
×
2209
                                return fmt.Errorf("unable to build channel "+
×
2210
                                        "policies: %w", err)
×
2211
                        }
×
2212

NEW
2213
                        edges[edge.ChannelID] = ChannelEdge{
×
2214
                                Info:    edge,
×
2215
                                Policy1: p1,
×
2216
                                Policy2: p2,
×
2217
                                Node1:   node1,
×
2218
                                Node2:   node2,
×
NEW
2219
                        }
×
NEW
2220

×
NEW
2221
                        return nil
×
2222
                }
2223

NEW
2224
                queryWrapper := func(ctx context.Context, scids [][]byte) (
×
NEW
2225
                        []sqlc.GetChannelsBySCIDWithPoliciesRow, error) {
×
NEW
2226

×
2227
                        return db.GetChannelsBySCIDWithPolicies(
×
NEW
2228
                                ctx, sqlc.GetChannelsBySCIDWithPoliciesParams{
×
2229
                                        Scids:   scids,
×
2230
                                        Version: int16(ProtocolV1),
×
2231
                                },
×
2232
                        )
×
UNCOV
2233
                }
×
2234

NEW
2235
                err := sqldb.ExecutePagedQuery(
×
NEW
2236
                        ctx, s.cfg.PaginationCfg, chanIDs, channelIDToBytes,
×
NEW
2237
                        queryWrapper, chanCallBack,
×
NEW
2238
                )
×
NEW
2239
                if err != nil {
×
NEW
2240
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
NEW
2241
                }
×
2242

NEW
2243
                return nil
×
NEW
2244
        }, func() {
×
NEW
2245
                clear(edges)
×
NEW
2246
        })
×
NEW
2247
        if err != nil {
×
NEW
2248
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
NEW
2249
        }
×
2250

NEW
2251
        res := make([]ChannelEdge, 0, len(edges))
×
NEW
2252
        for _, chanID := range chanIDs {
×
NEW
2253
                edge, ok := edges[chanID]
×
NEW
2254
                if !ok {
×
NEW
2255
                        continue
×
2256
                }
2257

NEW
2258
                res = append(res, edge)
×
2259
        }
2260

NEW
2261
        return res, nil
×
2262
}
2263

2264
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2265
// ID's that we don't know and are not known zombies of the passed set. In other
2266
// words, we perform a set difference of our set of chan ID's and the ones
2267
// passed in. This method can be used by callers to determine the set of
2268
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2269
// known zombies is also returned.
2270
//
2271
// NOTE: part of the V1Store interface.
2272
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
UNCOV
2273
        []ChannelUpdateInfo, error) {
×
UNCOV
2274

×
UNCOV
2275
        var (
×
UNCOV
2276
                ctx          = context.TODO()
×
UNCOV
2277
                newChanIDs   []uint64
×
2278
                knownZombies []ChannelUpdateInfo
×
2279
                infoLookup   = make(
×
2280
                        map[uint64]ChannelUpdateInfo, len(chansInfo),
×
2281
                )
×
2282
        )
×
2283

×
NEW
2284
        for _, chanInfo := range chansInfo {
×
NEW
2285
                scid := chanInfo.ShortChannelID.ToUint64()
×
NEW
2286
                infoLookup[scid] = chanInfo
×
2287
        }
×
2288

NEW
2289
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
2290
                queryWrapper := func(ctx context.Context,
×
NEW
2291
                        scids [][]byte) ([]sqlc.GraphChannel, error) {
×
NEW
2292

×
NEW
2293
                        return db.GetChannelsBySCIDs(
×
2294
                                ctx, sqlc.GetChannelsBySCIDsParams{
×
NEW
2295
                                        Version: int16(ProtocolV1),
×
NEW
2296
                                        Scids:   scids,
×
NEW
2297
                                },
×
NEW
2298
                        )
×
NEW
2299
                }
×
2300

NEW
2301
                chanIDConverter := func(chanInfo ChannelUpdateInfo) []byte {
×
NEW
2302
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
NEW
2303

×
NEW
2304
                        return channelIDToBytes(channelID)
×
NEW
2305
                }
×
2306

NEW
2307
                cb := func(ctx context.Context,
×
2308
                        channel sqlc.GraphChannel) error {
×
2309

×
NEW
2310
                        delete(infoLookup, byteOrder.Uint64(channel.Scid))
×
2311

×
2312
                        return nil
×
NEW
2313
                }
×
2314

2315
                err := sqldb.ExecutePagedQuery(
×
2316
                        ctx, s.cfg.PaginationCfg, chansInfo, chanIDConverter,
×
2317
                        queryWrapper, cb,
×
2318
                )
×
2319
                if err != nil {
×
2320
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2321
                }
×
2322

2323
                for _, chanInfo := range chansInfo {
×
2324
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
UNCOV
2325
                        if _, ok := infoLookup[channelID]; !ok {
×
2326
                                continue
×
2327
                        }
2328
                        chanIDB := channelIDToBytes(channelID)
×
2329

×
UNCOV
2330
                        isZombie, err := db.IsZombieChannel(
×
UNCOV
2331
                                ctx, sqlc.IsZombieChannelParams{
×
2332
                                        Scid:    chanIDB,
×
UNCOV
2333
                                        Version: int16(ProtocolV1),
×
UNCOV
2334
                                },
×
2335
                        )
×
2336
                        if err != nil {
×
2337
                                return fmt.Errorf("unable to fetch zombie "+
×
2338
                                        "channel: %w", err)
×
2339
                        }
×
2340

2341
                        if isZombie {
×
2342
                                knownZombies = append(knownZombies, chanInfo)
×
UNCOV
2343

×
2344
                                continue
×
2345
                        }
2346

NEW
2347
                        newChanIDs = append(newChanIDs, channelID)
×
2348
                }
2349

NEW
2350
                return nil
×
NEW
2351
        }, func() {
×
NEW
2352
                newChanIDs = nil
×
NEW
2353
                knownZombies = nil
×
NEW
2354
        })
×
NEW
2355
        if err != nil {
×
NEW
2356
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
NEW
2357
        }
×
2358

NEW
2359
        return newChanIDs, knownZombies, nil
×
2360
}
2361

2362
// PruneGraphNodes is a garbage collection method which attempts to prune out
2363
// any nodes from the channel graph that are currently unconnected. This ensure
2364
// that we only maintain a graph of reachable nodes. In the event that a pruned
2365
// node gains more channels, it will be re-added back to the graph.
2366
//
2367
// NOTE: this prunes nodes across protocol versions. It will never prune the
2368
// source nodes.
2369
//
2370
// NOTE: part of the V1Store interface.
NEW
2371
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
NEW
2372
        var ctx = context.TODO()
×
NEW
2373

×
UNCOV
2374
        var prunedNodes []route.Vertex
×
UNCOV
2375
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
UNCOV
2376
                var err error
×
UNCOV
2377
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
UNCOV
2378

×
UNCOV
2379
                return err
×
UNCOV
2380
        }, func() {
×
UNCOV
2381
                prunedNodes = nil
×
UNCOV
2382
        })
×
2383
        if err != nil {
×
2384
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2385
        }
×
2386

2387
        return prunedNodes, nil
×
2388
}
2389

2390
// PruneGraph prunes newly closed channels from the channel graph in response
2391
// to a new block being solved on the network. Any transactions which spend the
2392
// funding output of any known channels within he graph will be deleted.
2393
// Additionally, the "prune tip", or the last block which has been used to
2394
// prune the graph is stored so callers can ensure the graph is fully in sync
2395
// with the current UTXO state. A slice of channels that have been closed by
2396
// the target block along with any pruned nodes are returned if the function
2397
// succeeds without error.
2398
//
2399
// NOTE: part of the V1Store interface.
2400
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2401
        blockHash *chainhash.Hash, blockHeight uint32) (
UNCOV
2402
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
UNCOV
2403

×
UNCOV
2404
        ctx := context.TODO()
×
UNCOV
2405

×
UNCOV
2406
        s.cacheMu.Lock()
×
UNCOV
2407
        defer s.cacheMu.Unlock()
×
UNCOV
2408

×
UNCOV
2409
        var (
×
UNCOV
2410
                closedChans []*models.ChannelEdgeInfo
×
UNCOV
2411
                prunedNodes []route.Vertex
×
UNCOV
2412
        )
×
UNCOV
2413
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2414
                // Define the callback function for processing each channel
×
2415
                channelCallback := func(ctx context.Context,
×
2416
                        row sqlc.GetChannelsByOutpointsRow) error {
×
2417

×
2418
                        node1, node2, err := buildNodeVertices(
×
2419
                                row.Node1Pubkey, row.Node2Pubkey,
×
2420
                        )
×
2421
                        if err != nil {
×
2422
                                return err
×
2423
                        }
×
2424

2425
                        info, err := getAndBuildEdgeInfo(
×
NEW
2426
                                ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
NEW
2427
                                row.GraphChannel, node1, node2,
×
NEW
2428
                        )
×
2429
                        if err != nil {
×
2430
                                return err
×
2431
                        }
×
2432

2433
                        err = db.DeleteChannel(ctx, row.GraphChannel.ID)
×
2434
                        if err != nil {
×
2435
                                return fmt.Errorf("unable to delete "+
×
UNCOV
2436
                                        "channel: %w", err)
×
2437
                        }
×
2438

2439
                        closedChans = append(closedChans, info)
×
2440
                        return nil
×
2441
                }
2442

2443
                err := s.forEachChanInOutpoints(
×
UNCOV
2444
                        ctx, db, spentOutputs, channelCallback,
×
2445
                )
×
2446
                if err != nil {
×
2447
                        return fmt.Errorf("unable to fetch channels by "+
×
2448
                                "outpoints: %w", err)
×
2449
                }
×
2450

2451
                err = db.UpsertPruneLogEntry(
×
NEW
2452
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
UNCOV
2453
                                BlockHash:   blockHash[:],
×
UNCOV
2454
                                BlockHeight: int64(blockHeight),
×
NEW
2455
                        },
×
NEW
2456
                )
×
NEW
2457
                if err != nil {
×
NEW
2458
                        return fmt.Errorf("unable to insert prune log "+
×
NEW
2459
                                "entry: %w", err)
×
NEW
2460
                }
×
2461

2462
                // Now that we've pruned some channels, we'll also prune any
2463
                // nodes that no longer have any channels.
2464
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2465
                if err != nil {
×
2466
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2467
                                err)
×
2468
                }
×
2469

2470
                return nil
×
2471
        }, func() {
×
2472
                prunedNodes = nil
×
UNCOV
2473
                closedChans = nil
×
UNCOV
2474
        })
×
UNCOV
2475
        if err != nil {
×
2476
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2477
        }
×
2478

2479
        for _, channel := range closedChans {
×
2480
                s.rejectCache.remove(channel.ChannelID)
×
UNCOV
2481
                s.chanCache.remove(channel.ChannelID)
×
2482
        }
×
2483

2484
        return closedChans, prunedNodes, nil
×
2485
}
2486

2487
// forEachChanInOutpoints is a helper function that executes a paginated
2488
// query to fetch channels by their outpoints and applies the given call-back
2489
// to each.
2490
//
2491
// NOTE: this fetches channels for all protocol versions.
2492
func (s *SQLStore) forEachChanInOutpoints(ctx context.Context, db SQLQueries,
2493
        outpoints []*wire.OutPoint, cb func(ctx context.Context,
2494
        row sqlc.GetChannelsByOutpointsRow) error) error {
×
UNCOV
2495

×
2496
        // Create a wrapper that uses the transaction's db instance to execute
×
UNCOV
2497
        // the query.
×
UNCOV
2498
        queryWrapper := func(ctx context.Context,
×
NEW
2499
                pageOutpoints []string) ([]sqlc.GetChannelsByOutpointsRow,
×
NEW
2500
                error) {
×
NEW
2501

×
NEW
2502
                return db.GetChannelsByOutpoints(ctx, pageOutpoints)
×
NEW
2503
        }
×
2504

2505
        // Define the conversion function from Outpoint to string
NEW
2506
        outpointToString := func(outpoint *wire.OutPoint) string {
×
NEW
2507
                return outpoint.String()
×
NEW
2508
        }
×
2509

NEW
2510
        return sqldb.ExecutePagedQuery(
×
NEW
2511
                ctx, s.cfg.PaginationCfg, outpoints, outpointToString,
×
NEW
2512
                queryWrapper, cb,
×
NEW
2513
        )
×
2514
}
2515

2516
// ChannelView returns the verifiable edge information for each active channel
2517
// within the known channel graph. The set of UTXOs (along with their scripts)
2518
// returned are the ones that need to be watched on chain to detect channel
2519
// closes on the resident blockchain.
2520
//
2521
// NOTE: part of the V1Store interface.
NEW
2522
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
NEW
2523
        var (
×
NEW
2524
                ctx        = context.TODO()
×
NEW
2525
                edgePoints []EdgePoint
×
NEW
2526
        )
×
NEW
2527

×
UNCOV
2528
        handleChannel := func(db SQLQueries,
×
UNCOV
2529
                channel sqlc.ListChannelsPaginatedRow) error {
×
UNCOV
2530

×
UNCOV
2531
                pkScript, err := genMultiSigP2WSH(
×
UNCOV
2532
                        channel.BitcoinKey1, channel.BitcoinKey2,
×
UNCOV
2533
                )
×
2534
                if err != nil {
×
2535
                        return err
×
2536
                }
×
2537

2538
                op, err := wire.NewOutPointFromString(channel.Outpoint)
×
2539
                if err != nil {
×
2540
                        return err
×
2541
                }
×
2542

2543
                edgePoints = append(edgePoints, EdgePoint{
×
2544
                        FundingPkScript: pkScript,
×
2545
                        OutPoint:        *op,
×
2546
                })
×
2547

×
2548
                return nil
×
2549
        }
2550

2551
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2552
                lastID := int64(-1)
×
2553
                for {
×
UNCOV
2554
                        rows, err := db.ListChannelsPaginated(
×
2555
                                ctx, sqlc.ListChannelsPaginatedParams{
×
2556
                                        Version: int16(ProtocolV1),
×
2557
                                        ID:      lastID,
×
2558
                                        Limit:   pageSize,
×
2559
                                },
×
2560
                        )
×
UNCOV
2561
                        if err != nil {
×
UNCOV
2562
                                return err
×
2563
                        }
×
2564

2565
                        if len(rows) == 0 {
×
2566
                                break
×
2567
                        }
2568

2569
                        for _, row := range rows {
×
2570
                                err := handleChannel(db, row)
×
2571
                                if err != nil {
×
2572
                                        return err
×
2573
                                }
×
2574

2575
                                lastID = row.ID
×
2576
                        }
2577
                }
2578

UNCOV
2579
                return nil
×
UNCOV
2580
        }, func() {
×
2581
                edgePoints = nil
×
2582
        })
×
2583
        if err != nil {
×
2584
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
2585
        }
×
2586

2587
        return edgePoints, nil
×
2588
}
2589

2590
// PruneTip returns the block height and hash of the latest block that has been
2591
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2592
// to tell if the graph is currently in sync with the current best known UTXO
2593
// state.
2594
//
2595
// NOTE: part of the V1Store interface.
2596
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
2597
        var (
×
UNCOV
2598
                ctx       = context.TODO()
×
2599
                tipHash   chainhash.Hash
×
UNCOV
2600
                tipHeight uint32
×
UNCOV
2601
        )
×
UNCOV
2602
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
UNCOV
2603
                pruneTip, err := db.GetPruneTip(ctx)
×
UNCOV
2604
                if errors.Is(err, sql.ErrNoRows) {
×
UNCOV
2605
                        return ErrGraphNeverPruned
×
UNCOV
2606
                } else if err != nil {
×
UNCOV
2607
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
2608
                }
×
2609

2610
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
2611
                tipHeight = uint32(pruneTip.BlockHeight)
×
2612

×
2613
                return nil
×
2614
        }, sqldb.NoOpReset)
2615
        if err != nil {
×
2616
                return nil, 0, err
×
2617
        }
×
2618

2619
        return &tipHash, tipHeight, nil
×
2620
}
2621

2622
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2623
//
2624
// NOTE: this prunes nodes across protocol versions. It will never prune the
2625
// source nodes.
2626
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
2627
        db SQLQueries) ([]route.Vertex, error) {
×
2628

×
2629
        nodeKeys, err := db.DeleteUnconnectedNodes(ctx)
×
UNCOV
2630
        if err != nil {
×
2631
                return nil, fmt.Errorf("unable to delete unconnected "+
×
UNCOV
2632
                        "nodes: %w", err)
×
UNCOV
2633
        }
×
2634

UNCOV
2635
        prunedNodes := make([]route.Vertex, len(nodeKeys))
×
UNCOV
2636
        for i, nodeKey := range nodeKeys {
×
UNCOV
2637
                pub, err := route.NewVertexFromBytes(nodeKey)
×
UNCOV
2638
                if err != nil {
×
2639
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
2640
                                "from bytes: %w", err)
×
2641
                }
×
2642

2643
                prunedNodes[i] = pub
×
2644
        }
2645

UNCOV
2646
        return prunedNodes, nil
×
2647
}
2648

2649
// DisconnectBlockAtHeight is used to indicate that the block specified
2650
// by the passed height has been disconnected from the main chain. This
2651
// will "rewind" the graph back to the height below, deleting channels
2652
// that are no longer confirmed from the graph. The prune log will be
2653
// set to the last prune height valid for the remaining chain.
2654
// Channels that were removed from the graph resulting from the
2655
// disconnected block are returned.
2656
//
2657
// NOTE: part of the V1Store interface.
2658
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
UNCOV
2659
        []*models.ChannelEdgeInfo, error) {
×
UNCOV
2660

×
UNCOV
2661
        ctx := context.TODO()
×
UNCOV
2662

×
UNCOV
2663
        var (
×
UNCOV
2664
                // Every channel having a ShortChannelID starting at 'height'
×
UNCOV
2665
                // will no longer be confirmed.
×
UNCOV
2666
                startShortChanID = lnwire.ShortChannelID{
×
UNCOV
2667
                        BlockHeight: height,
×
UNCOV
2668
                }
×
UNCOV
2669

×
UNCOV
2670
                // Delete everything after this height from the db up until the
×
2671
                // SCID alias range.
×
2672
                endShortChanID = aliasmgr.StartingAlias
×
2673

×
2674
                removedChans []*models.ChannelEdgeInfo
×
2675

×
2676
                chanIDStart = channelIDToBytes(startShortChanID.ToUint64())
×
2677
                chanIDEnd   = channelIDToBytes(endShortChanID.ToUint64())
×
2678
        )
×
2679

×
2680
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2681
                rows, err := db.GetChannelsBySCIDRange(
×
2682
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
2683
                                StartScid: chanIDStart,
×
2684
                                EndScid:   chanIDEnd,
×
2685
                        },
×
2686
                )
×
2687
                if err != nil {
×
2688
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2689
                }
×
2690

2691
                for _, row := range rows {
×
2692
                        node1, node2, err := buildNodeVertices(
×
2693
                                row.Node1PubKey, row.Node2PubKey,
×
2694
                        )
×
2695
                        if err != nil {
×
2696
                                return err
×
2697
                        }
×
2698

2699
                        channel, err := getAndBuildEdgeInfo(
×
2700
                                ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
2701
                                row.GraphChannel, node1, node2,
×
UNCOV
2702
                        )
×
2703
                        if err != nil {
×
2704
                                return err
×
2705
                        }
×
2706

2707
                        err = db.DeleteChannel(ctx, row.GraphChannel.ID)
×
2708
                        if err != nil {
×
2709
                                return fmt.Errorf("unable to delete "+
×
UNCOV
2710
                                        "channel: %w", err)
×
2711
                        }
×
2712

2713
                        removedChans = append(removedChans, channel)
×
2714
                }
2715

2716
                return db.DeletePruneLogEntriesInRange(
×
2717
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
UNCOV
2718
                                StartHeight: int64(height),
×
2719
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
2720
                        },
×
2721
                )
×
2722
        }, func() {
×
2723
                removedChans = nil
×
UNCOV
2724
        })
×
2725
        if err != nil {
×
UNCOV
2726
                return nil, fmt.Errorf("unable to disconnect block at "+
×
UNCOV
2727
                        "height: %w", err)
×
2728
        }
×
2729

2730
        for _, channel := range removedChans {
×
2731
                s.rejectCache.remove(channel.ChannelID)
×
2732
                s.chanCache.remove(channel.ChannelID)
×
2733
        }
×
2734

2735
        return removedChans, nil
×
2736
}
2737

2738
// AddEdgeProof sets the proof of an existing edge in the graph database.
2739
//
2740
// NOTE: part of the V1Store interface.
2741
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
2742
        proof *models.ChannelAuthProof) error {
×
2743

×
2744
        var (
×
2745
                ctx       = context.TODO()
×
UNCOV
2746
                scidBytes = channelIDToBytes(scid.ToUint64())
×
2747
        )
×
UNCOV
2748

×
UNCOV
2749
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
UNCOV
2750
                res, err := db.AddV1ChannelProof(
×
UNCOV
2751
                        ctx, sqlc.AddV1ChannelProofParams{
×
UNCOV
2752
                                Scid:              scidBytes,
×
UNCOV
2753
                                Node1Signature:    proof.NodeSig1Bytes,
×
2754
                                Node2Signature:    proof.NodeSig2Bytes,
×
2755
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
2756
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
2757
                        },
×
2758
                )
×
2759
                if err != nil {
×
2760
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
2761
                }
×
2762

2763
                n, err := res.RowsAffected()
×
2764
                if err != nil {
×
2765
                        return err
×
2766
                }
×
2767

2768
                if n == 0 {
×
2769
                        return fmt.Errorf("no rows affected when adding edge "+
×
2770
                                "proof for SCID %v", scid)
×
2771
                } else if n > 1 {
×
2772
                        return fmt.Errorf("multiple rows affected when adding "+
×
2773
                                "edge proof for SCID %v: %d rows affected",
×
UNCOV
2774
                                scid, n)
×
2775
                }
×
2776

2777
                return nil
×
2778
        }, sqldb.NoOpReset)
UNCOV
2779
        if err != nil {
×
2780
                return fmt.Errorf("unable to add edge proof: %w", err)
×
2781
        }
×
2782

2783
        return nil
×
2784
}
2785

2786
// PutClosedScid stores a SCID for a closed channel in the database. This is so
2787
// that we can ignore channel announcements that we know to be closed without
2788
// having to validate them and fetch a block.
2789
//
2790
// NOTE: part of the V1Store interface.
2791
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
2792
        var (
×
2793
                ctx     = context.TODO()
×
UNCOV
2794
                chanIDB = channelIDToBytes(scid.ToUint64())
×
2795
        )
×
UNCOV
2796

×
UNCOV
2797
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
UNCOV
2798
                return db.InsertClosedChannel(ctx, chanIDB)
×
UNCOV
2799
        }, sqldb.NoOpReset)
×
2800
}
2801

2802
// IsClosedScid checks whether a channel identified by the passed in scid is
2803
// closed. This helps avoid having to perform expensive validation checks.
2804
//
2805
// NOTE: part of the V1Store interface.
2806
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
2807
        var (
×
2808
                ctx      = context.TODO()
×
2809
                isClosed bool
×
2810
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
2811
        )
×
UNCOV
2812
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
UNCOV
2813
                var err error
×
UNCOV
2814
                isClosed, err = db.IsClosedChannel(ctx, chanIDB)
×
UNCOV
2815
                if err != nil {
×
UNCOV
2816
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
UNCOV
2817
                                err)
×
2818
                }
×
2819

2820
                return nil
×
2821
        }, sqldb.NoOpReset)
2822
        if err != nil {
×
2823
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
2824
                        err)
×
2825
        }
×
2826

2827
        return isClosed, nil
×
2828
}
2829

2830
// GraphSession will provide the call-back with access to a NodeTraverser
2831
// instance which can be used to perform queries against the channel graph.
2832
//
2833
// NOTE: part of the V1Store interface.
2834
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error,
2835
        reset func()) error {
×
2836

×
2837
        var ctx = context.TODO()
×
UNCOV
2838

×
2839
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
UNCOV
2840
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
UNCOV
2841
        }, reset)
×
2842
}
2843

2844
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
2845
// read only transaction for a consistent view of the graph.
2846
type sqlNodeTraverser struct {
2847
        db    SQLQueries
2848
        chain chainhash.Hash
2849
}
2850

2851
// A compile-time assertion to ensure that sqlNodeTraverser implements the
2852
// NodeTraverser interface.
2853
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
2854

2855
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
2856
func newSQLNodeTraverser(db SQLQueries,
UNCOV
2857
        chain chainhash.Hash) *sqlNodeTraverser {
×
UNCOV
2858

×
UNCOV
2859
        return &sqlNodeTraverser{
×
UNCOV
2860
                db:    db,
×
UNCOV
2861
                chain: chain,
×
UNCOV
2862
        }
×
UNCOV
2863
}
×
2864

2865
// ForEachNodeDirectedChannel calls the callback for every channel of the given
2866
// node.
2867
//
2868
// NOTE: Part of the NodeTraverser interface.
2869
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
2870
        cb func(channel *DirectedChannel) error, _ func()) error {
×
2871

×
2872
        ctx := context.TODO()
×
2873

×
2874
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
2875
}
×
2876

2877
// FetchNodeFeatures returns the features of the given node. If the node is
2878
// unknown, assume no additional features are supported.
2879
//
2880
// NOTE: Part of the NodeTraverser interface.
2881
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
2882
        *lnwire.FeatureVector, error) {
×
2883

×
2884
        ctx := context.TODO()
×
2885

×
2886
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
2887
}
×
2888

2889
// forEachNodeDirectedChannel iterates through all channels of a given
2890
// node, executing the passed callback on the directed edge representing the
2891
// channel and its incoming policy. If the node is not found, no error is
2892
// returned.
2893
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
2894
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
2895

×
2896
        toNodeCallback := func() route.Vertex {
×
2897
                return nodePub
×
2898
        }
×
2899

UNCOV
2900
        dbID, err := db.GetNodeIDByPubKey(
×
UNCOV
2901
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
UNCOV
2902
                        Version: int16(ProtocolV1),
×
UNCOV
2903
                        PubKey:  nodePub[:],
×
UNCOV
2904
                },
×
UNCOV
2905
        )
×
2906
        if errors.Is(err, sql.ErrNoRows) {
×
2907
                return nil
×
2908
        } else if err != nil {
×
2909
                return fmt.Errorf("unable to fetch node: %w", err)
×
2910
        }
×
2911

2912
        rows, err := db.ListChannelsByNodeID(
×
2913
                ctx, sqlc.ListChannelsByNodeIDParams{
×
2914
                        Version: int16(ProtocolV1),
×
2915
                        NodeID1: dbID,
×
2916
                },
×
2917
        )
×
2918
        if err != nil {
×
2919
                return fmt.Errorf("unable to fetch channels: %w", err)
×
2920
        }
×
2921

2922
        // Exit early if there are no channels for this node so we don't
2923
        // do the unnecessary feature fetching.
2924
        if len(rows) == 0 {
×
2925
                return nil
×
2926
        }
×
2927

2928
        features, err := getNodeFeatures(ctx, db, dbID)
×
2929
        if err != nil {
×
2930
                return fmt.Errorf("unable to fetch node features: %w", err)
×
2931
        }
×
2932

UNCOV
2933
        for _, row := range rows {
×
UNCOV
2934
                node1, node2, err := buildNodeVertices(
×
UNCOV
2935
                        row.Node1Pubkey, row.Node2Pubkey,
×
2936
                )
×
2937
                if err != nil {
×
2938
                        return fmt.Errorf("unable to build node vertices: %w",
×
UNCOV
2939
                                err)
×
2940
                }
×
2941

2942
                edge := buildCacheableChannelInfo(
×
2943
                        row.GraphChannel, node1, node2,
×
UNCOV
2944
                )
×
2945

×
2946
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2947
                if err != nil {
×
2948
                        return err
×
2949
                }
×
2950

2951
                var p1, p2 *models.CachedEdgePolicy
×
2952
                if dbPol1 != nil {
×
UNCOV
2953
                        policy1, err := buildChanPolicy(
×
2954
                                *dbPol1, edge.ChannelID, nil, node2,
×
2955
                        )
×
2956
                        if err != nil {
×
2957
                                return err
×
2958
                        }
×
2959

2960
                        p1 = models.NewCachedPolicy(policy1)
×
2961
                }
UNCOV
2962
                if dbPol2 != nil {
×
2963
                        policy2, err := buildChanPolicy(
×
2964
                                *dbPol2, edge.ChannelID, nil, node1,
×
2965
                        )
×
2966
                        if err != nil {
×
2967
                                return err
×
2968
                        }
×
2969

2970
                        p2 = models.NewCachedPolicy(policy2)
×
2971
                }
2972

2973
                // Determine the outgoing and incoming policy for this
2974
                // channel and node combo.
2975
                outPolicy, inPolicy := p1, p2
×
2976
                if p1 != nil && node2 == nodePub {
×
2977
                        outPolicy, inPolicy = p2, p1
×
2978
                } else if p2 != nil && node1 != nodePub {
×
2979
                        outPolicy, inPolicy = p2, p1
×
2980
                }
×
2981

2982
                var cachedInPolicy *models.CachedEdgePolicy
×
UNCOV
2983
                if inPolicy != nil {
×
UNCOV
2984
                        cachedInPolicy = inPolicy
×
UNCOV
2985
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
UNCOV
2986
                        cachedInPolicy.ToNodeFeatures = features
×
2987
                }
×
2988

2989
                directedChannel := &DirectedChannel{
×
2990
                        ChannelID:    edge.ChannelID,
×
2991
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
2992
                        OtherNode:    edge.NodeKey2Bytes,
×
UNCOV
2993
                        Capacity:     edge.Capacity,
×
2994
                        OutPolicySet: outPolicy != nil,
×
2995
                        InPolicy:     cachedInPolicy,
×
2996
                }
×
2997
                if outPolicy != nil {
×
2998
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
2999
                                directedChannel.InboundFee = fee
×
UNCOV
3000
                        })
×
3001
                }
3002

3003
                if nodePub == edge.NodeKey2Bytes {
×
3004
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
3005
                }
×
3006

3007
                if err := cb(directedChannel); err != nil {
×
3008
                        return err
×
3009
                }
×
3010
        }
3011

3012
        return nil
×
3013
}
3014

3015
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
3016
// and executes the provided callback for each node.
3017
func forEachNodeCacheable(ctx context.Context, db SQLQueries,
UNCOV
3018
        cb func(nodeID int64, nodePub route.Vertex) error) error {
×
3019

×
3020
        lastID := int64(-1)
×
3021

×
UNCOV
3022
        for {
×
UNCOV
3023
                nodes, err := db.ListNodeIDsAndPubKeys(
×
3024
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
UNCOV
3025
                                Version: int16(ProtocolV1),
×
UNCOV
3026
                                ID:      lastID,
×
UNCOV
3027
                                Limit:   pageSize,
×
UNCOV
3028
                        },
×
UNCOV
3029
                )
×
3030
                if err != nil {
×
3031
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
3032
                }
×
3033

3034
                if len(nodes) == 0 {
×
3035
                        break
×
3036
                }
3037

3038
                for _, node := range nodes {
×
3039
                        var pub route.Vertex
×
3040
                        copy(pub[:], node.PubKey)
×
3041

×
3042
                        if err := cb(node.ID, pub); err != nil {
×
3043
                                return fmt.Errorf("forEachNodeCacheable "+
×
3044
                                        "callback failed for node(id=%d): %w",
×
UNCOV
3045
                                        node.ID, err)
×
3046
                        }
×
3047

UNCOV
3048
                        lastID = node.ID
×
3049
                }
3050
        }
3051

3052
        return nil
×
3053
}
3054

3055
// forEachNodeChannel iterates through all channels of a node, executing
3056
// the passed callback on each. The call-back is provided with the channel's
3057
// edge information, the outgoing policy and the incoming policy for the
3058
// channel and node combo.
3059
func forEachNodeChannel(ctx context.Context, db SQLQueries,
3060
        chain chainhash.Hash, id int64, cb func(*models.ChannelEdgeInfo,
3061
        *models.ChannelEdgePolicy,
UNCOV
3062
        *models.ChannelEdgePolicy) error) error {
×
UNCOV
3063

×
3064
        // Get all the V1 channels for this node.Add commentMore actions
×
UNCOV
3065
        rows, err := db.ListChannelsByNodeID(
×
UNCOV
3066
                ctx, sqlc.ListChannelsByNodeIDParams{
×
UNCOV
3067
                        Version: int16(ProtocolV1),
×
UNCOV
3068
                        NodeID1: id,
×
UNCOV
3069
                },
×
UNCOV
3070
        )
×
UNCOV
3071
        if err != nil {
×
UNCOV
3072
                return fmt.Errorf("unable to fetch channels: %w", err)
×
UNCOV
3073
        }
×
3074

3075
        // Call the call-back for each channel and its known policies.
3076
        for _, row := range rows {
×
3077
                node1, node2, err := buildNodeVertices(
×
3078
                        row.Node1Pubkey, row.Node2Pubkey,
×
3079
                )
×
3080
                if err != nil {
×
3081
                        return fmt.Errorf("unable to build node vertices: %w",
×
3082
                                err)
×
3083
                }
×
3084

3085
                edge, err := getAndBuildEdgeInfo(
×
UNCOV
3086
                        ctx, db, chain, row.GraphChannel.ID, row.GraphChannel,
×
UNCOV
3087
                        node1, node2,
×
3088
                )
×
3089
                if err != nil {
×
3090
                        return fmt.Errorf("unable to build channel info: %w",
×
3091
                                err)
×
3092
                }
×
3093

3094
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3095
                if err != nil {
×
UNCOV
3096
                        return fmt.Errorf("unable to extract channel "+
×
3097
                                "policies: %w", err)
×
3098
                }
×
3099

3100
                p1, p2, err := getAndBuildChanPolicies(
×
3101
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
3102
                )
×
3103
                if err != nil {
×
3104
                        return fmt.Errorf("unable to build channel "+
×
UNCOV
3105
                                "policies: %w", err)
×
3106
                }
×
3107

3108
                // Determine the outgoing and incoming policy for this
3109
                // channel and node combo.
3110
                p1ToNode := row.GraphChannel.NodeID2
×
UNCOV
3111
                p2ToNode := row.GraphChannel.NodeID1
×
3112
                outPolicy, inPolicy := p1, p2
×
3113
                if (p1 != nil && p1ToNode == id) ||
×
3114
                        (p2 != nil && p2ToNode != id) {
×
3115

×
3116
                        outPolicy, inPolicy = p2, p1
×
3117
                }
×
3118

UNCOV
3119
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
UNCOV
3120
                        return err
×
UNCOV
3121
                }
×
3122
        }
3123

3124
        return nil
×
3125
}
3126

3127
// updateChanEdgePolicy upserts the channel policy info we have stored for
3128
// a channel we already know of.
3129
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3130
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
3131
        error) {
×
3132

×
3133
        var (
×
UNCOV
3134
                node1Pub, node2Pub route.Vertex
×
UNCOV
3135
                isNode1            bool
×
3136
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
UNCOV
3137
        )
×
UNCOV
3138

×
UNCOV
3139
        // Check that this edge policy refers to a channel that we already
×
UNCOV
3140
        // know of. We do this explicitly so that we can return the appropriate
×
UNCOV
3141
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
UNCOV
3142
        // abort the transaction which would abort the entire batch.
×
3143
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3144
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3145
                        Scid:    chanIDB,
×
3146
                        Version: int16(ProtocolV1),
×
3147
                },
×
3148
        )
×
3149
        if errors.Is(err, sql.ErrNoRows) {
×
3150
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3151
        } else if err != nil {
×
3152
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3153
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3154
        }
×
3155

3156
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3157
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3158

×
3159
        // Figure out which node this edge is from.
×
3160
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3161
        nodeID := dbChan.NodeID1
×
3162
        if !isNode1 {
×
3163
                nodeID = dbChan.NodeID2
×
3164
        }
×
3165

3166
        var (
×
UNCOV
3167
                inboundBase sql.NullInt64
×
3168
                inboundRate sql.NullInt64
×
3169
        )
×
3170
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3171
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3172
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3173
        })
×
3174

3175
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
3176
                Version:     int16(ProtocolV1),
×
UNCOV
3177
                ChannelID:   dbChan.ID,
×
3178
                NodeID:      nodeID,
×
3179
                Timelock:    int32(edge.TimeLockDelta),
×
3180
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
3181
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
3182
                MinHtlcMsat: int64(edge.MinHTLC),
×
3183
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3184
                Disabled: sql.NullBool{
×
3185
                        Valid: true,
×
UNCOV
3186
                        Bool:  edge.IsDisabled(),
×
3187
                },
×
3188
                MaxHtlcMsat: sql.NullInt64{
×
3189
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3190
                        Int64: int64(edge.MaxHTLC),
×
3191
                },
×
3192
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
3193
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
3194
                InboundBaseFeeMsat:      inboundBase,
×
3195
                InboundFeeRateMilliMsat: inboundRate,
×
3196
                Signature:               edge.SigBytes,
×
3197
        })
×
3198
        if err != nil {
×
3199
                return node1Pub, node2Pub, isNode1,
×
3200
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3201
        }
×
3202

3203
        // Convert the flat extra opaque data into a map of TLV types to
3204
        // values.
3205
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3206
        if err != nil {
×
3207
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3208
                        "marshal extra opaque data: %w", err)
×
3209
        }
×
3210

3211
        // Update the channel policy's extra signed fields.
3212
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3213
        if err != nil {
×
UNCOV
3214
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
UNCOV
3215
                        "policy extra TLVs: %w", err)
×
UNCOV
3216
        }
×
3217

3218
        return node1Pub, node2Pub, isNode1, nil
×
3219
}
3220

3221
// getNodeByPubKey attempts to look up a target node by its public key.
3222
func getNodeByPubKey(ctx context.Context, db SQLQueries,
UNCOV
3223
        pubKey route.Vertex) (int64, *models.LightningNode, error) {
×
3224

×
3225
        dbNode, err := db.GetNodeByPubKey(
×
3226
                ctx, sqlc.GetNodeByPubKeyParams{
×
3227
                        Version: int16(ProtocolV1),
×
3228
                        PubKey:  pubKey[:],
×
UNCOV
3229
                },
×
3230
        )
×
UNCOV
3231
        if errors.Is(err, sql.ErrNoRows) {
×
UNCOV
3232
                return 0, nil, ErrGraphNodeNotFound
×
UNCOV
3233
        } else if err != nil {
×
UNCOV
3234
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3235
        }
×
3236

3237
        node, err := buildNode(ctx, db, &dbNode)
×
3238
        if err != nil {
×
3239
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3240
        }
×
3241

3242
        return dbNode.ID, node, nil
×
3243
}
3244

3245
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3246
// provided database channel row and the public keys of the two nodes
3247
// involved in the channel.
3248
func buildCacheableChannelInfo(dbChan sqlc.GraphChannel, node1Pub,
3249
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
3250

×
3251
        return &models.CachedEdgeInfo{
×
3252
                ChannelID:     byteOrder.Uint64(dbChan.Scid),
×
UNCOV
3253
                NodeKey1Bytes: node1Pub,
×
3254
                NodeKey2Bytes: node2Pub,
×
UNCOV
3255
                Capacity:      btcutil.Amount(dbChan.Capacity.Int64),
×
UNCOV
3256
        }
×
UNCOV
3257
}
×
3258

3259
// buildNode constructs a LightningNode instance from the given database node
3260
// record. The node's features, addresses and extra signed fields are also
3261
// fetched from the database and set on the node.
3262
func buildNode(ctx context.Context, db SQLQueries, dbNode *sqlc.GraphNode) (
3263
        *models.LightningNode, error) {
×
3264

×
3265
        if dbNode.Version != int16(ProtocolV1) {
×
3266
                return nil, fmt.Errorf("unsupported node version: %d",
×
3267
                        dbNode.Version)
×
3268
        }
×
3269

UNCOV
3270
        var pub [33]byte
×
UNCOV
3271
        copy(pub[:], dbNode.PubKey)
×
UNCOV
3272

×
UNCOV
3273
        node := &models.LightningNode{
×
UNCOV
3274
                PubKeyBytes: pub,
×
3275
                Features:    lnwire.EmptyFeatureVector(),
×
3276
                LastUpdate:  time.Unix(0, 0),
×
3277
        }
×
3278

×
3279
        if len(dbNode.Signature) == 0 {
×
3280
                return node, nil
×
UNCOV
3281
        }
×
3282

3283
        node.HaveNodeAnnouncement = true
×
3284
        node.AuthSigBytes = dbNode.Signature
×
3285
        node.Alias = dbNode.Alias.String
×
3286
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
3287

×
3288
        var err error
×
3289
        if dbNode.Color.Valid {
×
3290
                node.Color, err = DecodeHexColor(dbNode.Color.String)
×
3291
                if err != nil {
×
3292
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3293
                                err)
×
UNCOV
3294
                }
×
3295
        }
3296

3297
        // Fetch the node's features.
3298
        node.Features, err = getNodeFeatures(ctx, db, dbNode.ID)
×
3299
        if err != nil {
×
3300
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3301
                        "features: %w", dbNode.ID, err)
×
3302
        }
×
3303

3304
        // Fetch the node's addresses.
3305
        _, node.Addresses, err = getNodeAddresses(ctx, db, pub[:])
×
3306
        if err != nil {
×
UNCOV
3307
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
UNCOV
3308
                        "addresses: %w", dbNode.ID, err)
×
UNCOV
3309
        }
×
3310

3311
        // Fetch the node's extra signed fields.
3312
        extraTLVMap, err := getNodeExtraSignedFields(ctx, db, dbNode.ID)
×
3313
        if err != nil {
×
3314
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
UNCOV
3315
                        "extra signed fields: %w", dbNode.ID, err)
×
UNCOV
3316
        }
×
3317

3318
        recs, err := lnwire.CustomRecords(extraTLVMap).Serialize()
×
3319
        if err != nil {
×
3320
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
3321
                        "fields: %w", err)
×
UNCOV
3322
        }
×
3323

3324
        if len(recs) != 0 {
×
3325
                node.ExtraOpaqueData = recs
×
3326
        }
×
3327

3328
        return node, nil
×
3329
}
3330

3331
// getNodeFeatures fetches the feature bits and constructs the feature vector
3332
// for a node with the given DB ID.
3333
func getNodeFeatures(ctx context.Context, db SQLQueries,
3334
        nodeID int64) (*lnwire.FeatureVector, error) {
×
UNCOV
3335

×
3336
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3337
        if err != nil {
×
3338
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
UNCOV
3339
                        nodeID, err)
×
3340
        }
×
3341

UNCOV
3342
        features := lnwire.EmptyFeatureVector()
×
UNCOV
3343
        for _, feature := range rows {
×
UNCOV
3344
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
UNCOV
3345
        }
×
3346

3347
        return features, nil
×
3348
}
3349

3350
// getNodeExtraSignedFields fetches the extra signed fields for a node with the
3351
// given DB ID.
3352
func getNodeExtraSignedFields(ctx context.Context, db SQLQueries,
UNCOV
3353
        nodeID int64) (map[uint64][]byte, error) {
×
3354

×
3355
        fields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3356
        if err != nil {
×
3357
                return nil, fmt.Errorf("unable to get node(%d) extra "+
×
UNCOV
3358
                        "signed fields: %w", nodeID, err)
×
3359
        }
×
3360

UNCOV
3361
        extraFields := make(map[uint64][]byte)
×
UNCOV
3362
        for _, field := range fields {
×
UNCOV
3363
                extraFields[uint64(field.Type)] = field.Value
×
UNCOV
3364
        }
×
3365

3366
        return extraFields, nil
×
3367
}
3368

3369
// upsertNode upserts the node record into the database. If the node already
3370
// exists, then the node's information is updated. If the node doesn't exist,
3371
// then a new node is created. The node's features, addresses and extra TLV
3372
// types are also updated. The node's DB ID is returned.
3373
func upsertNode(ctx context.Context, db SQLQueries,
3374
        node *models.LightningNode) (int64, error) {
×
3375

×
3376
        params := sqlc.UpsertNodeParams{
×
UNCOV
3377
                Version: int16(ProtocolV1),
×
3378
                PubKey:  node.PubKeyBytes[:],
×
UNCOV
3379
        }
×
UNCOV
3380

×
UNCOV
3381
        if node.HaveNodeAnnouncement {
×
UNCOV
3382
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
UNCOV
3383
                params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
×
UNCOV
3384
                params.Alias = sqldb.SQLStr(node.Alias)
×
UNCOV
3385
                params.Signature = node.AuthSigBytes
×
3386
        }
×
3387

3388
        nodeID, err := db.UpsertNode(ctx, params)
×
3389
        if err != nil {
×
3390
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3391
                        err)
×
3392
        }
×
3393

3394
        // We can exit here if we don't have the announcement yet.
3395
        if !node.HaveNodeAnnouncement {
×
3396
                return nodeID, nil
×
3397
        }
×
3398

3399
        // Update the node's features.
3400
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3401
        if err != nil {
×
3402
                return 0, fmt.Errorf("inserting node features: %w", err)
×
3403
        }
×
3404

3405
        // Update the node's addresses.
UNCOV
3406
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3407
        if err != nil {
×
3408
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
3409
        }
×
3410

3411
        // Convert the flat extra opaque data into a map of TLV types to
3412
        // values.
3413
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
3414
        if err != nil {
×
3415
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
UNCOV
3416
                        err)
×
UNCOV
3417
        }
×
3418

3419
        // Update the node's extra signed fields.
3420
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3421
        if err != nil {
×
UNCOV
3422
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
UNCOV
3423
        }
×
3424

3425
        return nodeID, nil
×
3426
}
3427

3428
// upsertNodeFeatures updates the node's features node_features table. This
3429
// includes deleting any feature bits no longer present and inserting any new
3430
// feature bits. If the feature bit does not yet exist in the features table,
3431
// then an entry is created in that table first.
3432
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3433
        features *lnwire.FeatureVector) error {
×
3434

×
3435
        // Get any existing features for the node.
×
UNCOV
3436
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3437
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
UNCOV
3438
                return err
×
UNCOV
3439
        }
×
3440

3441
        // Copy the nodes latest set of feature bits.
UNCOV
3442
        newFeatures := make(map[int32]struct{})
×
UNCOV
3443
        if features != nil {
×
UNCOV
3444
                for feature := range features.Features() {
×
3445
                        newFeatures[int32(feature)] = struct{}{}
×
3446
                }
×
3447
        }
3448

3449
        // For any current feature that already exists in the DB, remove it from
3450
        // the in-memory map. For any existing feature that does not exist in
3451
        // the in-memory map, delete it from the database.
UNCOV
3452
        for _, feature := range existingFeatures {
×
UNCOV
3453
                // The feature is still present, so there are no updates to be
×
3454
                // made.
×
3455
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3456
                        delete(newFeatures, feature.FeatureBit)
×
3457
                        continue
×
3458
                }
3459

3460
                // The feature is no longer present, so we remove it from the
3461
                // database.
UNCOV
3462
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
UNCOV
3463
                        NodeID:     nodeID,
×
3464
                        FeatureBit: feature.FeatureBit,
×
3465
                })
×
3466
                if err != nil {
×
3467
                        return fmt.Errorf("unable to delete node(%d) "+
×
3468
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3469
                                err)
×
UNCOV
3470
                }
×
3471
        }
3472

3473
        // Any remaining entries in newFeatures are new features that need to be
3474
        // added to the database for the first time.
3475
        for feature := range newFeatures {
×
3476
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3477
                        NodeID:     nodeID,
×
3478
                        FeatureBit: feature,
×
3479
                })
×
3480
                if err != nil {
×
3481
                        return fmt.Errorf("unable to insert node(%d) "+
×
3482
                                "feature(%v): %w", nodeID, feature, err)
×
UNCOV
3483
                }
×
3484
        }
3485

UNCOV
3486
        return nil
×
3487
}
3488

3489
// fetchNodeFeatures fetches the features for a node with the given public key.
3490
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3491
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3492

×
3493
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3494
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3495
                        PubKey:  nodePub[:],
×
UNCOV
3496
                        Version: int16(ProtocolV1),
×
UNCOV
3497
                },
×
3498
        )
×
UNCOV
3499
        if err != nil {
×
UNCOV
3500
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
UNCOV
3501
                        nodePub, err)
×
UNCOV
3502
        }
×
3503

3504
        features := lnwire.EmptyFeatureVector()
×
3505
        for _, bit := range rows {
×
3506
                features.Set(lnwire.FeatureBit(bit))
×
3507
        }
×
3508

3509
        return features, nil
×
3510
}
3511

3512
// dbAddressType is an enum type that represents the different address types
3513
// that we store in the node_addresses table. The address type determines how
3514
// the address is to be serialised/deserialize.
3515
type dbAddressType uint8
3516

3517
const (
3518
        addressTypeIPv4   dbAddressType = 1
3519
        addressTypeIPv6   dbAddressType = 2
3520
        addressTypeTorV2  dbAddressType = 3
3521
        addressTypeTorV3  dbAddressType = 4
3522
        addressTypeOpaque dbAddressType = math.MaxInt8
3523
)
3524

3525
// upsertNodeAddresses updates the node's addresses in the database. This
3526
// includes deleting any existing addresses and inserting the new set of
3527
// addresses. The deletion is necessary since the ordering of the addresses may
3528
// change, and we need to ensure that the database reflects the latest set of
3529
// addresses so that at the time of reconstructing the node announcement, the
3530
// order is preserved and the signature over the message remains valid.
3531
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
UNCOV
3532
        addresses []net.Addr) error {
×
UNCOV
3533

×
UNCOV
3534
        // Delete any existing addresses for the node. This is required since
×
UNCOV
3535
        // even if the new set of addresses is the same, the ordering may have
×
UNCOV
3536
        // changed for a given address type.
×
UNCOV
3537
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
UNCOV
3538
        if err != nil {
×
UNCOV
3539
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
UNCOV
3540
                        nodeID, err)
×
UNCOV
3541
        }
×
3542

3543
        // Copy the nodes latest set of addresses.
3544
        newAddresses := map[dbAddressType][]string{
×
3545
                addressTypeIPv4:   {},
×
3546
                addressTypeIPv6:   {},
×
3547
                addressTypeTorV2:  {},
×
3548
                addressTypeTorV3:  {},
×
3549
                addressTypeOpaque: {},
×
3550
        }
×
3551
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3552
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3553
        }
×
3554

UNCOV
3555
        for _, address := range addresses {
×
3556
                switch addr := address.(type) {
×
3557
                case *net.TCPAddr:
×
3558
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3559
                                addAddr(addressTypeIPv4, addr)
×
3560
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3561
                                addAddr(addressTypeIPv6, addr)
×
3562
                        } else {
×
3563
                                return fmt.Errorf("unhandled IP address: %v",
×
3564
                                        addr)
×
3565
                        }
×
3566

3567
                case *tor.OnionAddr:
×
3568
                        switch len(addr.OnionService) {
×
3569
                        case tor.V2Len:
×
3570
                                addAddr(addressTypeTorV2, addr)
×
3571
                        case tor.V3Len:
×
3572
                                addAddr(addressTypeTorV3, addr)
×
3573
                        default:
×
3574
                                return fmt.Errorf("invalid length for a tor " +
×
3575
                                        "address")
×
3576
                        }
3577

UNCOV
3578
                case *lnwire.OpaqueAddrs:
×
3579
                        addAddr(addressTypeOpaque, addr)
×
3580

3581
                default:
×
3582
                        return fmt.Errorf("unhandled address type: %T", addr)
×
3583
                }
3584
        }
3585

3586
        // Any remaining entries in newAddresses are new addresses that need to
3587
        // be added to the database for the first time.
UNCOV
3588
        for addrType, addrList := range newAddresses {
×
UNCOV
3589
                for position, addr := range addrList {
×
3590
                        err := db.InsertNodeAddress(
×
3591
                                ctx, sqlc.InsertNodeAddressParams{
×
UNCOV
3592
                                        NodeID:   nodeID,
×
3593
                                        Type:     int16(addrType),
×
3594
                                        Address:  addr,
×
UNCOV
3595
                                        Position: int32(position),
×
UNCOV
3596
                                },
×
UNCOV
3597
                        )
×
UNCOV
3598
                        if err != nil {
×
UNCOV
3599
                                return fmt.Errorf("unable to insert "+
×
3600
                                        "node(%d) address(%v): %w", nodeID,
×
3601
                                        addr, err)
×
3602
                        }
×
3603
                }
3604
        }
3605

3606
        return nil
×
3607
}
3608

3609
// getNodeAddresses fetches the addresses for a node with the given public key.
3610
func getNodeAddresses(ctx context.Context, db SQLQueries, nodePub []byte) (bool,
3611
        []net.Addr, error) {
×
3612

×
3613
        // GetNodeAddressesByPubKey ensures that the addresses for a given type
×
3614
        // are returned in the same order as they were inserted.
×
UNCOV
3615
        rows, err := db.GetNodeAddressesByPubKey(
×
UNCOV
3616
                ctx, sqlc.GetNodeAddressesByPubKeyParams{
×
UNCOV
3617
                        Version: int16(ProtocolV1),
×
3618
                        PubKey:  nodePub,
×
UNCOV
3619
                },
×
UNCOV
3620
        )
×
UNCOV
3621
        if err != nil {
×
UNCOV
3622
                return false, nil, err
×
3623
        }
×
3624

3625
        // GetNodeAddressesByPubKey uses a left join so there should always be
3626
        // at least one row returned if the node exists even if it has no
3627
        // addresses.
3628
        if len(rows) == 0 {
×
3629
                return false, nil, nil
×
3630
        }
×
3631

3632
        addresses := make([]net.Addr, 0, len(rows))
×
3633
        for _, addr := range rows {
×
3634
                if !(addr.Type.Valid && addr.Address.Valid) {
×
3635
                        continue
×
3636
                }
3637

UNCOV
3638
                address := addr.Address.String
×
UNCOV
3639

×
3640
                switch dbAddressType(addr.Type.Int16) {
×
3641
                case addressTypeIPv4:
×
3642
                        tcp, err := net.ResolveTCPAddr("tcp4", address)
×
UNCOV
3643
                        if err != nil {
×
3644
                                return false, nil, nil
×
3645
                        }
×
3646
                        tcp.IP = tcp.IP.To4()
×
3647

×
UNCOV
3648
                        addresses = append(addresses, tcp)
×
3649

3650
                case addressTypeIPv6:
×
3651
                        tcp, err := net.ResolveTCPAddr("tcp6", address)
×
3652
                        if err != nil {
×
3653
                                return false, nil, nil
×
3654
                        }
×
3655
                        addresses = append(addresses, tcp)
×
3656

3657
                case addressTypeTorV3, addressTypeTorV2:
×
3658
                        service, portStr, err := net.SplitHostPort(address)
×
3659
                        if err != nil {
×
3660
                                return false, nil, fmt.Errorf("unable to "+
×
UNCOV
3661
                                        "split tor v3 address: %v",
×
3662
                                        addr.Address)
×
3663
                        }
×
3664

3665
                        port, err := strconv.Atoi(portStr)
×
3666
                        if err != nil {
×
3667
                                return false, nil, err
×
UNCOV
3668
                        }
×
3669

3670
                        addresses = append(addresses, &tor.OnionAddr{
×
3671
                                OnionService: service,
×
3672
                                Port:         port,
×
3673
                        })
×
3674

3675
                case addressTypeOpaque:
×
UNCOV
3676
                        opaque, err := hex.DecodeString(address)
×
3677
                        if err != nil {
×
3678
                                return false, nil, fmt.Errorf("unable to "+
×
3679
                                        "decode opaque address: %v", addr)
×
3680
                        }
×
3681

3682
                        addresses = append(addresses, &lnwire.OpaqueAddrs{
×
3683
                                Payload: opaque,
×
3684
                        })
×
3685

UNCOV
3686
                default:
×
3687
                        return false, nil, fmt.Errorf("unknown address "+
×
3688
                                "type: %v", addr.Type)
×
3689
                }
3690
        }
3691

3692
        // If we have no addresses, then we'll return nil instead of an
3693
        // empty slice.
3694
        if len(addresses) == 0 {
×
3695
                addresses = nil
×
3696
        }
×
3697

3698
        return true, addresses, nil
×
3699
}
3700

3701
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
3702
// database. This includes updating any existing types, inserting any new types,
3703
// and deleting any types that are no longer present.
3704
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
UNCOV
3705
        nodeID int64, extraFields map[uint64][]byte) error {
×
3706

×
3707
        // Get any existing extra signed fields for the node.
×
3708
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
UNCOV
3709
        if err != nil {
×
3710
                return err
×
UNCOV
3711
        }
×
3712

3713
        // Make a lookup map of the existing field types so that we can use it
3714
        // to keep track of any fields we should delete.
UNCOV
3715
        m := make(map[uint64]bool)
×
UNCOV
3716
        for _, field := range existingFields {
×
3717
                m[uint64(field.Type)] = true
×
3718
        }
×
3719

3720
        // For all the new fields, we'll upsert them and remove them from the
3721
        // map of existing fields.
3722
        for tlvType, value := range extraFields {
×
3723
                err = db.UpsertNodeExtraType(
×
UNCOV
3724
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
UNCOV
3725
                                NodeID: nodeID,
×
UNCOV
3726
                                Type:   int64(tlvType),
×
3727
                                Value:  value,
×
3728
                        },
×
3729
                )
×
3730
                if err != nil {
×
UNCOV
3731
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
UNCOV
3732
                                "signed field(%v): %w", nodeID, tlvType, err)
×
UNCOV
3733
                }
×
3734

3735
                // Remove the field from the map of existing fields if it was
3736
                // present.
3737
                delete(m, tlvType)
×
3738
        }
3739

3740
        // For all the fields that are left in the map of existing fields, we'll
3741
        // delete them as they are no longer present in the new set of fields.
3742
        for tlvType := range m {
×
3743
                err = db.DeleteExtraNodeType(
×
3744
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
3745
                                NodeID: nodeID,
×
UNCOV
3746
                                Type:   int64(tlvType),
×
UNCOV
3747
                        },
×
UNCOV
3748
                )
×
3749
                if err != nil {
×
UNCOV
3750
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
UNCOV
3751
                                "signed field(%v): %w", nodeID, tlvType, err)
×
UNCOV
3752
                }
×
3753
        }
3754

3755
        return nil
×
3756
}
3757

3758
// srcNodeInfo holds the information about the source node of the graph.
3759
type srcNodeInfo struct {
3760
        // id is the DB level ID of the source node entry in the "nodes" table.
3761
        id int64
3762

3763
        // pub is the public key of the source node.
3764
        pub route.Vertex
3765
}
3766

3767
// sourceNode returns the DB node ID and pub key of the source node for the
3768
// specified protocol version.
3769
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
UNCOV
3770
        version ProtocolVersion) (int64, route.Vertex, error) {
×
UNCOV
3771

×
UNCOV
3772
        s.srcNodeMu.Lock()
×
UNCOV
3773
        defer s.srcNodeMu.Unlock()
×
UNCOV
3774

×
UNCOV
3775
        // If we already have the source node ID and pub key cached, then
×
UNCOV
3776
        // return them.
×
UNCOV
3777
        if info, ok := s.srcNodes[version]; ok {
×
UNCOV
3778
                return info.id, info.pub, nil
×
UNCOV
3779
        }
×
3780

UNCOV
3781
        var pubKey route.Vertex
×
3782

×
3783
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
3784
        if err != nil {
×
3785
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
3786
                        err)
×
3787
        }
×
3788

3789
        if len(nodes) == 0 {
×
3790
                return 0, pubKey, ErrSourceNodeNotSet
×
3791
        } else if len(nodes) > 1 {
×
UNCOV
3792
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
3793
                        "protocol %s found", version)
×
3794
        }
×
3795

3796
        copy(pubKey[:], nodes[0].PubKey)
×
3797

×
3798
        s.srcNodes[version] = &srcNodeInfo{
×
3799
                id:  nodes[0].NodeID,
×
UNCOV
3800
                pub: pubKey,
×
3801
        }
×
3802

×
3803
        return nodes[0].NodeID, pubKey, nil
×
3804
}
3805

3806
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
3807
// This then produces a map from TLV type to value. If the input is not a
3808
// valid TLV stream, then an error is returned.
3809
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
3810
        r := bytes.NewReader(data)
×
3811

×
3812
        tlvStream, err := tlv.NewStream()
×
3813
        if err != nil {
×
3814
                return nil, err
×
3815
        }
×
3816

3817
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
3818
        // pass it into the P2P decoding variant.
UNCOV
3819
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
UNCOV
3820
        if err != nil {
×
3821
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
3822
        }
×
3823
        if len(parsedTypes) == 0 {
×
3824
                return nil, nil
×
3825
        }
×
3826

3827
        records := make(map[uint64][]byte)
×
UNCOV
3828
        for k, v := range parsedTypes {
×
UNCOV
3829
                records[uint64(k)] = v
×
UNCOV
3830
        }
×
3831

3832
        return records, nil
×
3833
}
3834

3835
// dbChanInfo holds the DB level IDs of a channel and the nodes involved in the
3836
// channel.
3837
type dbChanInfo struct {
3838
        channelID int64
3839
        node1ID   int64
3840
        node2ID   int64
3841
}
3842

3843
// insertChannel inserts a new channel record into the database.
3844
func insertChannel(ctx context.Context, db SQLQueries,
UNCOV
3845
        edge *models.ChannelEdgeInfo) (*dbChanInfo, error) {
×
UNCOV
3846

×
UNCOV
3847
        chanIDB := channelIDToBytes(edge.ChannelID)
×
UNCOV
3848

×
UNCOV
3849
        // Make sure that the channel doesn't already exist. We do this
×
UNCOV
3850
        // explicitly instead of relying on catching a unique constraint error
×
UNCOV
3851
        // because relying on SQL to throw that error would abort the entire
×
UNCOV
3852
        // batch of transactions.
×
UNCOV
3853
        _, err := db.GetChannelBySCID(
×
UNCOV
3854
                ctx, sqlc.GetChannelBySCIDParams{
×
UNCOV
3855
                        Scid:    chanIDB,
×
UNCOV
3856
                        Version: int16(ProtocolV1),
×
3857
                },
×
3858
        )
×
3859
        if err == nil {
×
3860
                return nil, ErrEdgeAlreadyExist
×
3861
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3862
                return nil, fmt.Errorf("unable to fetch channel: %w", err)
×
3863
        }
×
3864

3865
        // Make sure that at least a "shell" entry for each node is present in
3866
        // the nodes table.
3867
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
3868
        if err != nil {
×
3869
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3870
        }
×
3871

3872
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
3873
        if err != nil {
×
3874
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3875
        }
×
3876

UNCOV
3877
        var capacity sql.NullInt64
×
UNCOV
3878
        if edge.Capacity != 0 {
×
3879
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
3880
        }
×
3881

3882
        createParams := sqlc.CreateChannelParams{
×
UNCOV
3883
                Version:     int16(ProtocolV1),
×
3884
                Scid:        chanIDB,
×
3885
                NodeID1:     node1DBID,
×
3886
                NodeID2:     node2DBID,
×
3887
                Outpoint:    edge.ChannelPoint.String(),
×
UNCOV
3888
                Capacity:    capacity,
×
3889
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
3890
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
3891
        }
×
3892

×
UNCOV
3893
        if edge.AuthProof != nil {
×
3894
                proof := edge.AuthProof
×
3895

×
3896
                createParams.Node1Signature = proof.NodeSig1Bytes
×
3897
                createParams.Node2Signature = proof.NodeSig2Bytes
×
3898
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
3899
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
3900
        }
×
3901

3902
        // Insert the new channel record.
3903
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
3904
        if err != nil {
×
3905
                return nil, err
×
3906
        }
×
3907

3908
        // Insert any channel features.
3909
        for feature := range edge.Features.Features() {
×
3910
                err = db.InsertChannelFeature(
×
3911
                        ctx, sqlc.InsertChannelFeatureParams{
×
3912
                                ChannelID:  dbChanID,
×
UNCOV
3913
                                FeatureBit: int32(feature),
×
UNCOV
3914
                        },
×
3915
                )
×
3916
                if err != nil {
×
3917
                        return nil, fmt.Errorf("unable to insert channel(%d) "+
×
3918
                                "feature(%v): %w", dbChanID, feature, err)
×
UNCOV
3919
                }
×
3920
        }
3921

3922
        // Finally, insert any extra TLV fields in the channel announcement.
3923
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3924
        if err != nil {
×
3925
                return nil, fmt.Errorf("unable to marshal extra opaque "+
×
3926
                        "data: %w", err)
×
3927
        }
×
3928

3929
        for tlvType, value := range extra {
×
3930
                err := db.CreateChannelExtraType(
×
3931
                        ctx, sqlc.CreateChannelExtraTypeParams{
×
UNCOV
3932
                                ChannelID: dbChanID,
×
UNCOV
3933
                                Type:      int64(tlvType),
×
UNCOV
3934
                                Value:     value,
×
3935
                        },
×
3936
                )
×
3937
                if err != nil {
×
3938
                        return nil, fmt.Errorf("unable to upsert "+
×
3939
                                "channel(%d) extra signed field(%v): %w",
×
UNCOV
3940
                                edge.ChannelID, tlvType, err)
×
3941
                }
×
3942
        }
3943

3944
        return &dbChanInfo{
×
3945
                channelID: dbChanID,
×
3946
                node1ID:   node1DBID,
×
3947
                node2ID:   node2DBID,
×
3948
        }, nil
×
3949
}
3950

3951
// maybeCreateShellNode checks if a shell node entry exists for the
3952
// given public key. If it does not exist, then a new shell node entry is
3953
// created. The ID of the node is returned. A shell node only has a protocol
3954
// version and public key persisted.
3955
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
3956
        pubKey route.Vertex) (int64, error) {
×
3957

×
3958
        dbNode, err := db.GetNodeByPubKey(
×
3959
                ctx, sqlc.GetNodeByPubKeyParams{
×
3960
                        PubKey:  pubKey[:],
×
UNCOV
3961
                        Version: int16(ProtocolV1),
×
UNCOV
3962
                },
×
UNCOV
3963
        )
×
UNCOV
3964
        // The node exists. Return the ID.
×
UNCOV
3965
        if err == nil {
×
UNCOV
3966
                return dbNode.ID, nil
×
UNCOV
3967
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3968
                return 0, err
×
3969
        }
×
3970

3971
        // Otherwise, the node does not exist, so we create a shell entry for
3972
        // it.
3973
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
3974
                Version: int16(ProtocolV1),
×
3975
                PubKey:  pubKey[:],
×
3976
        })
×
3977
        if err != nil {
×
3978
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
3979
        }
×
3980

3981
        return id, nil
×
3982
}
3983

3984
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
3985
// the database. This includes deleting any existing types and then inserting
3986
// the new types.
3987
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
3988
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
3989

×
3990
        // Delete all existing extra signed fields for the channel policy.
×
3991
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
UNCOV
3992
        if err != nil {
×
3993
                return fmt.Errorf("unable to delete "+
×
UNCOV
3994
                        "existing policy extra signed fields for policy %d: %w",
×
UNCOV
3995
                        chanPolicyID, err)
×
UNCOV
3996
        }
×
3997

3998
        // Insert all new extra signed fields for the channel policy.
UNCOV
3999
        for tlvType, value := range extraFields {
×
4000
                err = db.InsertChanPolicyExtraType(
×
4001
                        ctx, sqlc.InsertChanPolicyExtraTypeParams{
×
4002
                                ChannelPolicyID: chanPolicyID,
×
4003
                                Type:            int64(tlvType),
×
4004
                                Value:           value,
×
4005
                        },
×
4006
                )
×
4007
                if err != nil {
×
4008
                        return fmt.Errorf("unable to insert "+
×
UNCOV
4009
                                "channel_policy(%d) extra signed field(%v): %w",
×
UNCOV
4010
                                chanPolicyID, tlvType, err)
×
4011
                }
×
4012
        }
4013

4014
        return nil
×
4015
}
4016

4017
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
4018
// provided dbChanRow and also fetches any other required information
4019
// to construct the edge info.
4020
func getAndBuildEdgeInfo(ctx context.Context, db SQLQueries,
4021
        chain chainhash.Hash, dbChanID int64, dbChan sqlc.GraphChannel, node1,
4022
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
4023

×
UNCOV
4024
        if dbChan.Version != int16(ProtocolV1) {
×
UNCOV
4025
                return nil, fmt.Errorf("unsupported channel version: %d",
×
4026
                        dbChan.Version)
×
UNCOV
4027
        }
×
4028

UNCOV
4029
        fv, extras, err := getChanFeaturesAndExtras(
×
UNCOV
4030
                ctx, db, dbChanID,
×
UNCOV
4031
        )
×
UNCOV
4032
        if err != nil {
×
UNCOV
4033
                return nil, err
×
4034
        }
×
4035

4036
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
4037
        if err != nil {
×
4038
                return nil, err
×
4039
        }
×
4040

4041
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4042
        if err != nil {
×
4043
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4044
                        "fields: %w", err)
×
4045
        }
×
4046
        if recs == nil {
×
UNCOV
4047
                recs = make([]byte, 0)
×
4048
        }
×
4049

4050
        var btcKey1, btcKey2 route.Vertex
×
4051
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
UNCOV
4052
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
4053

×
4054
        channel := &models.ChannelEdgeInfo{
×
4055
                ChainHash:        chain,
×
4056
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
4057
                NodeKey1Bytes:    node1,
×
4058
                NodeKey2Bytes:    node2,
×
4059
                BitcoinKey1Bytes: btcKey1,
×
4060
                BitcoinKey2Bytes: btcKey2,
×
UNCOV
4061
                ChannelPoint:     *op,
×
4062
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
4063
                Features:         fv,
×
4064
                ExtraOpaqueData:  recs,
×
4065
        }
×
4066

×
4067
        // We always set all the signatures at the same time, so we can
×
4068
        // safely check if one signature is present to determine if we have the
×
4069
        // rest of the signatures for the auth proof.
×
4070
        if len(dbChan.Bitcoin1Signature) > 0 {
×
4071
                channel.AuthProof = &models.ChannelAuthProof{
×
4072
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
4073
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
4074
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
4075
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
4076
                }
×
4077
        }
×
4078

4079
        return channel, nil
×
4080
}
4081

4082
// buildNodeVertices is a helper that converts raw node public keys
4083
// into route.Vertex instances.
4084
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4085
        route.Vertex, error) {
×
4086

×
4087
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4088
        if err != nil {
×
4089
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
UNCOV
4090
                        "create vertex from node1 pubkey: %w", err)
×
4091
        }
×
4092

UNCOV
4093
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
UNCOV
4094
        if err != nil {
×
UNCOV
4095
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
UNCOV
4096
                        "create vertex from node2 pubkey: %w", err)
×
4097
        }
×
4098

4099
        return node1Vertex, node2Vertex, nil
×
4100
}
4101

4102
// getChanFeaturesAndExtras fetches the channel features and extra TLV types
4103
// for a channel with the given ID.
4104
func getChanFeaturesAndExtras(ctx context.Context, db SQLQueries,
4105
        id int64) (*lnwire.FeatureVector, map[uint64][]byte, error) {
×
4106

×
4107
        rows, err := db.GetChannelFeaturesAndExtras(ctx, id)
×
4108
        if err != nil {
×
4109
                return nil, nil, fmt.Errorf("unable to fetch channel "+
×
UNCOV
4110
                        "features and extras: %w", err)
×
4111
        }
×
4112

UNCOV
4113
        var (
×
UNCOV
4114
                fv     = lnwire.EmptyFeatureVector()
×
UNCOV
4115
                extras = make(map[uint64][]byte)
×
UNCOV
4116
        )
×
4117
        for _, row := range rows {
×
4118
                if row.IsFeature {
×
4119
                        fv.Set(lnwire.FeatureBit(row.FeatureBit))
×
4120

×
4121
                        continue
×
4122
                }
4123

UNCOV
4124
                tlvType, ok := row.ExtraKey.(int64)
×
4125
                if !ok {
×
4126
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
4127
                                "TLV type: %T", row.ExtraKey)
×
4128
                }
×
4129

4130
                valueBytes, ok := row.Value.([]byte)
×
4131
                if !ok {
×
4132
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
4133
                                "Value: %T", row.Value)
×
UNCOV
4134
                }
×
4135

4136
                extras[uint64(tlvType)] = valueBytes
×
4137
        }
4138

4139
        return fv, extras, nil
×
4140
}
4141

4142
// getAndBuildChanPolicies uses the given sqlc.GraphChannelPolicy and also
4143
// retrieves all the extra info required to build the complete
4144
// models.ChannelEdgePolicy types. It returns two policies, which may be nil if
4145
// the provided sqlc.GraphChannelPolicy records are nil.
4146
func getAndBuildChanPolicies(ctx context.Context, db SQLQueries,
4147
        dbPol1, dbPol2 *sqlc.GraphChannelPolicy, channelID uint64, node1,
4148
        node2 route.Vertex) (*models.ChannelEdgePolicy,
UNCOV
4149
        *models.ChannelEdgePolicy, error) {
×
UNCOV
4150

×
4151
        if dbPol1 == nil && dbPol2 == nil {
×
UNCOV
4152
                return nil, nil, nil
×
UNCOV
4153
        }
×
4154

UNCOV
4155
        var (
×
UNCOV
4156
                policy1ID int64
×
UNCOV
4157
                policy2ID int64
×
UNCOV
4158
        )
×
UNCOV
4159
        if dbPol1 != nil {
×
UNCOV
4160
                policy1ID = dbPol1.ID
×
4161
        }
×
4162
        if dbPol2 != nil {
×
4163
                policy2ID = dbPol2.ID
×
4164
        }
×
4165
        rows, err := db.GetChannelPolicyExtraTypes(
×
UNCOV
4166
                ctx, sqlc.GetChannelPolicyExtraTypesParams{
×
4167
                        ID:   policy1ID,
×
4168
                        ID_2: policy2ID,
×
4169
                },
×
4170
        )
×
4171
        if err != nil {
×
4172
                return nil, nil, err
×
4173
        }
×
4174

4175
        var (
×
4176
                dbPol1Extras = make(map[uint64][]byte)
×
4177
                dbPol2Extras = make(map[uint64][]byte)
×
4178
        )
×
4179
        for _, row := range rows {
×
4180
                switch row.PolicyID {
×
4181
                case policy1ID:
×
4182
                        dbPol1Extras[uint64(row.Type)] = row.Value
×
4183
                case policy2ID:
×
4184
                        dbPol2Extras[uint64(row.Type)] = row.Value
×
4185
                default:
×
UNCOV
4186
                        return nil, nil, fmt.Errorf("unexpected policy ID %d "+
×
4187
                                "in row: %v", row.PolicyID, row)
×
4188
                }
4189
        }
4190

4191
        var pol1, pol2 *models.ChannelEdgePolicy
×
4192
        if dbPol1 != nil {
×
4193
                pol1, err = buildChanPolicy(
×
4194
                        *dbPol1, channelID, dbPol1Extras, node2,
×
4195
                )
×
4196
                if err != nil {
×
4197
                        return nil, nil, err
×
4198
                }
×
4199
        }
UNCOV
4200
        if dbPol2 != nil {
×
UNCOV
4201
                pol2, err = buildChanPolicy(
×
UNCOV
4202
                        *dbPol2, channelID, dbPol2Extras, node1,
×
4203
                )
×
4204
                if err != nil {
×
4205
                        return nil, nil, err
×
4206
                }
×
4207
        }
4208

4209
        return pol1, pol2, nil
×
4210
}
4211

4212
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4213
// provided sqlc.GraphChannelPolicy and other required information.
4214
func buildChanPolicy(dbPolicy sqlc.GraphChannelPolicy, channelID uint64,
4215
        extras map[uint64][]byte,
4216
        toNode route.Vertex) (*models.ChannelEdgePolicy, error) {
×
4217

×
4218
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
UNCOV
4219
        if err != nil {
×
UNCOV
4220
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4221
                        "fields: %w", err)
×
UNCOV
4222
        }
×
4223

UNCOV
4224
        var inboundFee fn.Option[lnwire.Fee]
×
UNCOV
4225
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
UNCOV
4226
                dbPolicy.InboundBaseFeeMsat.Valid {
×
UNCOV
4227

×
4228
                inboundFee = fn.Some(lnwire.Fee{
×
4229
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4230
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4231
                })
×
4232
        }
×
4233

4234
        return &models.ChannelEdgePolicy{
×
UNCOV
4235
                SigBytes:  dbPolicy.Signature,
×
4236
                ChannelID: channelID,
×
4237
                LastUpdate: time.Unix(
×
4238
                        dbPolicy.LastUpdate.Int64, 0,
×
4239
                ),
×
4240
                MessageFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags](
×
4241
                        dbPolicy.MessageFlags,
×
4242
                ),
×
4243
                ChannelFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags](
×
4244
                        dbPolicy.ChannelFlags,
×
UNCOV
4245
                ),
×
4246
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
4247
                MinHTLC: lnwire.MilliSatoshi(
×
4248
                        dbPolicy.MinHtlcMsat,
×
4249
                ),
×
4250
                MaxHTLC: lnwire.MilliSatoshi(
×
4251
                        dbPolicy.MaxHtlcMsat.Int64,
×
4252
                ),
×
4253
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4254
                        dbPolicy.BaseFeeMsat,
×
4255
                ),
×
4256
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4257
                ToNode:                    toNode,
×
4258
                InboundFee:                inboundFee,
×
4259
                ExtraOpaqueData:           recs,
×
4260
        }, nil
×
4261
}
4262

4263
// buildNodes builds the models.LightningNode instances for the
4264
// given row which is expected to be a sqlc type that contains node information.
4265
func buildNodes(ctx context.Context, db SQLQueries, dbNode1,
4266
        dbNode2 sqlc.GraphNode) (*models.LightningNode, *models.LightningNode,
4267
        error) {
×
4268

×
4269
        node1, err := buildNode(ctx, db, &dbNode1)
×
4270
        if err != nil {
×
4271
                return nil, nil, err
×
4272
        }
×
4273

UNCOV
4274
        node2, err := buildNode(ctx, db, &dbNode2)
×
UNCOV
4275
        if err != nil {
×
UNCOV
4276
                return nil, nil, err
×
UNCOV
4277
        }
×
4278

4279
        return node1, node2, nil
×
4280
}
4281

4282
// extractChannelPolicies extracts the sqlc.GraphChannelPolicy records from the give
4283
// row which is expected to be a sqlc type that contains channel policy
4284
// information. It returns two policies, which may be nil if the policy
4285
// information is not present in the row.
4286
//
4287
//nolint:ll,dupl,funlen
4288
func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy,
4289
        *sqlc.GraphChannelPolicy, error) {
×
UNCOV
4290

×
4291
        var policy1, policy2 *sqlc.GraphChannelPolicy
×
UNCOV
4292
        switch r := row.(type) {
×
UNCOV
4293
        case sqlc.GetChannelsBySCIDWithPoliciesRow:
×
UNCOV
4294
                if r.Policy1ID.Valid {
×
UNCOV
4295
                        policy1 = &sqlc.GraphChannelPolicy{
×
UNCOV
4296
                                ID:                      r.Policy1ID.Int64,
×
UNCOV
4297
                                Version:                 r.Policy1Version.Int16,
×
UNCOV
4298
                                ChannelID:               r.GraphChannel.ID,
×
UNCOV
4299
                                NodeID:                  r.Policy1NodeID.Int64,
×
UNCOV
4300
                                Timelock:                r.Policy1Timelock.Int32,
×
4301
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4302
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4303
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4304
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
NEW
4305
                                LastUpdate:              r.Policy1LastUpdate,
×
NEW
4306
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
NEW
4307
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
NEW
4308
                                Disabled:                r.Policy1Disabled,
×
NEW
4309
                                MessageFlags:            r.Policy1MessageFlags,
×
NEW
4310
                                ChannelFlags:            r.Policy1ChannelFlags,
×
NEW
4311
                                Signature:               r.Policy1Signature,
×
NEW
4312
                        }
×
NEW
4313
                }
×
NEW
4314
                if r.Policy2ID.Valid {
×
NEW
4315
                        policy2 = &sqlc.GraphChannelPolicy{
×
NEW
4316
                                ID:                      r.Policy2ID.Int64,
×
NEW
4317
                                Version:                 r.Policy2Version.Int16,
×
NEW
4318
                                ChannelID:               r.GraphChannel.ID,
×
NEW
4319
                                NodeID:                  r.Policy2NodeID.Int64,
×
NEW
4320
                                Timelock:                r.Policy2Timelock.Int32,
×
NEW
4321
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
NEW
4322
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
NEW
4323
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
NEW
4324
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
NEW
4325
                                LastUpdate:              r.Policy2LastUpdate,
×
NEW
4326
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
NEW
4327
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
NEW
4328
                                Disabled:                r.Policy2Disabled,
×
NEW
4329
                                MessageFlags:            r.Policy2MessageFlags,
×
NEW
4330
                                ChannelFlags:            r.Policy2ChannelFlags,
×
NEW
4331
                                Signature:               r.Policy2Signature,
×
NEW
4332
                        }
×
NEW
4333
                }
×
4334

NEW
4335
                return policy1, policy2, nil
×
4336

NEW
4337
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
NEW
4338
                if r.Policy1ID.Valid {
×
NEW
4339
                        policy1 = &sqlc.GraphChannelPolicy{
×
NEW
4340
                                ID:                      r.Policy1ID.Int64,
×
NEW
4341
                                Version:                 r.Policy1Version.Int16,
×
NEW
4342
                                ChannelID:               r.GraphChannel.ID,
×
NEW
4343
                                NodeID:                  r.Policy1NodeID.Int64,
×
NEW
4344
                                Timelock:                r.Policy1Timelock.Int32,
×
NEW
4345
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
NEW
4346
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
NEW
4347
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
NEW
4348
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4349
                                LastUpdate:              r.Policy1LastUpdate,
×
4350
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4351
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4352
                                Disabled:                r.Policy1Disabled,
×
4353
                                MessageFlags:            r.Policy1MessageFlags,
×
4354
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4355
                                Signature:               r.Policy1Signature,
×
4356
                        }
×
4357
                }
×
4358
                if r.Policy2ID.Valid {
×
4359
                        policy2 = &sqlc.GraphChannelPolicy{
×
4360
                                ID:                      r.Policy2ID.Int64,
×
4361
                                Version:                 r.Policy2Version.Int16,
×
4362
                                ChannelID:               r.GraphChannel.ID,
×
4363
                                NodeID:                  r.Policy2NodeID.Int64,
×
4364
                                Timelock:                r.Policy2Timelock.Int32,
×
4365
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4366
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4367
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4368
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4369
                                LastUpdate:              r.Policy2LastUpdate,
×
4370
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4371
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4372
                                Disabled:                r.Policy2Disabled,
×
4373
                                MessageFlags:            r.Policy2MessageFlags,
×
4374
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4375
                                Signature:               r.Policy2Signature,
×
4376
                        }
×
4377
                }
×
4378

4379
                return policy1, policy2, nil
×
4380

4381
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4382
                if r.Policy1ID.Valid {
×
4383
                        policy1 = &sqlc.GraphChannelPolicy{
×
4384
                                ID:                      r.Policy1ID.Int64,
×
4385
                                Version:                 r.Policy1Version.Int16,
×
4386
                                ChannelID:               r.GraphChannel.ID,
×
4387
                                NodeID:                  r.Policy1NodeID.Int64,
×
4388
                                Timelock:                r.Policy1Timelock.Int32,
×
4389
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
UNCOV
4390
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4391
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
UNCOV
4392
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4393
                                LastUpdate:              r.Policy1LastUpdate,
×
4394
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4395
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4396
                                Disabled:                r.Policy1Disabled,
×
4397
                                MessageFlags:            r.Policy1MessageFlags,
×
4398
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4399
                                Signature:               r.Policy1Signature,
×
4400
                        }
×
4401
                }
×
4402
                if r.Policy2ID.Valid {
×
4403
                        policy2 = &sqlc.GraphChannelPolicy{
×
4404
                                ID:                      r.Policy2ID.Int64,
×
4405
                                Version:                 r.Policy2Version.Int16,
×
4406
                                ChannelID:               r.GraphChannel.ID,
×
4407
                                NodeID:                  r.Policy2NodeID.Int64,
×
4408
                                Timelock:                r.Policy2Timelock.Int32,
×
4409
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4410
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4411
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4412
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4413
                                LastUpdate:              r.Policy2LastUpdate,
×
4414
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4415
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4416
                                Disabled:                r.Policy2Disabled,
×
4417
                                MessageFlags:            r.Policy2MessageFlags,
×
4418
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4419
                                Signature:               r.Policy2Signature,
×
4420
                        }
×
4421
                }
×
4422

4423
                return policy1, policy2, nil
×
4424

4425
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4426
                if r.Policy1ID.Valid {
×
4427
                        policy1 = &sqlc.GraphChannelPolicy{
×
4428
                                ID:                      r.Policy1ID.Int64,
×
4429
                                Version:                 r.Policy1Version.Int16,
×
4430
                                ChannelID:               r.GraphChannel.ID,
×
4431
                                NodeID:                  r.Policy1NodeID.Int64,
×
4432
                                Timelock:                r.Policy1Timelock.Int32,
×
4433
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
UNCOV
4434
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4435
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
UNCOV
4436
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4437
                                LastUpdate:              r.Policy1LastUpdate,
×
4438
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4439
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4440
                                Disabled:                r.Policy1Disabled,
×
4441
                                MessageFlags:            r.Policy1MessageFlags,
×
4442
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4443
                                Signature:               r.Policy1Signature,
×
4444
                        }
×
4445
                }
×
4446
                if r.Policy2ID.Valid {
×
4447
                        policy2 = &sqlc.GraphChannelPolicy{
×
4448
                                ID:                      r.Policy2ID.Int64,
×
4449
                                Version:                 r.Policy2Version.Int16,
×
4450
                                ChannelID:               r.GraphChannel.ID,
×
4451
                                NodeID:                  r.Policy2NodeID.Int64,
×
4452
                                Timelock:                r.Policy2Timelock.Int32,
×
4453
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4454
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4455
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4456
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4457
                                LastUpdate:              r.Policy2LastUpdate,
×
4458
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4459
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4460
                                Disabled:                r.Policy2Disabled,
×
4461
                                MessageFlags:            r.Policy2MessageFlags,
×
4462
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4463
                                Signature:               r.Policy2Signature,
×
4464
                        }
×
4465
                }
×
4466

4467
                return policy1, policy2, nil
×
4468

4469
        case sqlc.ListChannelsByNodeIDRow:
×
4470
                if r.Policy1ID.Valid {
×
4471
                        policy1 = &sqlc.GraphChannelPolicy{
×
4472
                                ID:                      r.Policy1ID.Int64,
×
4473
                                Version:                 r.Policy1Version.Int16,
×
4474
                                ChannelID:               r.GraphChannel.ID,
×
4475
                                NodeID:                  r.Policy1NodeID.Int64,
×
4476
                                Timelock:                r.Policy1Timelock.Int32,
×
4477
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
UNCOV
4478
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4479
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
UNCOV
4480
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4481
                                LastUpdate:              r.Policy1LastUpdate,
×
4482
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4483
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4484
                                Disabled:                r.Policy1Disabled,
×
4485
                                MessageFlags:            r.Policy1MessageFlags,
×
4486
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4487
                                Signature:               r.Policy1Signature,
×
4488
                        }
×
4489
                }
×
4490
                if r.Policy2ID.Valid {
×
4491
                        policy2 = &sqlc.GraphChannelPolicy{
×
4492
                                ID:                      r.Policy2ID.Int64,
×
4493
                                Version:                 r.Policy2Version.Int16,
×
4494
                                ChannelID:               r.GraphChannel.ID,
×
4495
                                NodeID:                  r.Policy2NodeID.Int64,
×
4496
                                Timelock:                r.Policy2Timelock.Int32,
×
4497
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4498
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4499
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4500
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4501
                                LastUpdate:              r.Policy2LastUpdate,
×
4502
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4503
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4504
                                Disabled:                r.Policy2Disabled,
×
4505
                                MessageFlags:            r.Policy2MessageFlags,
×
4506
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4507
                                Signature:               r.Policy2Signature,
×
4508
                        }
×
4509
                }
×
4510

4511
                return policy1, policy2, nil
×
4512

4513
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4514
                if r.Policy1ID.Valid {
×
4515
                        policy1 = &sqlc.GraphChannelPolicy{
×
4516
                                ID:                      r.Policy1ID.Int64,
×
4517
                                Version:                 r.Policy1Version.Int16,
×
4518
                                ChannelID:               r.GraphChannel.ID,
×
4519
                                NodeID:                  r.Policy1NodeID.Int64,
×
4520
                                Timelock:                r.Policy1Timelock.Int32,
×
4521
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
UNCOV
4522
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4523
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
UNCOV
4524
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4525
                                LastUpdate:              r.Policy1LastUpdate,
×
4526
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4527
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4528
                                Disabled:                r.Policy1Disabled,
×
4529
                                MessageFlags:            r.Policy1MessageFlags,
×
4530
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4531
                                Signature:               r.Policy1Signature,
×
4532
                        }
×
4533
                }
×
4534
                if r.Policy2ID.Valid {
×
4535
                        policy2 = &sqlc.GraphChannelPolicy{
×
4536
                                ID:                      r.Policy2ID.Int64,
×
4537
                                Version:                 r.Policy2Version.Int16,
×
4538
                                ChannelID:               r.GraphChannel.ID,
×
4539
                                NodeID:                  r.Policy2NodeID.Int64,
×
4540
                                Timelock:                r.Policy2Timelock.Int32,
×
4541
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4542
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4543
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4544
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4545
                                LastUpdate:              r.Policy2LastUpdate,
×
4546
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4547
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4548
                                Disabled:                r.Policy2Disabled,
×
4549
                                MessageFlags:            r.Policy2MessageFlags,
×
4550
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4551
                                Signature:               r.Policy2Signature,
×
4552
                        }
×
4553
                }
×
4554

4555
                return policy1, policy2, nil
×
4556
        default:
×
4557
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
4558
                        "extractChannelPolicies: %T", r)
×
4559
        }
4560
}
4561

4562
// channelIDToBytes converts a channel ID (SCID) to a byte array
4563
// representation.
4564
func channelIDToBytes(channelID uint64) []byte {
×
4565
        var chanIDB [8]byte
×
UNCOV
4566
        byteOrder.PutUint64(chanIDB[:], channelID)
×
4567

×
4568
        return chanIDB[:]
×
4569
}
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc