• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 15880679691

25 Jun 2025 03:32PM UTC coverage: 57.316% (+0.6%) from 56.712%
15880679691

push

github

web-flow
Merge pull request #9971 from ellemouton/graphSQL16-closed-scids

[16] graph/db: SQL closed SCIDs table and last few methods

2 of 159 new or added lines in 2 files covered. (1.26%)

235 existing lines in 9 files now uncovered.

97544 of 170187 relevant lines covered (57.32%)

1.2 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "maps"
11
        "math"
12
        "net"
13
        "slices"
14
        "strconv"
15
        "sync"
16
        "time"
17

18
        "github.com/btcsuite/btcd/btcec/v2"
19
        "github.com/btcsuite/btcd/btcutil"
20
        "github.com/btcsuite/btcd/chaincfg/chainhash"
21
        "github.com/btcsuite/btcd/wire"
22
        "github.com/lightningnetwork/lnd/aliasmgr"
23
        "github.com/lightningnetwork/lnd/batch"
24
        "github.com/lightningnetwork/lnd/fn/v2"
25
        "github.com/lightningnetwork/lnd/graph/db/models"
26
        "github.com/lightningnetwork/lnd/lnwire"
27
        "github.com/lightningnetwork/lnd/routing/route"
28
        "github.com/lightningnetwork/lnd/sqldb"
29
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
30
        "github.com/lightningnetwork/lnd/tlv"
31
        "github.com/lightningnetwork/lnd/tor"
32
)
33

34
// pageSize is the limit for the number of records that can be returned
35
// in a paginated query. This can be tuned after some benchmarks.
36
const pageSize = 2000
37

38
// ProtocolVersion is an enum that defines the gossip protocol version of a
39
// message.
40
type ProtocolVersion uint8
41

42
const (
43
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
44
        ProtocolV1 ProtocolVersion = 1
45
)
46

47
// String returns a string representation of the protocol version.
48
func (v ProtocolVersion) String() string {
×
49
        return fmt.Sprintf("V%d", v)
×
50
}
×
51

52
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
53
// execute queries against the SQL graph tables.
54
//
55
//nolint:ll,interfacebloat
56
type SQLQueries interface {
57
        /*
58
                Node queries.
59
        */
60
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
61
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.Node, error)
62
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
63
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.Node, error)
64
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.Node, error)
65
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
66
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
67
        GetUnconnectedNodes(ctx context.Context) ([]sqlc.GetUnconnectedNodesRow, error)
68
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
69
        DeleteNode(ctx context.Context, id int64) error
70

71
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.NodeExtraType, error)
72
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
73
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
74

75
        InsertNodeAddress(ctx context.Context, arg sqlc.InsertNodeAddressParams) error
76
        GetNodeAddressesByPubKey(ctx context.Context, arg sqlc.GetNodeAddressesByPubKeyParams) ([]sqlc.GetNodeAddressesByPubKeyRow, error)
77
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
78

79
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
80
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.NodeFeature, error)
81
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
82
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
83

84
        /*
85
                Source node queries.
86
        */
87
        AddSourceNode(ctx context.Context, nodeID int64) error
88
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
89

90
        /*
91
                Channel queries.
92
        */
93
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
94
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
95
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.Channel, error)
96
        GetChannelByOutpoint(ctx context.Context, outpoint string) (sqlc.GetChannelByOutpointRow, error)
97
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
98
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
99
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
100
        GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]sqlc.GetChannelFeaturesAndExtrasRow, error)
101
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
102
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
103
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
104
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
105
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
106
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
107
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.Channel, error)
108
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
109
        DeleteChannel(ctx context.Context, id int64) error
110

111
        CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
112
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
113

114
        /*
115
                Channel Policy table queries.
116
        */
117
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
118
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.ChannelPolicy, error)
119
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
120

121
        InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error
122
        GetChannelPolicyExtraTypes(ctx context.Context, arg sqlc.GetChannelPolicyExtraTypesParams) ([]sqlc.GetChannelPolicyExtraTypesRow, error)
123
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
124

125
        /*
126
                Zombie index queries.
127
        */
128
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
129
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.ZombieChannel, error)
130
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
131
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
132
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
133

134
        /*
135
                Prune log table queries.
136
        */
137
        GetPruneTip(ctx context.Context) (sqlc.PruneLog, error)
138
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
139
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
140

141
        /*
142
                Closed SCID table queries.
143
        */
144
        InsertClosedChannel(ctx context.Context, scid []byte) error
145
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
146
}
147

148
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
149
// database operations.
150
type BatchedSQLQueries interface {
151
        SQLQueries
152
        sqldb.BatchedTx[SQLQueries]
153
}
154

155
// SQLStore is an implementation of the V1Store interface that uses a SQL
156
// database as the backend.
157
//
158
// NOTE: currently, this temporarily embeds the KVStore struct so that we can
159
// implement the V1Store interface incrementally. For any method not
160
// implemented,  things will fall back to the KVStore. This is ONLY the case
161
// for the time being while this struct is purely used in unit tests only.
162
type SQLStore struct {
163
        cfg *SQLStoreConfig
164
        db  BatchedSQLQueries
165

166
        // cacheMu guards all caches (rejectCache and chanCache). If
167
        // this mutex will be acquired at the same time as the DB mutex then
168
        // the cacheMu MUST be acquired first to prevent deadlock.
169
        cacheMu     sync.RWMutex
170
        rejectCache *rejectCache
171
        chanCache   *channelCache
172

173
        chanScheduler batch.Scheduler[SQLQueries]
174
        nodeScheduler batch.Scheduler[SQLQueries]
175

176
        srcNodes  map[ProtocolVersion]*srcNodeInfo
177
        srcNodeMu sync.Mutex
178

179
        // Temporary fall-back to the KVStore so that we can implement the
180
        // interface incrementally.
181
        *KVStore
182
}
183

184
// A compile-time assertion to ensure that SQLStore implements the V1Store
185
// interface.
186
var _ V1Store = (*SQLStore)(nil)
187

188
// SQLStoreConfig holds the configuration for the SQLStore.
189
type SQLStoreConfig struct {
190
        // ChainHash is the genesis hash for the chain that all the gossip
191
        // messages in this store are aimed at.
192
        ChainHash chainhash.Hash
193
}
194

195
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
196
// storage backend.
197
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries, kvStore *KVStore,
198
        options ...StoreOptionModifier) (*SQLStore, error) {
×
199

×
200
        opts := DefaultOptions()
×
201
        for _, o := range options {
×
202
                o(opts)
×
203
        }
×
204

205
        if opts.NoMigration {
×
206
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
207
                        "supported for SQL stores")
×
208
        }
×
209

210
        s := &SQLStore{
×
211
                cfg:         cfg,
×
212
                db:          db,
×
213
                KVStore:     kvStore,
×
214
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
215
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
216
                srcNodes:    make(map[ProtocolVersion]*srcNodeInfo),
×
217
        }
×
218

×
219
        s.chanScheduler = batch.NewTimeScheduler(
×
220
                db, &s.cacheMu, opts.BatchCommitInterval,
×
221
        )
×
222
        s.nodeScheduler = batch.NewTimeScheduler(
×
223
                db, nil, opts.BatchCommitInterval,
×
224
        )
×
225

×
226
        return s, nil
×
227
}
228

229
// AddLightningNode adds a vertex/node to the graph database. If the node is not
230
// in the database from before, this will add a new, unconnected one to the
231
// graph. If it is present from before, this will update that node's
232
// information.
233
//
234
// NOTE: part of the V1Store interface.
235
func (s *SQLStore) AddLightningNode(ctx context.Context,
236
        node *models.LightningNode, opts ...batch.SchedulerOption) error {
×
237

×
238
        r := &batch.Request[SQLQueries]{
×
239
                Opts: batch.NewSchedulerOptions(opts...),
×
240
                Do: func(queries SQLQueries) error {
×
241
                        _, err := upsertNode(ctx, queries, node)
×
242
                        return err
×
243
                },
×
244
        }
245

246
        return s.nodeScheduler.Execute(ctx, r)
×
247
}
248

249
// FetchLightningNode attempts to look up a target node by its identity public
250
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
251
// returned.
252
//
253
// NOTE: part of the V1Store interface.
254
func (s *SQLStore) FetchLightningNode(ctx context.Context,
255
        pubKey route.Vertex) (*models.LightningNode, error) {
×
256

×
257
        var node *models.LightningNode
×
258
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
259
                var err error
×
260
                _, node, err = getNodeByPubKey(ctx, db, pubKey)
×
261

×
262
                return err
×
263
        }, sqldb.NoOpReset)
×
264
        if err != nil {
×
265
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
266
        }
×
267

268
        return node, nil
×
269
}
270

271
// HasLightningNode determines if the graph has a vertex identified by the
272
// target node identity public key. If the node exists in the database, a
273
// timestamp of when the data for the node was lasted updated is returned along
274
// with a true boolean. Otherwise, an empty time.Time is returned with a false
275
// boolean.
276
//
277
// NOTE: part of the V1Store interface.
278
func (s *SQLStore) HasLightningNode(ctx context.Context,
279
        pubKey [33]byte) (time.Time, bool, error) {
×
280

×
281
        var (
×
282
                exists     bool
×
283
                lastUpdate time.Time
×
284
        )
×
285
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
286
                dbNode, err := db.GetNodeByPubKey(
×
287
                        ctx, sqlc.GetNodeByPubKeyParams{
×
288
                                Version: int16(ProtocolV1),
×
289
                                PubKey:  pubKey[:],
×
290
                        },
×
291
                )
×
292
                if errors.Is(err, sql.ErrNoRows) {
×
293
                        return nil
×
294
                } else if err != nil {
×
295
                        return fmt.Errorf("unable to fetch node: %w", err)
×
296
                }
×
297

298
                exists = true
×
299

×
300
                if dbNode.LastUpdate.Valid {
×
301
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
302
                }
×
303

304
                return nil
×
305
        }, sqldb.NoOpReset)
306
        if err != nil {
×
307
                return time.Time{}, false,
×
308
                        fmt.Errorf("unable to fetch node: %w", err)
×
309
        }
×
310

311
        return lastUpdate, exists, nil
×
312
}
313

314
// AddrsForNode returns all known addresses for the target node public key
315
// that the graph DB is aware of. The returned boolean indicates if the
316
// given node is unknown to the graph DB or not.
317
//
318
// NOTE: part of the V1Store interface.
319
func (s *SQLStore) AddrsForNode(ctx context.Context,
320
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
321

×
322
        var (
×
323
                addresses []net.Addr
×
324
                known     bool
×
325
        )
×
326
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
327
                var err error
×
328
                known, addresses, err = getNodeAddresses(
×
329
                        ctx, db, nodePub.SerializeCompressed(),
×
330
                )
×
331
                if err != nil {
×
332
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
333
                                err)
×
334
                }
×
335

336
                return nil
×
337
        }, sqldb.NoOpReset)
338
        if err != nil {
×
339
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
340
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
341
        }
×
342

343
        return known, addresses, nil
×
344
}
345

346
// DeleteLightningNode starts a new database transaction to remove a vertex/node
347
// from the database according to the node's public key.
348
//
349
// NOTE: part of the V1Store interface.
350
func (s *SQLStore) DeleteLightningNode(ctx context.Context,
351
        pubKey route.Vertex) error {
×
352

×
353
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
354
                res, err := db.DeleteNodeByPubKey(
×
355
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
356
                                Version: int16(ProtocolV1),
×
357
                                PubKey:  pubKey[:],
×
358
                        },
×
359
                )
×
360
                if err != nil {
×
361
                        return err
×
362
                }
×
363

364
                rows, err := res.RowsAffected()
×
365
                if err != nil {
×
366
                        return err
×
367
                }
×
368

369
                if rows == 0 {
×
370
                        return ErrGraphNodeNotFound
×
371
                } else if rows > 1 {
×
372
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
373
                }
×
374

375
                return err
×
376
        }, sqldb.NoOpReset)
377
        if err != nil {
×
378
                return fmt.Errorf("unable to delete node: %w", err)
×
379
        }
×
380

381
        return nil
×
382
}
383

384
// FetchNodeFeatures returns the features of the given node. If no features are
385
// known for the node, an empty feature vector is returned.
386
//
387
// NOTE: this is part of the graphdb.NodeTraverser interface.
388
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
389
        *lnwire.FeatureVector, error) {
×
390

×
391
        ctx := context.TODO()
×
392

×
393
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
394
}
×
395

396
// DisabledChannelIDs returns the channel ids of disabled channels.
397
// A channel is disabled when two of the associated ChanelEdgePolicies
398
// have their disabled bit on.
399
//
400
// NOTE: part of the V1Store interface.
401
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
402
        var (
×
403
                ctx     = context.TODO()
×
404
                chanIDs []uint64
×
405
        )
×
406
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
407
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
408
                if err != nil {
×
409
                        return fmt.Errorf("unable to fetch disabled "+
×
410
                                "channels: %w", err)
×
411
                }
×
412

413
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
414

×
415
                return nil
×
416
        }, sqldb.NoOpReset)
417
        if err != nil {
×
418
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
419
                        err)
×
420
        }
×
421

422
        return chanIDs, nil
×
423
}
424

425
// LookupAlias attempts to return the alias as advertised by the target node.
426
//
427
// NOTE: part of the V1Store interface.
428
func (s *SQLStore) LookupAlias(ctx context.Context,
429
        pub *btcec.PublicKey) (string, error) {
×
430

×
431
        var alias string
×
432
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
433
                dbNode, err := db.GetNodeByPubKey(
×
434
                        ctx, sqlc.GetNodeByPubKeyParams{
×
435
                                Version: int16(ProtocolV1),
×
436
                                PubKey:  pub.SerializeCompressed(),
×
437
                        },
×
438
                )
×
439
                if errors.Is(err, sql.ErrNoRows) {
×
440
                        return ErrNodeAliasNotFound
×
441
                } else if err != nil {
×
442
                        return fmt.Errorf("unable to fetch node: %w", err)
×
443
                }
×
444

445
                if !dbNode.Alias.Valid {
×
446
                        return ErrNodeAliasNotFound
×
447
                }
×
448

449
                alias = dbNode.Alias.String
×
450

×
451
                return nil
×
452
        }, sqldb.NoOpReset)
453
        if err != nil {
×
454
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
455
        }
×
456

457
        return alias, nil
×
458
}
459

460
// SourceNode returns the source node of the graph. The source node is treated
461
// as the center node within a star-graph. This method may be used to kick off
462
// a path finding algorithm in order to explore the reachability of another
463
// node based off the source node.
464
//
465
// NOTE: part of the V1Store interface.
466
func (s *SQLStore) SourceNode(ctx context.Context) (*models.LightningNode,
467
        error) {
×
468

×
469
        var node *models.LightningNode
×
470
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
471
                _, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
472
                if err != nil {
×
473
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
474
                                err)
×
475
                }
×
476

477
                _, node, err = getNodeByPubKey(ctx, db, nodePub)
×
478

×
479
                return err
×
480
        }, sqldb.NoOpReset)
481
        if err != nil {
×
482
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
483
        }
×
484

485
        return node, nil
×
486
}
487

488
// SetSourceNode sets the source node within the graph database. The source
489
// node is to be used as the center of a star-graph within path finding
490
// algorithms.
491
//
492
// NOTE: part of the V1Store interface.
493
func (s *SQLStore) SetSourceNode(ctx context.Context,
494
        node *models.LightningNode) error {
×
495

×
496
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
497
                id, err := upsertNode(ctx, db, node)
×
498
                if err != nil {
×
499
                        return fmt.Errorf("unable to upsert source node: %w",
×
500
                                err)
×
501
                }
×
502

503
                // Make sure that if a source node for this version is already
504
                // set, then the ID is the same as the one we are about to set.
505
                dbSourceNodeID, _, err := s.getSourceNode(ctx, db, ProtocolV1)
×
506
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
507
                        return fmt.Errorf("unable to fetch source node: %w",
×
508
                                err)
×
509
                } else if err == nil {
×
510
                        if dbSourceNodeID != id {
×
511
                                return fmt.Errorf("v1 source node already "+
×
512
                                        "set to a different node: %d vs %d",
×
513
                                        dbSourceNodeID, id)
×
514
                        }
×
515

516
                        return nil
×
517
                }
518

519
                return db.AddSourceNode(ctx, id)
×
520
        }, sqldb.NoOpReset)
521
}
522

523
// NodeUpdatesInHorizon returns all the known lightning node which have an
524
// update timestamp within the passed range. This method can be used by two
525
// nodes to quickly determine if they have the same set of up to date node
526
// announcements.
527
//
528
// NOTE: This is part of the V1Store interface.
529
func (s *SQLStore) NodeUpdatesInHorizon(startTime,
530
        endTime time.Time) ([]models.LightningNode, error) {
×
531

×
532
        ctx := context.TODO()
×
533

×
534
        var nodes []models.LightningNode
×
535
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
536
                dbNodes, err := db.GetNodesByLastUpdateRange(
×
537
                        ctx, sqlc.GetNodesByLastUpdateRangeParams{
×
538
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
539
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
540
                        },
×
541
                )
×
542
                if err != nil {
×
543
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
544
                }
×
545

546
                for _, dbNode := range dbNodes {
×
547
                        node, err := buildNode(ctx, db, &dbNode)
×
548
                        if err != nil {
×
549
                                return fmt.Errorf("unable to build node: %w",
×
550
                                        err)
×
551
                        }
×
552

553
                        nodes = append(nodes, *node)
×
554
                }
555

556
                return nil
×
557
        }, sqldb.NoOpReset)
558
        if err != nil {
×
559
                return nil, fmt.Errorf("unable to fetch nodes: %w", err)
×
560
        }
×
561

562
        return nodes, nil
×
563
}
564

565
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
566
// undirected edge from the two target nodes are created. The information stored
567
// denotes the static attributes of the channel, such as the channelID, the keys
568
// involved in creation of the channel, and the set of features that the channel
569
// supports. The chanPoint and chanID are used to uniquely identify the edge
570
// globally within the database.
571
//
572
// NOTE: part of the V1Store interface.
573
func (s *SQLStore) AddChannelEdge(ctx context.Context,
574
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
575

×
576
        var alreadyExists bool
×
577
        r := &batch.Request[SQLQueries]{
×
578
                Opts: batch.NewSchedulerOptions(opts...),
×
579
                Reset: func() {
×
580
                        alreadyExists = false
×
581
                },
×
582
                Do: func(tx SQLQueries) error {
×
583
                        err := insertChannel(ctx, tx, edge)
×
584

×
585
                        // Silence ErrEdgeAlreadyExist so that the batch can
×
586
                        // succeed, but propagate the error via local state.
×
587
                        if errors.Is(err, ErrEdgeAlreadyExist) {
×
588
                                alreadyExists = true
×
589
                                return nil
×
590
                        }
×
591

592
                        return err
×
593
                },
594
                OnCommit: func(err error) error {
×
595
                        switch {
×
596
                        case err != nil:
×
597
                                return err
×
598
                        case alreadyExists:
×
599
                                return ErrEdgeAlreadyExist
×
600
                        default:
×
601
                                s.rejectCache.remove(edge.ChannelID)
×
602
                                s.chanCache.remove(edge.ChannelID)
×
603
                                return nil
×
604
                        }
605
                },
606
        }
607

608
        return s.chanScheduler.Execute(ctx, r)
×
609
}
610

611
// HighestChanID returns the "highest" known channel ID in the channel graph.
612
// This represents the "newest" channel from the PoV of the chain. This method
613
// can be used by peers to quickly determine if their graphs are in sync.
614
//
615
// NOTE: This is part of the V1Store interface.
616
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
617
        var highestChanID uint64
×
618
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
619
                chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
×
620
                if errors.Is(err, sql.ErrNoRows) {
×
621
                        return nil
×
622
                } else if err != nil {
×
623
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
624
                                err)
×
625
                }
×
626

627
                highestChanID = byteOrder.Uint64(chanID)
×
628

×
629
                return nil
×
630
        }, sqldb.NoOpReset)
631
        if err != nil {
×
632
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
633
        }
×
634

635
        return highestChanID, nil
×
636
}
637

638
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
639
// within the database for the referenced channel. The `flags` attribute within
640
// the ChannelEdgePolicy determines which of the directed edges are being
641
// updated. If the flag is 1, then the first node's information is being
642
// updated, otherwise it's the second node's information. The node ordering is
643
// determined by the lexicographical ordering of the identity public keys of the
644
// nodes on either side of the channel.
645
//
646
// NOTE: part of the V1Store interface.
647
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
648
        edge *models.ChannelEdgePolicy,
649
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
650

×
651
        var (
×
652
                isUpdate1    bool
×
653
                edgeNotFound bool
×
654
                from, to     route.Vertex
×
655
        )
×
656

×
657
        r := &batch.Request[SQLQueries]{
×
658
                Opts: batch.NewSchedulerOptions(opts...),
×
659
                Reset: func() {
×
660
                        isUpdate1 = false
×
661
                        edgeNotFound = false
×
662
                },
×
663
                Do: func(tx SQLQueries) error {
×
664
                        var err error
×
665
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
666
                                ctx, tx, edge,
×
667
                        )
×
668
                        if err != nil {
×
669
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
670
                        }
×
671

672
                        // Silence ErrEdgeNotFound so that the batch can
673
                        // succeed, but propagate the error via local state.
674
                        if errors.Is(err, ErrEdgeNotFound) {
×
675
                                edgeNotFound = true
×
676
                                return nil
×
677
                        }
×
678

679
                        return err
×
680
                },
681
                OnCommit: func(err error) error {
×
682
                        switch {
×
683
                        case err != nil:
×
684
                                return err
×
685
                        case edgeNotFound:
×
686
                                return ErrEdgeNotFound
×
687
                        default:
×
688
                                s.updateEdgeCache(edge, isUpdate1)
×
689
                                return nil
×
690
                        }
691
                },
692
        }
693

694
        err := s.chanScheduler.Execute(ctx, r)
×
695

×
696
        return from, to, err
×
697
}
698

699
// updateEdgeCache updates our reject and channel caches with the new
700
// edge policy information.
701
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
702
        isUpdate1 bool) {
×
703

×
704
        // If an entry for this channel is found in reject cache, we'll modify
×
705
        // the entry with the updated timestamp for the direction that was just
×
706
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
707
        // during the next query for this edge.
×
708
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
709
                if isUpdate1 {
×
710
                        entry.upd1Time = e.LastUpdate.Unix()
×
711
                } else {
×
712
                        entry.upd2Time = e.LastUpdate.Unix()
×
713
                }
×
714
                s.rejectCache.insert(e.ChannelID, entry)
×
715
        }
716

717
        // If an entry for this channel is found in channel cache, we'll modify
718
        // the entry with the updated policy for the direction that was just
719
        // written. If the edge doesn't exist, we'll defer loading the info and
720
        // policies and lazily read from disk during the next query.
721
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
722
                if isUpdate1 {
×
723
                        channel.Policy1 = e
×
724
                } else {
×
725
                        channel.Policy2 = e
×
726
                }
×
727
                s.chanCache.insert(e.ChannelID, channel)
×
728
        }
729
}
730

731
// ForEachSourceNodeChannel iterates through all channels of the source node,
732
// executing the passed callback on each. The call-back is provided with the
733
// channel's outpoint, whether we have a policy for the channel and the channel
734
// peer's node information.
735
//
736
// NOTE: part of the V1Store interface.
737
func (s *SQLStore) ForEachSourceNodeChannel(cb func(chanPoint wire.OutPoint,
738
        havePolicy bool, otherNode *models.LightningNode) error) error {
×
739

×
740
        var ctx = context.TODO()
×
741

×
742
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
743
                nodeID, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
744
                if err != nil {
×
745
                        return fmt.Errorf("unable to fetch source node: %w",
×
746
                                err)
×
747
                }
×
748

749
                return forEachNodeChannel(
×
750
                        ctx, db, s.cfg.ChainHash, nodeID,
×
751
                        func(info *models.ChannelEdgeInfo,
×
752
                                outPolicy *models.ChannelEdgePolicy,
×
753
                                _ *models.ChannelEdgePolicy) error {
×
754

×
755
                                // Fetch the other node.
×
756
                                var (
×
757
                                        otherNodePub [33]byte
×
758
                                        node1        = info.NodeKey1Bytes
×
759
                                        node2        = info.NodeKey2Bytes
×
760
                                )
×
761
                                switch {
×
762
                                case bytes.Equal(node1[:], nodePub[:]):
×
763
                                        otherNodePub = node2
×
764
                                case bytes.Equal(node2[:], nodePub[:]):
×
765
                                        otherNodePub = node1
×
766
                                default:
×
767
                                        return fmt.Errorf("node not " +
×
768
                                                "participating in this channel")
×
769
                                }
770

771
                                _, otherNode, err := getNodeByPubKey(
×
772
                                        ctx, db, otherNodePub,
×
773
                                )
×
774
                                if err != nil {
×
775
                                        return fmt.Errorf("unable to fetch "+
×
776
                                                "other node(%x): %w",
×
777
                                                otherNodePub, err)
×
778
                                }
×
779

780
                                return cb(
×
781
                                        info.ChannelPoint, outPolicy != nil,
×
782
                                        otherNode,
×
783
                                )
×
784
                        },
785
                )
786
        }, sqldb.NoOpReset)
787
}
788

789
// ForEachNode iterates through all the stored vertices/nodes in the graph,
790
// executing the passed callback with each node encountered. If the callback
791
// returns an error, then the transaction is aborted and the iteration stops
792
// early. Any operations performed on the NodeTx passed to the call-back are
793
// executed under the same read transaction and so, methods on the NodeTx object
794
// _MUST_ only be called from within the call-back.
795
//
796
// NOTE: part of the V1Store interface.
797
func (s *SQLStore) ForEachNode(cb func(tx NodeRTx) error) error {
×
798
        var (
×
799
                ctx          = context.TODO()
×
800
                lastID int64 = 0
×
801
        )
×
802

×
803
        handleNode := func(db SQLQueries, dbNode sqlc.Node) error {
×
804
                node, err := buildNode(ctx, db, &dbNode)
×
805
                if err != nil {
×
806
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
807
                                dbNode.ID, err)
×
808
                }
×
809

810
                err = cb(
×
811
                        newSQLGraphNodeTx(db, s.cfg.ChainHash, dbNode.ID, node),
×
812
                )
×
813
                if err != nil {
×
814
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
815
                                dbNode.ID, err)
×
816
                }
×
817

818
                return nil
×
819
        }
820

821
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
822
                for {
×
823
                        nodes, err := db.ListNodesPaginated(
×
824
                                ctx, sqlc.ListNodesPaginatedParams{
×
825
                                        Version: int16(ProtocolV1),
×
826
                                        ID:      lastID,
×
827
                                        Limit:   pageSize,
×
828
                                },
×
829
                        )
×
830
                        if err != nil {
×
831
                                return fmt.Errorf("unable to fetch nodes: %w",
×
832
                                        err)
×
833
                        }
×
834

835
                        if len(nodes) == 0 {
×
836
                                break
×
837
                        }
838

839
                        for _, dbNode := range nodes {
×
840
                                err = handleNode(db, dbNode)
×
841
                                if err != nil {
×
842
                                        return err
×
843
                                }
×
844

845
                                lastID = dbNode.ID
×
846
                        }
847
                }
848

849
                return nil
×
850
        }, sqldb.NoOpReset)
851
}
852

853
// sqlGraphNodeTx is an implementation of the NodeRTx interface backed by the
854
// SQLStore and a SQL transaction.
855
type sqlGraphNodeTx struct {
856
        db    SQLQueries
857
        id    int64
858
        node  *models.LightningNode
859
        chain chainhash.Hash
860
}
861

862
// A compile-time constraint to ensure sqlGraphNodeTx implements the NodeRTx
863
// interface.
864
var _ NodeRTx = (*sqlGraphNodeTx)(nil)
865

866
func newSQLGraphNodeTx(db SQLQueries, chain chainhash.Hash,
867
        id int64, node *models.LightningNode) *sqlGraphNodeTx {
×
868

×
869
        return &sqlGraphNodeTx{
×
870
                db:    db,
×
871
                chain: chain,
×
872
                id:    id,
×
873
                node:  node,
×
874
        }
×
875
}
×
876

877
// Node returns the raw information of the node.
878
//
879
// NOTE: This is a part of the NodeRTx interface.
880
func (s *sqlGraphNodeTx) Node() *models.LightningNode {
×
881
        return s.node
×
882
}
×
883

884
// ForEachChannel can be used to iterate over the node's channels under the same
885
// transaction used to fetch the node.
886
//
887
// NOTE: This is a part of the NodeRTx interface.
888
func (s *sqlGraphNodeTx) ForEachChannel(cb func(*models.ChannelEdgeInfo,
889
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
×
890

×
891
        ctx := context.TODO()
×
892

×
893
        return forEachNodeChannel(ctx, s.db, s.chain, s.id, cb)
×
894
}
×
895

896
// FetchNode fetches the node with the given pub key under the same transaction
897
// used to fetch the current node. The returned node is also a NodeRTx and any
898
// operations on that NodeRTx will also be done under the same transaction.
899
//
900
// NOTE: This is a part of the NodeRTx interface.
901
func (s *sqlGraphNodeTx) FetchNode(nodePub route.Vertex) (NodeRTx, error) {
×
902
        ctx := context.TODO()
×
903

×
904
        id, node, err := getNodeByPubKey(ctx, s.db, nodePub)
×
905
        if err != nil {
×
906
                return nil, fmt.Errorf("unable to fetch V1 node(%x): %w",
×
907
                        nodePub, err)
×
908
        }
×
909

910
        return newSQLGraphNodeTx(s.db, s.chain, id, node), nil
×
911
}
912

913
// ForEachNodeDirectedChannel iterates through all channels of a given node,
914
// executing the passed callback on the directed edge representing the channel
915
// and its incoming policy. If the callback returns an error, then the iteration
916
// is halted with the error propagated back up to the caller.
917
//
918
// Unknown policies are passed into the callback as nil values.
919
//
920
// NOTE: this is part of the graphdb.NodeTraverser interface.
921
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
922
        cb func(channel *DirectedChannel) error) error {
×
923

×
924
        var ctx = context.TODO()
×
925

×
926
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
927
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
928
        }, sqldb.NoOpReset)
×
929
}
930

931
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
932
// graph, executing the passed callback with each node encountered. If the
933
// callback returns an error, then the transaction is aborted and the iteration
934
// stops early.
935
//
936
// NOTE: This is a part of the V1Store interface.
937
func (s *SQLStore) ForEachNodeCacheable(cb func(route.Vertex,
938
        *lnwire.FeatureVector) error) error {
×
939

×
940
        ctx := context.TODO()
×
941

×
942
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
943
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
944
                        nodePub route.Vertex) error {
×
945

×
946
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
947
                        if err != nil {
×
948
                                return fmt.Errorf("unable to fetch node "+
×
949
                                        "features: %w", err)
×
950
                        }
×
951

952
                        return cb(nodePub, features)
×
953
                })
954
        }, sqldb.NoOpReset)
955
        if err != nil {
×
956
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
957
        }
×
958

959
        return nil
×
960
}
961

962
// ForEachNodeChannel iterates through all channels of the given node,
963
// executing the passed callback with an edge info structure and the policies
964
// of each end of the channel. The first edge policy is the outgoing edge *to*
965
// the connecting node, while the second is the incoming edge *from* the
966
// connecting node. If the callback returns an error, then the iteration is
967
// halted with the error propagated back up to the caller.
968
//
969
// Unknown policies are passed into the callback as nil values.
970
//
971
// NOTE: part of the V1Store interface.
972
func (s *SQLStore) ForEachNodeChannel(nodePub route.Vertex,
973
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
974
                *models.ChannelEdgePolicy) error) error {
×
975

×
976
        var ctx = context.TODO()
×
977

×
978
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
979
                dbNode, err := db.GetNodeByPubKey(
×
980
                        ctx, sqlc.GetNodeByPubKeyParams{
×
981
                                Version: int16(ProtocolV1),
×
982
                                PubKey:  nodePub[:],
×
983
                        },
×
984
                )
×
985
                if errors.Is(err, sql.ErrNoRows) {
×
986
                        return nil
×
987
                } else if err != nil {
×
988
                        return fmt.Errorf("unable to fetch node: %w", err)
×
989
                }
×
990

991
                return forEachNodeChannel(
×
992
                        ctx, db, s.cfg.ChainHash, dbNode.ID, cb,
×
993
                )
×
994
        }, sqldb.NoOpReset)
995
}
996

997
// ChanUpdatesInHorizon returns all the known channel edges which have at least
998
// one edge that has an update timestamp within the specified horizon.
999
//
1000
// NOTE: This is part of the V1Store interface.
1001
func (s *SQLStore) ChanUpdatesInHorizon(startTime,
1002
        endTime time.Time) ([]ChannelEdge, error) {
×
1003

×
1004
        s.cacheMu.Lock()
×
1005
        defer s.cacheMu.Unlock()
×
1006

×
1007
        var (
×
1008
                ctx = context.TODO()
×
1009
                // To ensure we don't return duplicate ChannelEdges, we'll use
×
1010
                // an additional map to keep track of the edges already seen to
×
1011
                // prevent re-adding it.
×
1012
                edgesSeen    = make(map[uint64]struct{})
×
1013
                edgesToCache = make(map[uint64]ChannelEdge)
×
1014
                edges        []ChannelEdge
×
1015
                hits         int
×
1016
        )
×
1017
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1018
                rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
1019
                        ctx, sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
1020
                                Version:   int16(ProtocolV1),
×
1021
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
1022
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
1023
                        },
×
1024
                )
×
1025
                if err != nil {
×
1026
                        return err
×
1027
                }
×
1028

1029
                for _, row := range rows {
×
1030
                        // If we've already retrieved the info and policies for
×
1031
                        // this edge, then we can skip it as we don't need to do
×
1032
                        // so again.
×
1033
                        chanIDInt := byteOrder.Uint64(row.Channel.Scid)
×
1034
                        if _, ok := edgesSeen[chanIDInt]; ok {
×
1035
                                continue
×
1036
                        }
1037

1038
                        if channel, ok := s.chanCache.get(chanIDInt); ok {
×
1039
                                hits++
×
1040
                                edgesSeen[chanIDInt] = struct{}{}
×
1041
                                edges = append(edges, channel)
×
1042

×
1043
                                continue
×
1044
                        }
1045

1046
                        node1, node2, err := buildNodes(
×
1047
                                ctx, db, row.Node, row.Node_2,
×
1048
                        )
×
1049
                        if err != nil {
×
1050
                                return err
×
1051
                        }
×
1052

1053
                        channel, err := getAndBuildEdgeInfo(
×
1054
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
1055
                                row.Channel, node1.PubKeyBytes,
×
1056
                                node2.PubKeyBytes,
×
1057
                        )
×
1058
                        if err != nil {
×
1059
                                return fmt.Errorf("unable to build channel "+
×
1060
                                        "info: %w", err)
×
1061
                        }
×
1062

1063
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1064
                        if err != nil {
×
1065
                                return fmt.Errorf("unable to extract channel "+
×
1066
                                        "policies: %w", err)
×
1067
                        }
×
1068

1069
                        p1, p2, err := getAndBuildChanPolicies(
×
1070
                                ctx, db, dbPol1, dbPol2, channel.ChannelID,
×
1071
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
1072
                        )
×
1073
                        if err != nil {
×
1074
                                return fmt.Errorf("unable to build channel "+
×
1075
                                        "policies: %w", err)
×
1076
                        }
×
1077

1078
                        edgesSeen[chanIDInt] = struct{}{}
×
1079
                        chanEdge := ChannelEdge{
×
1080
                                Info:    channel,
×
1081
                                Policy1: p1,
×
1082
                                Policy2: p2,
×
1083
                                Node1:   node1,
×
1084
                                Node2:   node2,
×
1085
                        }
×
1086
                        edges = append(edges, chanEdge)
×
1087
                        edgesToCache[chanIDInt] = chanEdge
×
1088
                }
1089

1090
                return nil
×
1091
        }, func() {
×
1092
                edgesSeen = make(map[uint64]struct{})
×
1093
                edgesToCache = make(map[uint64]ChannelEdge)
×
1094
                edges = nil
×
1095
        })
×
1096
        if err != nil {
×
1097
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
1098
        }
×
1099

1100
        // Insert any edges loaded from disk into the cache.
1101
        for chanid, channel := range edgesToCache {
×
1102
                s.chanCache.insert(chanid, channel)
×
1103
        }
×
1104

1105
        if len(edges) > 0 {
×
NEW
1106
                log.Debugf("ChanUpdatesInHorizon hit percentage: %.2f (%d/%d)",
×
NEW
1107
                        float64(hits)*100/float64(len(edges)), hits, len(edges))
×
1108
        } else {
×
1109
                log.Debugf("ChanUpdatesInHorizon returned no edges in "+
×
1110
                        "horizon (%s, %s)", startTime, endTime)
×
1111
        }
×
1112

1113
        return edges, nil
×
1114
}
1115

1116
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1117
// data to the call-back.
1118
//
1119
// NOTE: The callback contents MUST not be modified.
1120
//
1121
// NOTE: part of the V1Store interface.
1122
func (s *SQLStore) ForEachNodeCached(cb func(node route.Vertex,
1123
        chans map[uint64]*DirectedChannel) error) error {
×
1124

×
1125
        var ctx = context.TODO()
×
1126

×
1127
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1128
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
1129
                        nodePub route.Vertex) error {
×
1130

×
1131
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
1132
                        if err != nil {
×
1133
                                return fmt.Errorf("unable to fetch "+
×
1134
                                        "node(id=%d) features: %w", nodeID, err)
×
1135
                        }
×
1136

1137
                        toNodeCallback := func() route.Vertex {
×
1138
                                return nodePub
×
1139
                        }
×
1140

1141
                        rows, err := db.ListChannelsByNodeID(
×
1142
                                ctx, sqlc.ListChannelsByNodeIDParams{
×
1143
                                        Version: int16(ProtocolV1),
×
1144
                                        NodeID1: nodeID,
×
1145
                                },
×
1146
                        )
×
1147
                        if err != nil {
×
1148
                                return fmt.Errorf("unable to fetch channels "+
×
1149
                                        "of node(id=%d): %w", nodeID, err)
×
1150
                        }
×
1151

1152
                        channels := make(map[uint64]*DirectedChannel, len(rows))
×
1153
                        for _, row := range rows {
×
1154
                                node1, node2, err := buildNodeVertices(
×
1155
                                        row.Node1Pubkey, row.Node2Pubkey,
×
1156
                                )
×
1157
                                if err != nil {
×
1158
                                        return err
×
1159
                                }
×
1160

1161
                                e, err := getAndBuildEdgeInfo(
×
1162
                                        ctx, db, s.cfg.ChainHash,
×
1163
                                        row.Channel.ID, row.Channel, node1,
×
1164
                                        node2,
×
1165
                                )
×
1166
                                if err != nil {
×
1167
                                        return fmt.Errorf("unable to build "+
×
1168
                                                "channel info: %w", err)
×
1169
                                }
×
1170

1171
                                dbPol1, dbPol2, err := extractChannelPolicies(
×
1172
                                        row,
×
1173
                                )
×
1174
                                if err != nil {
×
1175
                                        return fmt.Errorf("unable to "+
×
1176
                                                "extract channel "+
×
1177
                                                "policies: %w", err)
×
1178
                                }
×
1179

1180
                                p1, p2, err := getAndBuildChanPolicies(
×
1181
                                        ctx, db, dbPol1, dbPol2, e.ChannelID,
×
1182
                                        node1, node2,
×
1183
                                )
×
1184
                                if err != nil {
×
1185
                                        return fmt.Errorf("unable to "+
×
1186
                                                "build channel policies: %w",
×
1187
                                                err)
×
1188
                                }
×
1189

1190
                                // Determine the outgoing and incoming policy
1191
                                // for this channel and node combo.
1192
                                outPolicy, inPolicy := p1, p2
×
1193
                                if p1 != nil && p1.ToNode == nodePub {
×
1194
                                        outPolicy, inPolicy = p2, p1
×
1195
                                } else if p2 != nil && p2.ToNode != nodePub {
×
1196
                                        outPolicy, inPolicy = p2, p1
×
1197
                                }
×
1198

1199
                                var cachedInPolicy *models.CachedEdgePolicy
×
1200
                                if inPolicy != nil {
×
1201
                                        cachedInPolicy = models.NewCachedPolicy(
×
1202
                                                p2,
×
1203
                                        )
×
1204
                                        cachedInPolicy.ToNodePubKey =
×
1205
                                                toNodeCallback
×
1206
                                        cachedInPolicy.ToNodeFeatures =
×
1207
                                                features
×
1208
                                }
×
1209

1210
                                var inboundFee lnwire.Fee
×
1211
                                outPolicy.InboundFee.WhenSome(
×
1212
                                        func(fee lnwire.Fee) {
×
1213
                                                inboundFee = fee
×
1214
                                        },
×
1215
                                )
1216

1217
                                directedChannel := &DirectedChannel{
×
1218
                                        ChannelID: e.ChannelID,
×
1219
                                        IsNode1: nodePub ==
×
1220
                                                e.NodeKey1Bytes,
×
1221
                                        OtherNode:    e.NodeKey2Bytes,
×
1222
                                        Capacity:     e.Capacity,
×
1223
                                        OutPolicySet: p1 != nil,
×
1224
                                        InPolicy:     cachedInPolicy,
×
1225
                                        InboundFee:   inboundFee,
×
1226
                                }
×
1227

×
1228
                                if nodePub == e.NodeKey2Bytes {
×
1229
                                        directedChannel.OtherNode =
×
1230
                                                e.NodeKey1Bytes
×
1231
                                }
×
1232

1233
                                channels[e.ChannelID] = directedChannel
×
1234
                        }
1235

1236
                        return cb(nodePub, channels)
×
1237
                })
1238
        }, sqldb.NoOpReset)
1239
}
1240

1241
// ForEachChannelCacheable iterates through all the channel edges stored
1242
// within the graph and invokes the passed callback for each edge. The
1243
// callback takes two edges as since this is a directed graph, both the
1244
// in/out edges are visited. If the callback returns an error, then the
1245
// transaction is aborted and the iteration stops early.
1246
//
1247
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1248
// pointer for that particular channel edge routing policy will be
1249
// passed into the callback.
1250
//
1251
// NOTE: this method is like ForEachChannel but fetches only the data
1252
// required for the graph cache.
1253
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1254
        *models.CachedEdgePolicy,
NEW
1255
        *models.CachedEdgePolicy) error) error {
×
NEW
1256

×
NEW
1257
        ctx := context.TODO()
×
NEW
1258

×
NEW
1259
        handleChannel := func(db SQLQueries,
×
NEW
1260
                row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
×
NEW
1261

×
NEW
1262
                node1, node2, err := buildNodeVertices(
×
NEW
1263
                        row.Node1Pubkey, row.Node2Pubkey,
×
NEW
1264
                )
×
NEW
1265
                if err != nil {
×
NEW
1266
                        return err
×
NEW
1267
                }
×
1268

NEW
1269
                edge := buildCacheableChannelInfo(row.Channel, node1, node2)
×
NEW
1270

×
NEW
1271
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
NEW
1272
                if err != nil {
×
NEW
1273
                        return err
×
NEW
1274
                }
×
1275

NEW
1276
                var pol1, pol2 *models.CachedEdgePolicy
×
NEW
1277
                if dbPol1 != nil {
×
NEW
1278
                        policy1, err := buildChanPolicy(
×
NEW
1279
                                *dbPol1, edge.ChannelID, nil, node2, true,
×
NEW
1280
                        )
×
NEW
1281
                        if err != nil {
×
NEW
1282
                                return err
×
NEW
1283
                        }
×
1284

NEW
1285
                        pol1 = models.NewCachedPolicy(policy1)
×
1286
                }
NEW
1287
                if dbPol2 != nil {
×
NEW
1288
                        policy2, err := buildChanPolicy(
×
NEW
1289
                                *dbPol2, edge.ChannelID, nil, node1, false,
×
NEW
1290
                        )
×
NEW
1291
                        if err != nil {
×
NEW
1292
                                return err
×
NEW
1293
                        }
×
1294

NEW
1295
                        pol2 = models.NewCachedPolicy(policy2)
×
1296
                }
1297

NEW
1298
                if err := cb(edge, pol1, pol2); err != nil {
×
NEW
1299
                        return err
×
NEW
1300
                }
×
1301

NEW
1302
                return nil
×
1303
        }
1304

NEW
1305
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
1306
                lastID := int64(-1)
×
NEW
1307
                for {
×
NEW
1308
                        //nolint:ll
×
NEW
1309
                        rows, err := db.ListChannelsWithPoliciesPaginated(
×
NEW
1310
                                ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
NEW
1311
                                        Version: int16(ProtocolV1),
×
NEW
1312
                                        ID:      lastID,
×
NEW
1313
                                        Limit:   pageSize,
×
NEW
1314
                                },
×
NEW
1315
                        )
×
NEW
1316
                        if err != nil {
×
NEW
1317
                                return err
×
NEW
1318
                        }
×
1319

NEW
1320
                        if len(rows) == 0 {
×
NEW
1321
                                break
×
1322
                        }
1323

NEW
1324
                        for _, row := range rows {
×
NEW
1325
                                err := handleChannel(db, row)
×
NEW
1326
                                if err != nil {
×
NEW
1327
                                        return err
×
NEW
1328
                                }
×
1329

NEW
1330
                                lastID = row.Channel.ID
×
1331
                        }
1332
                }
1333

NEW
1334
                return nil
×
1335
        }, sqldb.NoOpReset)
1336
}
1337

1338
// ForEachChannel iterates through all the channel edges stored within the
1339
// graph and invokes the passed callback for each edge. The callback takes two
1340
// edges as since this is a directed graph, both the in/out edges are visited.
1341
// If the callback returns an error, then the transaction is aborted and the
1342
// iteration stops early.
1343
//
1344
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1345
// for that particular channel edge routing policy will be passed into the
1346
// callback.
1347
//
1348
// NOTE: part of the V1Store interface.
1349
func (s *SQLStore) ForEachChannel(cb func(*models.ChannelEdgeInfo,
1350
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
×
1351

×
1352
        ctx := context.TODO()
×
1353

×
1354
        handleChannel := func(db SQLQueries,
×
1355
                row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
×
1356

×
1357
                node1, node2, err := buildNodeVertices(
×
1358
                        row.Node1Pubkey, row.Node2Pubkey,
×
1359
                )
×
1360
                if err != nil {
×
1361
                        return fmt.Errorf("unable to build node vertices: %w",
×
1362
                                err)
×
1363
                }
×
1364

1365
                edge, err := getAndBuildEdgeInfo(
×
1366
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
1367
                        node1, node2,
×
1368
                )
×
1369
                if err != nil {
×
1370
                        return fmt.Errorf("unable to build channel info: %w",
×
1371
                                err)
×
1372
                }
×
1373

1374
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1375
                if err != nil {
×
1376
                        return fmt.Errorf("unable to extract channel "+
×
1377
                                "policies: %w", err)
×
1378
                }
×
1379

1380
                p1, p2, err := getAndBuildChanPolicies(
×
1381
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1382
                )
×
1383
                if err != nil {
×
1384
                        return fmt.Errorf("unable to build channel "+
×
1385
                                "policies: %w", err)
×
1386
                }
×
1387

1388
                err = cb(edge, p1, p2)
×
1389
                if err != nil {
×
1390
                        return fmt.Errorf("callback failed for channel "+
×
1391
                                "id=%d: %w", edge.ChannelID, err)
×
1392
                }
×
1393

1394
                return nil
×
1395
        }
1396

1397
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
1398
                lastID := int64(-1)
×
1399
                for {
×
1400
                        //nolint:ll
×
1401
                        rows, err := db.ListChannelsWithPoliciesPaginated(
×
1402
                                ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
1403
                                        Version: int16(ProtocolV1),
×
1404
                                        ID:      lastID,
×
1405
                                        Limit:   pageSize,
×
1406
                                },
×
1407
                        )
×
1408
                        if err != nil {
×
1409
                                return err
×
1410
                        }
×
1411

1412
                        if len(rows) == 0 {
×
1413
                                break
×
1414
                        }
1415

1416
                        for _, row := range rows {
×
1417
                                err := handleChannel(db, row)
×
1418
                                if err != nil {
×
1419
                                        return err
×
1420
                                }
×
1421

1422
                                lastID = row.Channel.ID
×
1423
                        }
1424
                }
1425

1426
                return nil
×
1427
        }, sqldb.NoOpReset)
1428
}
1429

1430
// FilterChannelRange returns the channel ID's of all known channels which were
1431
// mined in a block height within the passed range. The channel IDs are grouped
1432
// by their common block height. This method can be used to quickly share with a
1433
// peer the set of channels we know of within a particular range to catch them
1434
// up after a period of time offline. If withTimestamps is true then the
1435
// timestamp info of the latest received channel update messages of the channel
1436
// will be included in the response.
1437
//
1438
// NOTE: This is part of the V1Store interface.
1439
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1440
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1441

×
1442
        var (
×
1443
                ctx       = context.TODO()
×
1444
                startSCID = &lnwire.ShortChannelID{
×
1445
                        BlockHeight: startHeight,
×
1446
                }
×
1447
                endSCID = lnwire.ShortChannelID{
×
1448
                        BlockHeight: endHeight,
×
1449
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1450
                        TxPosition:  math.MaxUint16,
×
1451
                }
×
1452
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1453
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1454
        )
×
1455

×
1456
        // 1) get all channels where channelID is between start and end chan ID.
×
1457
        // 2) skip if not public (ie, no channel_proof)
×
1458
        // 3) collect that channel.
×
1459
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1460
        //    and add those timestamps to the collected channel.
×
1461
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1462
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1463
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1464
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1465
                                StartScid: chanIDStart[:],
×
1466
                                EndScid:   chanIDEnd[:],
×
1467
                        },
×
1468
                )
×
1469
                if err != nil {
×
1470
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1471
                                err)
×
1472
                }
×
1473

1474
                for _, dbChan := range dbChans {
×
1475
                        cid := lnwire.NewShortChanIDFromInt(
×
1476
                                byteOrder.Uint64(dbChan.Scid),
×
1477
                        )
×
1478
                        chanInfo := NewChannelUpdateInfo(
×
1479
                                cid, time.Time{}, time.Time{},
×
1480
                        )
×
1481

×
1482
                        if !withTimestamps {
×
1483
                                channelsPerBlock[cid.BlockHeight] = append(
×
1484
                                        channelsPerBlock[cid.BlockHeight],
×
1485
                                        chanInfo,
×
1486
                                )
×
1487

×
1488
                                continue
×
1489
                        }
1490

1491
                        //nolint:ll
1492
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1493
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1494
                                        Version:   int16(ProtocolV1),
×
1495
                                        ChannelID: dbChan.ID,
×
1496
                                        NodeID:    dbChan.NodeID1,
×
1497
                                },
×
1498
                        )
×
1499
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1500
                                return fmt.Errorf("unable to fetch node1 "+
×
1501
                                        "policy: %w", err)
×
1502
                        } else if err == nil {
×
1503
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1504
                                        node1Policy.LastUpdate.Int64, 0,
×
1505
                                )
×
1506
                        }
×
1507

1508
                        //nolint:ll
1509
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1510
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1511
                                        Version:   int16(ProtocolV1),
×
1512
                                        ChannelID: dbChan.ID,
×
1513
                                        NodeID:    dbChan.NodeID2,
×
1514
                                },
×
1515
                        )
×
1516
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1517
                                return fmt.Errorf("unable to fetch node2 "+
×
1518
                                        "policy: %w", err)
×
1519
                        } else if err == nil {
×
1520
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1521
                                        node2Policy.LastUpdate.Int64, 0,
×
1522
                                )
×
1523
                        }
×
1524

1525
                        channelsPerBlock[cid.BlockHeight] = append(
×
1526
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1527
                        )
×
1528
                }
1529

1530
                return nil
×
1531
        }, func() {
×
1532
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1533
        })
×
1534
        if err != nil {
×
1535
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1536
        }
×
1537

1538
        if len(channelsPerBlock) == 0 {
×
1539
                return nil, nil
×
1540
        }
×
1541

1542
        // Return the channel ranges in ascending block height order.
1543
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1544
        slices.Sort(blocks)
×
1545

×
1546
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1547
                return BlockChannelRange{
×
1548
                        Height:   block,
×
1549
                        Channels: channelsPerBlock[block],
×
1550
                }
×
1551
        }), nil
×
1552
}
1553

1554
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1555
// zombie. This method is used on an ad-hoc basis, when channels need to be
1556
// marked as zombies outside the normal pruning cycle.
1557
//
1558
// NOTE: part of the V1Store interface.
1559
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1560
        pubKey1, pubKey2 [33]byte) error {
×
1561

×
1562
        ctx := context.TODO()
×
1563

×
1564
        s.cacheMu.Lock()
×
1565
        defer s.cacheMu.Unlock()
×
1566

×
1567
        chanIDB := channelIDToBytes(chanID)
×
1568

×
1569
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1570
                return db.UpsertZombieChannel(
×
1571
                        ctx, sqlc.UpsertZombieChannelParams{
×
1572
                                Version:  int16(ProtocolV1),
×
1573
                                Scid:     chanIDB[:],
×
1574
                                NodeKey1: pubKey1[:],
×
1575
                                NodeKey2: pubKey2[:],
×
1576
                        },
×
1577
                )
×
1578
        }, sqldb.NoOpReset)
×
1579
        if err != nil {
×
1580
                return fmt.Errorf("unable to upsert zombie channel "+
×
1581
                        "(channel_id=%d): %w", chanID, err)
×
1582
        }
×
1583

1584
        s.rejectCache.remove(chanID)
×
1585
        s.chanCache.remove(chanID)
×
1586

×
1587
        return nil
×
1588
}
1589

1590
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1591
//
1592
// NOTE: part of the V1Store interface.
1593
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1594
        s.cacheMu.Lock()
×
1595
        defer s.cacheMu.Unlock()
×
1596

×
1597
        var (
×
1598
                ctx     = context.TODO()
×
1599
                chanIDB = channelIDToBytes(chanID)
×
1600
        )
×
1601

×
1602
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1603
                res, err := db.DeleteZombieChannel(
×
1604
                        ctx, sqlc.DeleteZombieChannelParams{
×
1605
                                Scid:    chanIDB[:],
×
1606
                                Version: int16(ProtocolV1),
×
1607
                        },
×
1608
                )
×
1609
                if err != nil {
×
1610
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1611
                                err)
×
1612
                }
×
1613

1614
                rows, err := res.RowsAffected()
×
1615
                if err != nil {
×
1616
                        return err
×
1617
                }
×
1618

1619
                if rows == 0 {
×
1620
                        return ErrZombieEdgeNotFound
×
1621
                } else if rows > 1 {
×
1622
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1623
                                "expected 1", rows)
×
1624
                }
×
1625

1626
                return nil
×
1627
        }, sqldb.NoOpReset)
1628
        if err != nil {
×
1629
                return fmt.Errorf("unable to mark edge live "+
×
1630
                        "(channel_id=%d): %w", chanID, err)
×
1631
        }
×
1632

1633
        s.rejectCache.remove(chanID)
×
1634
        s.chanCache.remove(chanID)
×
1635

×
1636
        return err
×
1637
}
1638

1639
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1640
// zombie, then the two node public keys corresponding to this edge are also
1641
// returned.
1642
//
1643
// NOTE: part of the V1Store interface.
1644
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte) {
×
1645
        var (
×
1646
                ctx              = context.TODO()
×
1647
                isZombie         bool
×
1648
                pubKey1, pubKey2 route.Vertex
×
1649
                chanIDB          = channelIDToBytes(chanID)
×
1650
        )
×
1651

×
1652
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1653
                zombie, err := db.GetZombieChannel(
×
1654
                        ctx, sqlc.GetZombieChannelParams{
×
1655
                                Scid:    chanIDB[:],
×
1656
                                Version: int16(ProtocolV1),
×
1657
                        },
×
1658
                )
×
1659
                if errors.Is(err, sql.ErrNoRows) {
×
1660
                        return nil
×
1661
                }
×
1662
                if err != nil {
×
1663
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1664
                                err)
×
1665
                }
×
1666

1667
                copy(pubKey1[:], zombie.NodeKey1)
×
1668
                copy(pubKey2[:], zombie.NodeKey2)
×
1669
                isZombie = true
×
1670

×
1671
                return nil
×
1672
        }, sqldb.NoOpReset)
1673
        if err != nil {
×
1674
                // TODO(elle): update the IsZombieEdge method to return an
×
1675
                // error.
×
1676
                return false, route.Vertex{}, route.Vertex{}
×
1677
        }
×
1678

1679
        return isZombie, pubKey1, pubKey2
×
1680
}
1681

1682
// NumZombies returns the current number of zombie channels in the graph.
1683
//
1684
// NOTE: part of the V1Store interface.
1685
func (s *SQLStore) NumZombies() (uint64, error) {
×
1686
        var (
×
1687
                ctx        = context.TODO()
×
1688
                numZombies uint64
×
1689
        )
×
1690
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1691
                count, err := db.CountZombieChannels(ctx, int16(ProtocolV1))
×
1692
                if err != nil {
×
1693
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1694
                                err)
×
1695
                }
×
1696

1697
                numZombies = uint64(count)
×
1698

×
1699
                return nil
×
1700
        }, sqldb.NoOpReset)
1701
        if err != nil {
×
1702
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1703
        }
×
1704

1705
        return numZombies, nil
×
1706
}
1707

1708
// DeleteChannelEdges removes edges with the given channel IDs from the
1709
// database and marks them as zombies. This ensures that we're unable to re-add
1710
// it to our database once again. If an edge does not exist within the
1711
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1712
// true, then when we mark these edges as zombies, we'll set up the keys such
1713
// that we require the node that failed to send the fresh update to be the one
1714
// that resurrects the channel from its zombie state. The markZombie bool
1715
// denotes whether to mark the channel as a zombie.
1716
//
1717
// NOTE: part of the V1Store interface.
1718
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1719
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
1720

×
1721
        s.cacheMu.Lock()
×
1722
        defer s.cacheMu.Unlock()
×
1723

×
1724
        var (
×
1725
                ctx     = context.TODO()
×
1726
                deleted []*models.ChannelEdgeInfo
×
1727
        )
×
1728
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1729
                for _, chanID := range chanIDs {
×
1730
                        chanIDB := channelIDToBytes(chanID)
×
1731

×
1732
                        row, err := db.GetChannelBySCIDWithPolicies(
×
1733
                                ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1734
                                        Scid:    chanIDB[:],
×
1735
                                        Version: int16(ProtocolV1),
×
1736
                                },
×
1737
                        )
×
1738
                        if errors.Is(err, sql.ErrNoRows) {
×
1739
                                return ErrEdgeNotFound
×
1740
                        } else if err != nil {
×
1741
                                return fmt.Errorf("unable to fetch channel: %w",
×
1742
                                        err)
×
1743
                        }
×
1744

1745
                        node1, node2, err := buildNodeVertices(
×
1746
                                row.Node.PubKey, row.Node_2.PubKey,
×
1747
                        )
×
1748
                        if err != nil {
×
1749
                                return err
×
1750
                        }
×
1751

1752
                        info, err := getAndBuildEdgeInfo(
×
1753
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
1754
                                row.Channel, node1, node2,
×
1755
                        )
×
1756
                        if err != nil {
×
1757
                                return err
×
1758
                        }
×
1759

1760
                        err = db.DeleteChannel(ctx, row.Channel.ID)
×
1761
                        if err != nil {
×
1762
                                return fmt.Errorf("unable to delete "+
×
1763
                                        "channel: %w", err)
×
1764
                        }
×
1765

1766
                        deleted = append(deleted, info)
×
1767

×
1768
                        if !markZombie {
×
1769
                                continue
×
1770
                        }
1771

1772
                        nodeKey1, nodeKey2 := info.NodeKey1Bytes,
×
1773
                                info.NodeKey2Bytes
×
1774
                        if strictZombiePruning {
×
1775
                                var e1UpdateTime, e2UpdateTime *time.Time
×
1776
                                if row.Policy1LastUpdate.Valid {
×
1777
                                        e1Time := time.Unix(
×
1778
                                                row.Policy1LastUpdate.Int64, 0,
×
1779
                                        )
×
1780
                                        e1UpdateTime = &e1Time
×
1781
                                }
×
1782
                                if row.Policy2LastUpdate.Valid {
×
1783
                                        e2Time := time.Unix(
×
1784
                                                row.Policy2LastUpdate.Int64, 0,
×
1785
                                        )
×
1786
                                        e2UpdateTime = &e2Time
×
1787
                                }
×
1788

1789
                                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
1790
                                        info, e1UpdateTime, e2UpdateTime,
×
1791
                                )
×
1792
                        }
1793

1794
                        err = db.UpsertZombieChannel(
×
1795
                                ctx, sqlc.UpsertZombieChannelParams{
×
1796
                                        Version:  int16(ProtocolV1),
×
1797
                                        Scid:     chanIDB[:],
×
1798
                                        NodeKey1: nodeKey1[:],
×
1799
                                        NodeKey2: nodeKey2[:],
×
1800
                                },
×
1801
                        )
×
1802
                        if err != nil {
×
1803
                                return fmt.Errorf("unable to mark channel as "+
×
1804
                                        "zombie: %w", err)
×
1805
                        }
×
1806
                }
1807

1808
                return nil
×
1809
        }, func() {
×
1810
                deleted = nil
×
1811
        })
×
1812
        if err != nil {
×
1813
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1814
                        err)
×
1815
        }
×
1816

1817
        for _, chanID := range chanIDs {
×
1818
                s.rejectCache.remove(chanID)
×
1819
                s.chanCache.remove(chanID)
×
1820
        }
×
1821

1822
        return deleted, nil
×
1823
}
1824

1825
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1826
// channel identified by the channel ID. If the channel can't be found, then
1827
// ErrEdgeNotFound is returned. A struct which houses the general information
1828
// for the channel itself is returned as well as two structs that contain the
1829
// routing policies for the channel in either direction.
1830
//
1831
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1832
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1833
// the ChannelEdgeInfo will only include the public keys of each node.
1834
//
1835
// NOTE: part of the V1Store interface.
1836
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1837
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1838
        *models.ChannelEdgePolicy, error) {
×
1839

×
1840
        var (
×
1841
                ctx              = context.TODO()
×
1842
                edge             *models.ChannelEdgeInfo
×
1843
                policy1, policy2 *models.ChannelEdgePolicy
×
1844
        )
×
1845
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1846
                var chanIDB [8]byte
×
1847
                byteOrder.PutUint64(chanIDB[:], chanID)
×
1848

×
1849
                row, err := db.GetChannelBySCIDWithPolicies(
×
1850
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1851
                                Scid:    chanIDB[:],
×
1852
                                Version: int16(ProtocolV1),
×
1853
                        },
×
1854
                )
×
1855
                if errors.Is(err, sql.ErrNoRows) {
×
1856
                        // First check if this edge is perhaps in the zombie
×
1857
                        // index.
×
1858
                        isZombie, err := db.IsZombieChannel(
×
1859
                                ctx, sqlc.IsZombieChannelParams{
×
1860
                                        Scid:    chanIDB[:],
×
1861
                                        Version: int16(ProtocolV1),
×
1862
                                },
×
1863
                        )
×
1864
                        if err != nil {
×
1865
                                return fmt.Errorf("unable to check if "+
×
1866
                                        "channel is zombie: %w", err)
×
1867
                        } else if isZombie {
×
1868
                                return ErrZombieEdge
×
1869
                        }
×
1870

1871
                        return ErrEdgeNotFound
×
1872
                } else if err != nil {
×
1873
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1874
                }
×
1875

1876
                node1, node2, err := buildNodeVertices(
×
1877
                        row.Node.PubKey, row.Node_2.PubKey,
×
1878
                )
×
1879
                if err != nil {
×
1880
                        return err
×
1881
                }
×
1882

1883
                edge, err = getAndBuildEdgeInfo(
×
1884
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
1885
                        node1, node2,
×
1886
                )
×
1887
                if err != nil {
×
1888
                        return fmt.Errorf("unable to build channel info: %w",
×
1889
                                err)
×
1890
                }
×
1891

1892
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1893
                if err != nil {
×
1894
                        return fmt.Errorf("unable to extract channel "+
×
1895
                                "policies: %w", err)
×
1896
                }
×
1897

1898
                policy1, policy2, err = getAndBuildChanPolicies(
×
1899
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1900
                )
×
1901
                if err != nil {
×
1902
                        return fmt.Errorf("unable to build channel "+
×
1903
                                "policies: %w", err)
×
1904
                }
×
1905

1906
                return nil
×
1907
        }, sqldb.NoOpReset)
1908
        if err != nil {
×
1909
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1910
                        err)
×
1911
        }
×
1912

1913
        return edge, policy1, policy2, nil
×
1914
}
1915

1916
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
1917
// the channel identified by the funding outpoint. If the channel can't be
1918
// found, then ErrEdgeNotFound is returned. A struct which houses the general
1919
// information for the channel itself is returned as well as two structs that
1920
// contain the routing policies for the channel in either direction.
1921
//
1922
// NOTE: part of the V1Store interface.
1923
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
1924
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1925
        *models.ChannelEdgePolicy, error) {
×
1926

×
1927
        var (
×
1928
                ctx              = context.TODO()
×
1929
                edge             *models.ChannelEdgeInfo
×
1930
                policy1, policy2 *models.ChannelEdgePolicy
×
1931
        )
×
1932
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1933
                row, err := db.GetChannelByOutpointWithPolicies(
×
1934
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
1935
                                Outpoint: op.String(),
×
1936
                                Version:  int16(ProtocolV1),
×
1937
                        },
×
1938
                )
×
1939
                if errors.Is(err, sql.ErrNoRows) {
×
1940
                        return ErrEdgeNotFound
×
1941
                } else if err != nil {
×
1942
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1943
                }
×
1944

1945
                node1, node2, err := buildNodeVertices(
×
1946
                        row.Node1Pubkey, row.Node2Pubkey,
×
1947
                )
×
1948
                if err != nil {
×
1949
                        return err
×
1950
                }
×
1951

1952
                edge, err = getAndBuildEdgeInfo(
×
1953
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
1954
                        node1, node2,
×
1955
                )
×
1956
                if err != nil {
×
1957
                        return fmt.Errorf("unable to build channel info: %w",
×
1958
                                err)
×
1959
                }
×
1960

1961
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1962
                if err != nil {
×
1963
                        return fmt.Errorf("unable to extract channel "+
×
1964
                                "policies: %w", err)
×
1965
                }
×
1966

1967
                policy1, policy2, err = getAndBuildChanPolicies(
×
1968
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1969
                )
×
1970
                if err != nil {
×
1971
                        return fmt.Errorf("unable to build channel "+
×
1972
                                "policies: %w", err)
×
1973
                }
×
1974

1975
                return nil
×
1976
        }, sqldb.NoOpReset)
1977
        if err != nil {
×
1978
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1979
                        err)
×
1980
        }
×
1981

1982
        return edge, policy1, policy2, nil
×
1983
}
1984

1985
// HasChannelEdge returns true if the database knows of a channel edge with the
1986
// passed channel ID, and false otherwise. If an edge with that ID is found
1987
// within the graph, then two time stamps representing the last time the edge
1988
// was updated for both directed edges are returned along with the boolean. If
1989
// it is not found, then the zombie index is checked and its result is returned
1990
// as the second boolean.
1991
//
1992
// NOTE: part of the V1Store interface.
1993
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
1994
        bool, error) {
×
1995

×
1996
        ctx := context.TODO()
×
1997

×
1998
        var (
×
1999
                exists          bool
×
2000
                isZombie        bool
×
2001
                node1LastUpdate time.Time
×
2002
                node2LastUpdate time.Time
×
2003
        )
×
2004

×
2005
        // We'll query the cache with the shared lock held to allow multiple
×
2006
        // readers to access values in the cache concurrently if they exist.
×
2007
        s.cacheMu.RLock()
×
2008
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2009
                s.cacheMu.RUnlock()
×
2010
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2011
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2012
                exists, isZombie = entry.flags.unpack()
×
2013

×
2014
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2015
        }
×
2016
        s.cacheMu.RUnlock()
×
2017

×
2018
        s.cacheMu.Lock()
×
2019
        defer s.cacheMu.Unlock()
×
2020

×
2021
        // The item was not found with the shared lock, so we'll acquire the
×
2022
        // exclusive lock and check the cache again in case another method added
×
2023
        // the entry to the cache while no lock was held.
×
2024
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2025
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2026
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2027
                exists, isZombie = entry.flags.unpack()
×
2028

×
2029
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2030
        }
×
2031

2032
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2033
                var chanIDB [8]byte
×
2034
                byteOrder.PutUint64(chanIDB[:], chanID)
×
2035

×
2036
                channel, err := db.GetChannelBySCID(
×
2037
                        ctx, sqlc.GetChannelBySCIDParams{
×
2038
                                Scid:    chanIDB[:],
×
2039
                                Version: int16(ProtocolV1),
×
2040
                        },
×
2041
                )
×
2042
                if errors.Is(err, sql.ErrNoRows) {
×
2043
                        // Check if it is a zombie channel.
×
2044
                        isZombie, err = db.IsZombieChannel(
×
2045
                                ctx, sqlc.IsZombieChannelParams{
×
2046
                                        Scid:    chanIDB[:],
×
2047
                                        Version: int16(ProtocolV1),
×
2048
                                },
×
2049
                        )
×
2050
                        if err != nil {
×
2051
                                return fmt.Errorf("could not check if channel "+
×
2052
                                        "is zombie: %w", err)
×
2053
                        }
×
2054

2055
                        return nil
×
2056
                } else if err != nil {
×
2057
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2058
                }
×
2059

2060
                exists = true
×
2061

×
2062
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
2063
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2064
                                Version:   int16(ProtocolV1),
×
2065
                                ChannelID: channel.ID,
×
2066
                                NodeID:    channel.NodeID1,
×
2067
                        },
×
2068
                )
×
2069
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2070
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2071
                                err)
×
2072
                } else if err == nil {
×
2073
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
2074
                }
×
2075

2076
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
2077
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2078
                                Version:   int16(ProtocolV1),
×
2079
                                ChannelID: channel.ID,
×
2080
                                NodeID:    channel.NodeID2,
×
2081
                        },
×
2082
                )
×
2083
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2084
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2085
                                err)
×
2086
                } else if err == nil {
×
2087
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
2088
                }
×
2089

2090
                return nil
×
2091
        }, sqldb.NoOpReset)
2092
        if err != nil {
×
2093
                return time.Time{}, time.Time{}, false, false,
×
2094
                        fmt.Errorf("unable to fetch channel: %w", err)
×
2095
        }
×
2096

2097
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
2098
                upd1Time: node1LastUpdate.Unix(),
×
2099
                upd2Time: node2LastUpdate.Unix(),
×
2100
                flags:    packRejectFlags(exists, isZombie),
×
2101
        })
×
2102

×
2103
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2104
}
2105

2106
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2107
// passed channel point (outpoint). If the passed channel doesn't exist within
2108
// the database, then ErrEdgeNotFound is returned.
2109
//
2110
// NOTE: part of the V1Store interface.
2111
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
2112
        var (
×
2113
                ctx       = context.TODO()
×
2114
                channelID uint64
×
2115
        )
×
2116
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2117
                chanID, err := db.GetSCIDByOutpoint(
×
2118
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2119
                                Outpoint: chanPoint.String(),
×
2120
                                Version:  int16(ProtocolV1),
×
2121
                        },
×
2122
                )
×
2123
                if errors.Is(err, sql.ErrNoRows) {
×
2124
                        return ErrEdgeNotFound
×
2125
                } else if err != nil {
×
2126
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2127
                                err)
×
2128
                }
×
2129

2130
                channelID = byteOrder.Uint64(chanID)
×
2131

×
2132
                return nil
×
2133
        }, sqldb.NoOpReset)
2134
        if err != nil {
×
2135
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2136
        }
×
2137

2138
        return channelID, nil
×
2139
}
2140

2141
// IsPublicNode is a helper method that determines whether the node with the
2142
// given public key is seen as a public node in the graph from the graph's
2143
// source node's point of view.
2144
//
2145
// NOTE: part of the V1Store interface.
2146
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
2147
        ctx := context.TODO()
×
2148

×
2149
        var isPublic bool
×
2150
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2151
                var err error
×
2152
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2153

×
2154
                return err
×
2155
        }, sqldb.NoOpReset)
×
2156
        if err != nil {
×
2157
                return false, fmt.Errorf("unable to check if node is "+
×
2158
                        "public: %w", err)
×
2159
        }
×
2160

2161
        return isPublic, nil
×
2162
}
2163

2164
// FetchChanInfos returns the set of channel edges that correspond to the passed
2165
// channel ID's. If an edge is the query is unknown to the database, it will
2166
// skipped and the result will contain only those edges that exist at the time
2167
// of the query. This can be used to respond to peer queries that are seeking to
2168
// fill in gaps in their view of the channel graph.
2169
//
2170
// NOTE: part of the V1Store interface.
2171
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2172
        var (
×
2173
                ctx   = context.TODO()
×
2174
                edges []ChannelEdge
×
2175
        )
×
2176
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2177
                for _, chanID := range chanIDs {
×
2178
                        var chanIDB [8]byte
×
2179
                        byteOrder.PutUint64(chanIDB[:], chanID)
×
2180

×
2181
                        // TODO(elle): potentially optimize this by using
×
2182
                        //  sqlc.slice() once that works for both SQLite and
×
2183
                        //  Postgres.
×
2184
                        row, err := db.GetChannelBySCIDWithPolicies(
×
2185
                                ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
2186
                                        Scid:    chanIDB[:],
×
2187
                                        Version: int16(ProtocolV1),
×
2188
                                },
×
2189
                        )
×
2190
                        if errors.Is(err, sql.ErrNoRows) {
×
2191
                                continue
×
2192
                        } else if err != nil {
×
2193
                                return fmt.Errorf("unable to fetch channel: %w",
×
2194
                                        err)
×
2195
                        }
×
2196

2197
                        node1, node2, err := buildNodes(
×
2198
                                ctx, db, row.Node, row.Node_2,
×
2199
                        )
×
2200
                        if err != nil {
×
2201
                                return fmt.Errorf("unable to fetch nodes: %w",
×
2202
                                        err)
×
2203
                        }
×
2204

2205
                        edge, err := getAndBuildEdgeInfo(
×
2206
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
2207
                                row.Channel, node1.PubKeyBytes,
×
2208
                                node2.PubKeyBytes,
×
2209
                        )
×
2210
                        if err != nil {
×
2211
                                return fmt.Errorf("unable to build "+
×
2212
                                        "channel info: %w", err)
×
2213
                        }
×
2214

2215
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2216
                        if err != nil {
×
2217
                                return fmt.Errorf("unable to extract channel "+
×
2218
                                        "policies: %w", err)
×
2219
                        }
×
2220

2221
                        p1, p2, err := getAndBuildChanPolicies(
×
2222
                                ctx, db, dbPol1, dbPol2, edge.ChannelID,
×
2223
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
2224
                        )
×
2225
                        if err != nil {
×
2226
                                return fmt.Errorf("unable to build channel "+
×
2227
                                        "policies: %w", err)
×
2228
                        }
×
2229

2230
                        edges = append(edges, ChannelEdge{
×
2231
                                Info:    edge,
×
2232
                                Policy1: p1,
×
2233
                                Policy2: p2,
×
2234
                                Node1:   node1,
×
2235
                                Node2:   node2,
×
2236
                        })
×
2237
                }
2238

2239
                return nil
×
2240
        }, func() {
×
2241
                edges = nil
×
2242
        })
×
2243
        if err != nil {
×
2244
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2245
        }
×
2246

2247
        return edges, nil
×
2248
}
2249

2250
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2251
// ID's that we don't know and are not known zombies of the passed set. In other
2252
// words, we perform a set difference of our set of chan ID's and the ones
2253
// passed in. This method can be used by callers to determine the set of
2254
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2255
// known zombies is also returned.
2256
//
2257
// NOTE: part of the V1Store interface.
2258
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
2259
        []ChannelUpdateInfo, error) {
×
2260

×
2261
        var (
×
2262
                ctx          = context.TODO()
×
2263
                newChanIDs   []uint64
×
2264
                knownZombies []ChannelUpdateInfo
×
2265
        )
×
2266
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2267
                for _, chanInfo := range chansInfo {
×
2268
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
2269
                        var chanIDB [8]byte
×
2270
                        byteOrder.PutUint64(chanIDB[:], channelID)
×
2271

×
2272
                        // TODO(elle): potentially optimize this by using
×
2273
                        //  sqlc.slice() once that works for both SQLite and
×
2274
                        //  Postgres.
×
2275
                        _, err := db.GetChannelBySCID(
×
2276
                                ctx, sqlc.GetChannelBySCIDParams{
×
2277
                                        Version: int16(ProtocolV1),
×
2278
                                        Scid:    chanIDB[:],
×
2279
                                },
×
2280
                        )
×
2281
                        if err == nil {
×
2282
                                continue
×
2283
                        } else if !errors.Is(err, sql.ErrNoRows) {
×
2284
                                return fmt.Errorf("unable to fetch channel: %w",
×
2285
                                        err)
×
2286
                        }
×
2287

2288
                        isZombie, err := db.IsZombieChannel(
×
2289
                                ctx, sqlc.IsZombieChannelParams{
×
2290
                                        Scid:    chanIDB[:],
×
2291
                                        Version: int16(ProtocolV1),
×
2292
                                },
×
2293
                        )
×
2294
                        if err != nil {
×
2295
                                return fmt.Errorf("unable to fetch zombie "+
×
2296
                                        "channel: %w", err)
×
2297
                        }
×
2298

2299
                        if isZombie {
×
2300
                                knownZombies = append(knownZombies, chanInfo)
×
2301

×
2302
                                continue
×
2303
                        }
2304

2305
                        newChanIDs = append(newChanIDs, channelID)
×
2306
                }
2307

2308
                return nil
×
2309
        }, func() {
×
2310
                newChanIDs = nil
×
2311
                knownZombies = nil
×
2312
        })
×
2313
        if err != nil {
×
2314
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2315
        }
×
2316

2317
        return newChanIDs, knownZombies, nil
×
2318
}
2319

2320
// PruneGraphNodes is a garbage collection method which attempts to prune out
2321
// any nodes from the channel graph that are currently unconnected. This ensure
2322
// that we only maintain a graph of reachable nodes. In the event that a pruned
2323
// node gains more channels, it will be re-added back to the graph.
2324
//
2325
// NOTE: this prunes nodes across protocol versions. It will never prune the
2326
// source nodes.
2327
//
2328
// NOTE: part of the V1Store interface.
2329
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
2330
        var ctx = context.TODO()
×
2331

×
2332
        var prunedNodes []route.Vertex
×
2333
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2334
                var err error
×
2335
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2336

×
2337
                return err
×
2338
        }, func() {
×
2339
                prunedNodes = nil
×
2340
        })
×
2341
        if err != nil {
×
2342
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2343
        }
×
2344

2345
        return prunedNodes, nil
×
2346
}
2347

2348
// PruneGraph prunes newly closed channels from the channel graph in response
2349
// to a new block being solved on the network. Any transactions which spend the
2350
// funding output of any known channels within he graph will be deleted.
2351
// Additionally, the "prune tip", or the last block which has been used to
2352
// prune the graph is stored so callers can ensure the graph is fully in sync
2353
// with the current UTXO state. A slice of channels that have been closed by
2354
// the target block along with any pruned nodes are returned if the function
2355
// succeeds without error.
2356
//
2357
// NOTE: part of the V1Store interface.
2358
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2359
        blockHash *chainhash.Hash, blockHeight uint32) (
2360
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
2361

×
2362
        ctx := context.TODO()
×
2363

×
2364
        s.cacheMu.Lock()
×
2365
        defer s.cacheMu.Unlock()
×
2366

×
2367
        var (
×
2368
                closedChans []*models.ChannelEdgeInfo
×
2369
                prunedNodes []route.Vertex
×
2370
        )
×
2371
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2372
                for _, outpoint := range spentOutputs {
×
2373
                        // TODO(elle): potentially optimize this by using
×
2374
                        //  sqlc.slice() once that works for both SQLite and
×
2375
                        //  Postgres.
×
2376
                        //
×
2377
                        // NOTE: this fetches channels for all protocol
×
2378
                        // versions.
×
2379
                        row, err := db.GetChannelByOutpoint(
×
2380
                                ctx, outpoint.String(),
×
2381
                        )
×
2382
                        if errors.Is(err, sql.ErrNoRows) {
×
2383
                                continue
×
2384
                        } else if err != nil {
×
2385
                                return fmt.Errorf("unable to fetch channel: %w",
×
2386
                                        err)
×
2387
                        }
×
2388

2389
                        node1, node2, err := buildNodeVertices(
×
2390
                                row.Node1Pubkey, row.Node2Pubkey,
×
2391
                        )
×
2392
                        if err != nil {
×
2393
                                return err
×
2394
                        }
×
2395

2396
                        info, err := getAndBuildEdgeInfo(
×
2397
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
2398
                                row.Channel, node1, node2,
×
2399
                        )
×
2400
                        if err != nil {
×
2401
                                return err
×
2402
                        }
×
2403

2404
                        err = db.DeleteChannel(ctx, row.Channel.ID)
×
2405
                        if err != nil {
×
2406
                                return fmt.Errorf("unable to delete "+
×
2407
                                        "channel: %w", err)
×
2408
                        }
×
2409

2410
                        closedChans = append(closedChans, info)
×
2411
                }
2412

2413
                err := db.UpsertPruneLogEntry(
×
2414
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
2415
                                BlockHash:   blockHash[:],
×
2416
                                BlockHeight: int64(blockHeight),
×
2417
                        },
×
2418
                )
×
2419
                if err != nil {
×
2420
                        return fmt.Errorf("unable to insert prune log "+
×
2421
                                "entry: %w", err)
×
2422
                }
×
2423

2424
                // Now that we've pruned some channels, we'll also prune any
2425
                // nodes that no longer have any channels.
2426
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2427
                if err != nil {
×
2428
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2429
                                err)
×
2430
                }
×
2431

2432
                return nil
×
2433
        }, func() {
×
2434
                prunedNodes = nil
×
2435
                closedChans = nil
×
2436
        })
×
2437
        if err != nil {
×
2438
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2439
        }
×
2440

2441
        for _, channel := range closedChans {
×
2442
                s.rejectCache.remove(channel.ChannelID)
×
2443
                s.chanCache.remove(channel.ChannelID)
×
2444
        }
×
2445

2446
        return closedChans, prunedNodes, nil
×
2447
}
2448

2449
// ChannelView returns the verifiable edge information for each active channel
2450
// within the known channel graph. The set of UTXOs (along with their scripts)
2451
// returned are the ones that need to be watched on chain to detect channel
2452
// closes on the resident blockchain.
2453
//
2454
// NOTE: part of the V1Store interface.
2455
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
2456
        var (
×
2457
                ctx        = context.TODO()
×
2458
                edgePoints []EdgePoint
×
2459
        )
×
2460

×
2461
        handleChannel := func(db SQLQueries,
×
2462
                channel sqlc.ListChannelsPaginatedRow) error {
×
2463

×
2464
                pkScript, err := genMultiSigP2WSH(
×
2465
                        channel.BitcoinKey1, channel.BitcoinKey2,
×
2466
                )
×
2467
                if err != nil {
×
2468
                        return err
×
2469
                }
×
2470

2471
                op, err := wire.NewOutPointFromString(channel.Outpoint)
×
2472
                if err != nil {
×
2473
                        return err
×
2474
                }
×
2475

2476
                edgePoints = append(edgePoints, EdgePoint{
×
2477
                        FundingPkScript: pkScript,
×
2478
                        OutPoint:        *op,
×
2479
                })
×
2480

×
2481
                return nil
×
2482
        }
2483

2484
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2485
                lastID := int64(-1)
×
2486
                for {
×
2487
                        rows, err := db.ListChannelsPaginated(
×
2488
                                ctx, sqlc.ListChannelsPaginatedParams{
×
2489
                                        Version: int16(ProtocolV1),
×
2490
                                        ID:      lastID,
×
2491
                                        Limit:   pageSize,
×
2492
                                },
×
2493
                        )
×
2494
                        if err != nil {
×
2495
                                return err
×
2496
                        }
×
2497

2498
                        if len(rows) == 0 {
×
2499
                                break
×
2500
                        }
2501

2502
                        for _, row := range rows {
×
2503
                                err := handleChannel(db, row)
×
2504
                                if err != nil {
×
2505
                                        return err
×
2506
                                }
×
2507

2508
                                lastID = row.ID
×
2509
                        }
2510
                }
2511

2512
                return nil
×
2513
        }, func() {
×
2514
                edgePoints = nil
×
2515
        })
×
2516
        if err != nil {
×
2517
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
2518
        }
×
2519

2520
        return edgePoints, nil
×
2521
}
2522

2523
// PruneTip returns the block height and hash of the latest block that has been
2524
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2525
// to tell if the graph is currently in sync with the current best known UTXO
2526
// state.
2527
//
2528
// NOTE: part of the V1Store interface.
2529
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
2530
        var (
×
2531
                ctx       = context.TODO()
×
2532
                tipHash   chainhash.Hash
×
2533
                tipHeight uint32
×
2534
        )
×
2535
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2536
                pruneTip, err := db.GetPruneTip(ctx)
×
2537
                if errors.Is(err, sql.ErrNoRows) {
×
2538
                        return ErrGraphNeverPruned
×
2539
                } else if err != nil {
×
2540
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
2541
                }
×
2542

2543
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
2544
                tipHeight = uint32(pruneTip.BlockHeight)
×
2545

×
2546
                return nil
×
2547
        }, sqldb.NoOpReset)
2548
        if err != nil {
×
2549
                return nil, 0, err
×
2550
        }
×
2551

2552
        return &tipHash, tipHeight, nil
×
2553
}
2554

2555
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2556
//
2557
// NOTE: this prunes nodes across protocol versions. It will never prune the
2558
// source nodes.
2559
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
2560
        db SQLQueries) ([]route.Vertex, error) {
×
2561

×
2562
        // Fetch all un-connected nodes from the database.
×
2563
        // NOTE: this will not include any nodes that are listed in the
×
2564
        // source table.
×
2565
        nodes, err := db.GetUnconnectedNodes(ctx)
×
2566
        if err != nil {
×
2567
                return nil, fmt.Errorf("unable to fetch unconnected nodes: %w",
×
2568
                        err)
×
2569
        }
×
2570

2571
        prunedNodes := make([]route.Vertex, 0, len(nodes))
×
2572
        for _, node := range nodes {
×
2573
                // TODO(elle): update to use sqlc.slice() once that works.
×
2574
                if err = db.DeleteNode(ctx, node.ID); err != nil {
×
2575
                        return nil, fmt.Errorf("unable to delete "+
×
2576
                                "node(id=%d): %w", node.ID, err)
×
2577
                }
×
2578

2579
                pubKey, err := route.NewVertexFromBytes(node.PubKey)
×
2580
                if err != nil {
×
2581
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
2582
                                "for node(id=%d): %w", node.ID, err)
×
2583
                }
×
2584

2585
                prunedNodes = append(prunedNodes, pubKey)
×
2586
        }
2587

2588
        return prunedNodes, nil
×
2589
}
2590

2591
// DisconnectBlockAtHeight is used to indicate that the block specified
2592
// by the passed height has been disconnected from the main chain. This
2593
// will "rewind" the graph back to the height below, deleting channels
2594
// that are no longer confirmed from the graph. The prune log will be
2595
// set to the last prune height valid for the remaining chain.
2596
// Channels that were removed from the graph resulting from the
2597
// disconnected block are returned.
2598
//
2599
// NOTE: part of the V1Store interface.
2600
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
2601
        []*models.ChannelEdgeInfo, error) {
×
2602

×
2603
        ctx := context.TODO()
×
2604

×
2605
        var (
×
2606
                // Every channel having a ShortChannelID starting at 'height'
×
2607
                // will no longer be confirmed.
×
2608
                startShortChanID = lnwire.ShortChannelID{
×
2609
                        BlockHeight: height,
×
2610
                }
×
2611

×
2612
                // Delete everything after this height from the db up until the
×
2613
                // SCID alias range.
×
2614
                endShortChanID = aliasmgr.StartingAlias
×
2615

×
2616
                removedChans []*models.ChannelEdgeInfo
×
2617
        )
×
2618

×
2619
        var chanIDStart [8]byte
×
2620
        byteOrder.PutUint64(chanIDStart[:], startShortChanID.ToUint64())
×
2621
        var chanIDEnd [8]byte
×
2622
        byteOrder.PutUint64(chanIDEnd[:], endShortChanID.ToUint64())
×
2623

×
2624
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2625
                rows, err := db.GetChannelsBySCIDRange(
×
2626
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
2627
                                StartScid: chanIDStart[:],
×
2628
                                EndScid:   chanIDEnd[:],
×
2629
                        },
×
2630
                )
×
2631
                if err != nil {
×
2632
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2633
                }
×
2634

2635
                for _, row := range rows {
×
2636
                        node1, node2, err := buildNodeVertices(
×
2637
                                row.Node1PubKey, row.Node2PubKey,
×
2638
                        )
×
2639
                        if err != nil {
×
2640
                                return err
×
2641
                        }
×
2642

2643
                        channel, err := getAndBuildEdgeInfo(
×
2644
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
2645
                                row.Channel, node1, node2,
×
2646
                        )
×
2647
                        if err != nil {
×
2648
                                return err
×
2649
                        }
×
2650

2651
                        err = db.DeleteChannel(ctx, row.Channel.ID)
×
2652
                        if err != nil {
×
2653
                                return fmt.Errorf("unable to delete "+
×
2654
                                        "channel: %w", err)
×
2655
                        }
×
2656

2657
                        removedChans = append(removedChans, channel)
×
2658
                }
2659

2660
                return db.DeletePruneLogEntriesInRange(
×
2661
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2662
                                StartHeight: int64(height),
×
2663
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
2664
                        },
×
2665
                )
×
2666
        }, func() {
×
2667
                removedChans = nil
×
2668
        })
×
2669
        if err != nil {
×
2670
                return nil, fmt.Errorf("unable to disconnect block at "+
×
2671
                        "height: %w", err)
×
2672
        }
×
2673

2674
        for _, channel := range removedChans {
×
2675
                s.rejectCache.remove(channel.ChannelID)
×
2676
                s.chanCache.remove(channel.ChannelID)
×
2677
        }
×
2678

2679
        return removedChans, nil
×
2680
}
2681

2682
// AddEdgeProof sets the proof of an existing edge in the graph database.
2683
//
2684
// NOTE: part of the V1Store interface.
2685
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
NEW
2686
        proof *models.ChannelAuthProof) error {
×
NEW
2687

×
NEW
2688
        var (
×
NEW
2689
                ctx       = context.TODO()
×
NEW
2690
                scidBytes = channelIDToBytes(scid.ToUint64())
×
NEW
2691
        )
×
NEW
2692

×
NEW
2693
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
NEW
2694
                res, err := db.AddV1ChannelProof(
×
NEW
2695
                        ctx, sqlc.AddV1ChannelProofParams{
×
NEW
2696
                                Scid:              scidBytes[:],
×
NEW
2697
                                Node1Signature:    proof.NodeSig1Bytes,
×
NEW
2698
                                Node2Signature:    proof.NodeSig2Bytes,
×
NEW
2699
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
NEW
2700
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
NEW
2701
                        },
×
NEW
2702
                )
×
NEW
2703
                if err != nil {
×
NEW
2704
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
NEW
2705
                }
×
2706

NEW
2707
                n, err := res.RowsAffected()
×
NEW
2708
                if err != nil {
×
NEW
2709
                        return err
×
NEW
2710
                }
×
2711

NEW
2712
                if n == 0 {
×
NEW
2713
                        return fmt.Errorf("no rows affected when adding edge "+
×
NEW
2714
                                "proof for SCID %v", scid)
×
NEW
2715
                } else if n > 1 {
×
NEW
2716
                        return fmt.Errorf("multiple rows affected when adding "+
×
NEW
2717
                                "edge proof for SCID %v: %d rows affected",
×
NEW
2718
                                scid, n)
×
NEW
2719
                }
×
2720

NEW
2721
                return nil
×
2722
        }, sqldb.NoOpReset)
NEW
2723
        if err != nil {
×
NEW
2724
                return fmt.Errorf("unable to add edge proof: %w", err)
×
NEW
2725
        }
×
2726

NEW
2727
        return nil
×
2728
}
2729

2730
// PutClosedScid stores a SCID for a closed channel in the database. This is so
2731
// that we can ignore channel announcements that we know to be closed without
2732
// having to validate them and fetch a block.
2733
//
2734
// NOTE: part of the V1Store interface.
NEW
2735
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
NEW
2736
        var (
×
NEW
2737
                ctx     = context.TODO()
×
NEW
2738
                chanIDB = channelIDToBytes(scid.ToUint64())
×
NEW
2739
        )
×
NEW
2740

×
NEW
2741
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
NEW
2742
                return db.InsertClosedChannel(ctx, chanIDB[:])
×
NEW
2743
        }, sqldb.NoOpReset)
×
2744
}
2745

2746
// IsClosedScid checks whether a channel identified by the passed in scid is
2747
// closed. This helps avoid having to perform expensive validation checks.
2748
//
2749
// NOTE: part of the V1Store interface.
NEW
2750
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
NEW
2751
        var (
×
NEW
2752
                ctx      = context.TODO()
×
NEW
2753
                isClosed bool
×
NEW
2754
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
NEW
2755
        )
×
NEW
2756
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
2757
                var err error
×
NEW
2758
                isClosed, err = db.IsClosedChannel(ctx, chanIDB[:])
×
NEW
2759
                if err != nil {
×
NEW
2760
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
NEW
2761
                                err)
×
NEW
2762
                }
×
2763

NEW
2764
                return nil
×
2765
        }, sqldb.NoOpReset)
NEW
2766
        if err != nil {
×
NEW
2767
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
NEW
2768
                        err)
×
NEW
2769
        }
×
2770

NEW
2771
        return isClosed, nil
×
2772
}
2773

2774
// GraphSession will provide the call-back with access to a NodeTraverser
2775
// instance which can be used to perform queries against the channel graph.
2776
//
2777
// NOTE: part of the V1Store interface.
NEW
2778
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error) error {
×
NEW
2779
        var ctx = context.TODO()
×
NEW
2780

×
NEW
2781
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
2782
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
NEW
2783
        }, sqldb.NoOpReset)
×
2784
}
2785

2786
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
2787
// read only transaction for a consistent view of the graph.
2788
type sqlNodeTraverser struct {
2789
        db    SQLQueries
2790
        chain chainhash.Hash
2791
}
2792

2793
// A compile-time assertion to ensure that sqlNodeTraverser implements the
2794
// NodeTraverser interface.
2795
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
2796

2797
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
2798
func newSQLNodeTraverser(db SQLQueries,
NEW
2799
        chain chainhash.Hash) *sqlNodeTraverser {
×
NEW
2800

×
NEW
2801
        return &sqlNodeTraverser{
×
NEW
2802
                db:    db,
×
NEW
2803
                chain: chain,
×
NEW
2804
        }
×
NEW
2805
}
×
2806

2807
// ForEachNodeDirectedChannel calls the callback for every channel of the given
2808
// node.
2809
//
2810
// NOTE: Part of the NodeTraverser interface.
2811
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
NEW
2812
        cb func(channel *DirectedChannel) error) error {
×
NEW
2813

×
NEW
2814
        ctx := context.TODO()
×
NEW
2815

×
NEW
2816
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
NEW
2817
}
×
2818

2819
// FetchNodeFeatures returns the features of the given node. If the node is
2820
// unknown, assume no additional features are supported.
2821
//
2822
// NOTE: Part of the NodeTraverser interface.
2823
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
NEW
2824
        *lnwire.FeatureVector, error) {
×
NEW
2825

×
NEW
2826
        ctx := context.TODO()
×
NEW
2827

×
NEW
2828
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
NEW
2829
}
×
2830

2831
// forEachNodeDirectedChannel iterates through all channels of a given
2832
// node, executing the passed callback on the directed edge representing the
2833
// channel and its incoming policy. If the node is not found, no error is
2834
// returned.
2835
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
2836
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
2837

×
2838
        toNodeCallback := func() route.Vertex {
×
2839
                return nodePub
×
2840
        }
×
2841

2842
        dbID, err := db.GetNodeIDByPubKey(
×
2843
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
2844
                        Version: int16(ProtocolV1),
×
2845
                        PubKey:  nodePub[:],
×
2846
                },
×
2847
        )
×
2848
        if errors.Is(err, sql.ErrNoRows) {
×
2849
                return nil
×
2850
        } else if err != nil {
×
2851
                return fmt.Errorf("unable to fetch node: %w", err)
×
2852
        }
×
2853

2854
        rows, err := db.ListChannelsByNodeID(
×
2855
                ctx, sqlc.ListChannelsByNodeIDParams{
×
2856
                        Version: int16(ProtocolV1),
×
2857
                        NodeID1: dbID,
×
2858
                },
×
2859
        )
×
2860
        if err != nil {
×
2861
                return fmt.Errorf("unable to fetch channels: %w", err)
×
2862
        }
×
2863

2864
        // Exit early if there are no channels for this node so we don't
2865
        // do the unnecessary feature fetching.
2866
        if len(rows) == 0 {
×
2867
                return nil
×
2868
        }
×
2869

2870
        features, err := getNodeFeatures(ctx, db, dbID)
×
2871
        if err != nil {
×
2872
                return fmt.Errorf("unable to fetch node features: %w", err)
×
2873
        }
×
2874

2875
        for _, row := range rows {
×
2876
                node1, node2, err := buildNodeVertices(
×
2877
                        row.Node1Pubkey, row.Node2Pubkey,
×
2878
                )
×
2879
                if err != nil {
×
2880
                        return fmt.Errorf("unable to build node vertices: %w",
×
2881
                                err)
×
2882
                }
×
2883

2884
                edge := buildCacheableChannelInfo(row.Channel, node1, node2)
×
2885

×
2886
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2887
                if err != nil {
×
2888
                        return err
×
2889
                }
×
2890

2891
                var p1, p2 *models.CachedEdgePolicy
×
2892
                if dbPol1 != nil {
×
2893
                        policy1, err := buildChanPolicy(
×
2894
                                *dbPol1, edge.ChannelID, nil, node2, true,
×
2895
                        )
×
2896
                        if err != nil {
×
2897
                                return err
×
2898
                        }
×
2899

2900
                        p1 = models.NewCachedPolicy(policy1)
×
2901
                }
2902
                if dbPol2 != nil {
×
2903
                        policy2, err := buildChanPolicy(
×
2904
                                *dbPol2, edge.ChannelID, nil, node1, false,
×
2905
                        )
×
2906
                        if err != nil {
×
2907
                                return err
×
2908
                        }
×
2909

2910
                        p2 = models.NewCachedPolicy(policy2)
×
2911
                }
2912

2913
                // Determine the outgoing and incoming policy for this
2914
                // channel and node combo.
2915
                outPolicy, inPolicy := p1, p2
×
2916
                if p1 != nil && node2 == nodePub {
×
2917
                        outPolicy, inPolicy = p2, p1
×
2918
                } else if p2 != nil && node1 != nodePub {
×
2919
                        outPolicy, inPolicy = p2, p1
×
2920
                }
×
2921

2922
                var cachedInPolicy *models.CachedEdgePolicy
×
2923
                if inPolicy != nil {
×
2924
                        cachedInPolicy = inPolicy
×
2925
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
2926
                        cachedInPolicy.ToNodeFeatures = features
×
2927
                }
×
2928

2929
                directedChannel := &DirectedChannel{
×
2930
                        ChannelID:    edge.ChannelID,
×
2931
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
2932
                        OtherNode:    edge.NodeKey2Bytes,
×
2933
                        Capacity:     edge.Capacity,
×
2934
                        OutPolicySet: outPolicy != nil,
×
2935
                        InPolicy:     cachedInPolicy,
×
2936
                }
×
2937
                if outPolicy != nil {
×
2938
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
2939
                                directedChannel.InboundFee = fee
×
2940
                        })
×
2941
                }
2942

2943
                if nodePub == edge.NodeKey2Bytes {
×
2944
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
2945
                }
×
2946

2947
                if err := cb(directedChannel); err != nil {
×
2948
                        return err
×
2949
                }
×
2950
        }
2951

2952
        return nil
×
2953
}
2954

2955
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
2956
// and executes the provided callback for each node.
2957
func forEachNodeCacheable(ctx context.Context, db SQLQueries,
2958
        cb func(nodeID int64, nodePub route.Vertex) error) error {
×
2959

×
NEW
2960
        lastID := int64(-1)
×
2961

×
2962
        for {
×
2963
                nodes, err := db.ListNodeIDsAndPubKeys(
×
2964
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
2965
                                Version: int16(ProtocolV1),
×
2966
                                ID:      lastID,
×
2967
                                Limit:   pageSize,
×
2968
                        },
×
2969
                )
×
2970
                if err != nil {
×
2971
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
2972
                }
×
2973

2974
                if len(nodes) == 0 {
×
2975
                        break
×
2976
                }
2977

2978
                for _, node := range nodes {
×
2979
                        var pub route.Vertex
×
2980
                        copy(pub[:], node.PubKey)
×
2981

×
2982
                        if err := cb(node.ID, pub); err != nil {
×
2983
                                return fmt.Errorf("forEachNodeCacheable "+
×
2984
                                        "callback failed for node(id=%d): %w",
×
2985
                                        node.ID, err)
×
2986
                        }
×
2987

2988
                        lastID = node.ID
×
2989
                }
2990
        }
2991

2992
        return nil
×
2993
}
2994

2995
// forEachNodeChannel iterates through all channels of a node, executing
2996
// the passed callback on each. The call-back is provided with the channel's
2997
// edge information, the outgoing policy and the incoming policy for the
2998
// channel and node combo.
2999
func forEachNodeChannel(ctx context.Context, db SQLQueries,
3000
        chain chainhash.Hash, id int64, cb func(*models.ChannelEdgeInfo,
3001
                *models.ChannelEdgePolicy,
3002
                *models.ChannelEdgePolicy) error) error {
×
3003

×
3004
        // Get all the V1 channels for this node.Add commentMore actions
×
3005
        rows, err := db.ListChannelsByNodeID(
×
3006
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3007
                        Version: int16(ProtocolV1),
×
3008
                        NodeID1: id,
×
3009
                },
×
3010
        )
×
3011
        if err != nil {
×
3012
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3013
        }
×
3014

3015
        // Call the call-back for each channel and its known policies.
3016
        for _, row := range rows {
×
3017
                node1, node2, err := buildNodeVertices(
×
3018
                        row.Node1Pubkey, row.Node2Pubkey,
×
3019
                )
×
3020
                if err != nil {
×
3021
                        return fmt.Errorf("unable to build node vertices: %w",
×
3022
                                err)
×
3023
                }
×
3024

3025
                edge, err := getAndBuildEdgeInfo(
×
3026
                        ctx, db, chain, row.Channel.ID, row.Channel, node1,
×
3027
                        node2,
×
3028
                )
×
3029
                if err != nil {
×
3030
                        return fmt.Errorf("unable to build channel info: %w",
×
3031
                                err)
×
3032
                }
×
3033

3034
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3035
                if err != nil {
×
3036
                        return fmt.Errorf("unable to extract channel "+
×
3037
                                "policies: %w", err)
×
3038
                }
×
3039

3040
                p1, p2, err := getAndBuildChanPolicies(
×
3041
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
3042
                )
×
3043
                if err != nil {
×
3044
                        return fmt.Errorf("unable to build channel "+
×
3045
                                "policies: %w", err)
×
3046
                }
×
3047

3048
                // Determine the outgoing and incoming policy for this
3049
                // channel and node combo.
3050
                p1ToNode := row.Channel.NodeID2
×
3051
                p2ToNode := row.Channel.NodeID1
×
3052
                outPolicy, inPolicy := p1, p2
×
3053
                if (p1 != nil && p1ToNode == id) ||
×
3054
                        (p2 != nil && p2ToNode != id) {
×
3055

×
3056
                        outPolicy, inPolicy = p2, p1
×
3057
                }
×
3058

3059
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3060
                        return err
×
3061
                }
×
3062
        }
3063

3064
        return nil
×
3065
}
3066

3067
// updateChanEdgePolicy upserts the channel policy info we have stored for
3068
// a channel we already know of.
3069
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3070
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
3071
        error) {
×
3072

×
3073
        var (
×
3074
                node1Pub, node2Pub route.Vertex
×
3075
                isNode1            bool
×
3076
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3077
        )
×
3078

×
3079
        // Check that this edge policy refers to a channel that we already
×
3080
        // know of. We do this explicitly so that we can return the appropriate
×
3081
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
3082
        // abort the transaction which would abort the entire batch.
×
3083
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3084
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3085
                        Scid:    chanIDB[:],
×
3086
                        Version: int16(ProtocolV1),
×
3087
                },
×
3088
        )
×
3089
        if errors.Is(err, sql.ErrNoRows) {
×
3090
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3091
        } else if err != nil {
×
3092
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3093
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3094
        }
×
3095

3096
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3097
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3098

×
3099
        // Figure out which node this edge is from.
×
3100
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3101
        nodeID := dbChan.NodeID1
×
3102
        if !isNode1 {
×
3103
                nodeID = dbChan.NodeID2
×
3104
        }
×
3105

3106
        var (
×
3107
                inboundBase sql.NullInt64
×
3108
                inboundRate sql.NullInt64
×
3109
        )
×
3110
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3111
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3112
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3113
        })
×
3114

3115
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
3116
                Version:     int16(ProtocolV1),
×
3117
                ChannelID:   dbChan.ID,
×
3118
                NodeID:      nodeID,
×
3119
                Timelock:    int32(edge.TimeLockDelta),
×
3120
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
3121
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
3122
                MinHtlcMsat: int64(edge.MinHTLC),
×
3123
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3124
                Disabled: sql.NullBool{
×
3125
                        Valid: true,
×
3126
                        Bool:  edge.IsDisabled(),
×
3127
                },
×
3128
                MaxHtlcMsat: sql.NullInt64{
×
3129
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3130
                        Int64: int64(edge.MaxHTLC),
×
3131
                },
×
3132
                InboundBaseFeeMsat:      inboundBase,
×
3133
                InboundFeeRateMilliMsat: inboundRate,
×
3134
                Signature:               edge.SigBytes,
×
3135
        })
×
3136
        if err != nil {
×
3137
                return node1Pub, node2Pub, isNode1,
×
3138
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3139
        }
×
3140

3141
        // Convert the flat extra opaque data into a map of TLV types to
3142
        // values.
3143
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3144
        if err != nil {
×
3145
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3146
                        "marshal extra opaque data: %w", err)
×
3147
        }
×
3148

3149
        // Update the channel policy's extra signed fields.
3150
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3151
        if err != nil {
×
3152
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3153
                        "policy extra TLVs: %w", err)
×
3154
        }
×
3155

3156
        return node1Pub, node2Pub, isNode1, nil
×
3157
}
3158

3159
// getNodeByPubKey attempts to look up a target node by its public key.
3160
func getNodeByPubKey(ctx context.Context, db SQLQueries,
3161
        pubKey route.Vertex) (int64, *models.LightningNode, error) {
×
3162

×
3163
        dbNode, err := db.GetNodeByPubKey(
×
3164
                ctx, sqlc.GetNodeByPubKeyParams{
×
3165
                        Version: int16(ProtocolV1),
×
3166
                        PubKey:  pubKey[:],
×
3167
                },
×
3168
        )
×
3169
        if errors.Is(err, sql.ErrNoRows) {
×
3170
                return 0, nil, ErrGraphNodeNotFound
×
3171
        } else if err != nil {
×
3172
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3173
        }
×
3174

3175
        node, err := buildNode(ctx, db, &dbNode)
×
3176
        if err != nil {
×
3177
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3178
        }
×
3179

3180
        return dbNode.ID, node, nil
×
3181
}
3182

3183
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3184
// provided database channel row and the public keys of the two nodes
3185
// involved in the channel.
3186
func buildCacheableChannelInfo(dbChan sqlc.Channel, node1Pub,
3187
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
3188

×
3189
        return &models.CachedEdgeInfo{
×
3190
                ChannelID:     byteOrder.Uint64(dbChan.Scid),
×
3191
                NodeKey1Bytes: node1Pub,
×
3192
                NodeKey2Bytes: node2Pub,
×
3193
                Capacity:      btcutil.Amount(dbChan.Capacity.Int64),
×
3194
        }
×
3195
}
×
3196

3197
// buildNode constructs a LightningNode instance from the given database node
3198
// record. The node's features, addresses and extra signed fields are also
3199
// fetched from the database and set on the node.
3200
func buildNode(ctx context.Context, db SQLQueries, dbNode *sqlc.Node) (
3201
        *models.LightningNode, error) {
×
3202

×
3203
        if dbNode.Version != int16(ProtocolV1) {
×
3204
                return nil, fmt.Errorf("unsupported node version: %d",
×
3205
                        dbNode.Version)
×
3206
        }
×
3207

3208
        var pub [33]byte
×
3209
        copy(pub[:], dbNode.PubKey)
×
3210

×
3211
        node := &models.LightningNode{
×
3212
                PubKeyBytes: pub,
×
3213
                Features:    lnwire.EmptyFeatureVector(),
×
3214
                LastUpdate:  time.Unix(0, 0),
×
3215
        }
×
3216

×
3217
        if len(dbNode.Signature) == 0 {
×
3218
                return node, nil
×
3219
        }
×
3220

3221
        node.HaveNodeAnnouncement = true
×
3222
        node.AuthSigBytes = dbNode.Signature
×
3223
        node.Alias = dbNode.Alias.String
×
3224
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
3225

×
3226
        var err error
×
3227
        if dbNode.Color.Valid {
×
3228
                node.Color, err = DecodeHexColor(dbNode.Color.String)
×
3229
                if err != nil {
×
3230
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3231
                                err)
×
3232
                }
×
3233
        }
3234

3235
        // Fetch the node's features.
3236
        node.Features, err = getNodeFeatures(ctx, db, dbNode.ID)
×
3237
        if err != nil {
×
3238
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3239
                        "features: %w", dbNode.ID, err)
×
3240
        }
×
3241

3242
        // Fetch the node's addresses.
3243
        _, node.Addresses, err = getNodeAddresses(ctx, db, pub[:])
×
3244
        if err != nil {
×
3245
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3246
                        "addresses: %w", dbNode.ID, err)
×
3247
        }
×
3248

3249
        // Fetch the node's extra signed fields.
3250
        extraTLVMap, err := getNodeExtraSignedFields(ctx, db, dbNode.ID)
×
3251
        if err != nil {
×
3252
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3253
                        "extra signed fields: %w", dbNode.ID, err)
×
3254
        }
×
3255

3256
        recs, err := lnwire.CustomRecords(extraTLVMap).Serialize()
×
3257
        if err != nil {
×
3258
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
3259
                        "fields: %w", err)
×
3260
        }
×
3261

3262
        if len(recs) != 0 {
×
3263
                node.ExtraOpaqueData = recs
×
3264
        }
×
3265

3266
        return node, nil
×
3267
}
3268

3269
// getNodeFeatures fetches the feature bits and constructs the feature vector
3270
// for a node with the given DB ID.
3271
func getNodeFeatures(ctx context.Context, db SQLQueries,
3272
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3273

×
3274
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3275
        if err != nil {
×
3276
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3277
                        nodeID, err)
×
3278
        }
×
3279

3280
        features := lnwire.EmptyFeatureVector()
×
3281
        for _, feature := range rows {
×
3282
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3283
        }
×
3284

3285
        return features, nil
×
3286
}
3287

3288
// getNodeExtraSignedFields fetches the extra signed fields for a node with the
3289
// given DB ID.
3290
func getNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3291
        nodeID int64) (map[uint64][]byte, error) {
×
3292

×
3293
        fields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3294
        if err != nil {
×
3295
                return nil, fmt.Errorf("unable to get node(%d) extra "+
×
3296
                        "signed fields: %w", nodeID, err)
×
3297
        }
×
3298

3299
        extraFields := make(map[uint64][]byte)
×
3300
        for _, field := range fields {
×
3301
                extraFields[uint64(field.Type)] = field.Value
×
3302
        }
×
3303

3304
        return extraFields, nil
×
3305
}
3306

3307
// upsertNode upserts the node record into the database. If the node already
3308
// exists, then the node's information is updated. If the node doesn't exist,
3309
// then a new node is created. The node's features, addresses and extra TLV
3310
// types are also updated. The node's DB ID is returned.
3311
func upsertNode(ctx context.Context, db SQLQueries,
3312
        node *models.LightningNode) (int64, error) {
×
3313

×
3314
        params := sqlc.UpsertNodeParams{
×
3315
                Version: int16(ProtocolV1),
×
3316
                PubKey:  node.PubKeyBytes[:],
×
3317
        }
×
3318

×
3319
        if node.HaveNodeAnnouncement {
×
3320
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
3321
                params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
×
3322
                params.Alias = sqldb.SQLStr(node.Alias)
×
3323
                params.Signature = node.AuthSigBytes
×
3324
        }
×
3325

3326
        nodeID, err := db.UpsertNode(ctx, params)
×
3327
        if err != nil {
×
3328
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3329
                        err)
×
3330
        }
×
3331

3332
        // We can exit here if we don't have the announcement yet.
3333
        if !node.HaveNodeAnnouncement {
×
3334
                return nodeID, nil
×
3335
        }
×
3336

3337
        // Update the node's features.
3338
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3339
        if err != nil {
×
3340
                return 0, fmt.Errorf("inserting node features: %w", err)
×
3341
        }
×
3342

3343
        // Update the node's addresses.
3344
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3345
        if err != nil {
×
3346
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
3347
        }
×
3348

3349
        // Convert the flat extra opaque data into a map of TLV types to
3350
        // values.
3351
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
3352
        if err != nil {
×
3353
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
3354
                        err)
×
3355
        }
×
3356

3357
        // Update the node's extra signed fields.
3358
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3359
        if err != nil {
×
3360
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
3361
        }
×
3362

3363
        return nodeID, nil
×
3364
}
3365

3366
// upsertNodeFeatures updates the node's features node_features table. This
3367
// includes deleting any feature bits no longer present and inserting any new
3368
// feature bits. If the feature bit does not yet exist in the features table,
3369
// then an entry is created in that table first.
3370
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3371
        features *lnwire.FeatureVector) error {
×
3372

×
3373
        // Get any existing features for the node.
×
3374
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3375
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3376
                return err
×
3377
        }
×
3378

3379
        // Copy the nodes latest set of feature bits.
3380
        newFeatures := make(map[int32]struct{})
×
3381
        if features != nil {
×
3382
                for feature := range features.Features() {
×
3383
                        newFeatures[int32(feature)] = struct{}{}
×
3384
                }
×
3385
        }
3386

3387
        // For any current feature that already exists in the DB, remove it from
3388
        // the in-memory map. For any existing feature that does not exist in
3389
        // the in-memory map, delete it from the database.
3390
        for _, feature := range existingFeatures {
×
3391
                // The feature is still present, so there are no updates to be
×
3392
                // made.
×
3393
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3394
                        delete(newFeatures, feature.FeatureBit)
×
3395
                        continue
×
3396
                }
3397

3398
                // The feature is no longer present, so we remove it from the
3399
                // database.
3400
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
3401
                        NodeID:     nodeID,
×
3402
                        FeatureBit: feature.FeatureBit,
×
3403
                })
×
3404
                if err != nil {
×
3405
                        return fmt.Errorf("unable to delete node(%d) "+
×
3406
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3407
                                err)
×
3408
                }
×
3409
        }
3410

3411
        // Any remaining entries in newFeatures are new features that need to be
3412
        // added to the database for the first time.
3413
        for feature := range newFeatures {
×
3414
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3415
                        NodeID:     nodeID,
×
3416
                        FeatureBit: feature,
×
3417
                })
×
3418
                if err != nil {
×
3419
                        return fmt.Errorf("unable to insert node(%d) "+
×
3420
                                "feature(%v): %w", nodeID, feature, err)
×
3421
                }
×
3422
        }
3423

3424
        return nil
×
3425
}
3426

3427
// fetchNodeFeatures fetches the features for a node with the given public key.
3428
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3429
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3430

×
3431
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3432
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3433
                        PubKey:  nodePub[:],
×
3434
                        Version: int16(ProtocolV1),
×
3435
                },
×
3436
        )
×
3437
        if err != nil {
×
3438
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
3439
                        nodePub, err)
×
3440
        }
×
3441

3442
        features := lnwire.EmptyFeatureVector()
×
3443
        for _, bit := range rows {
×
3444
                features.Set(lnwire.FeatureBit(bit))
×
3445
        }
×
3446

3447
        return features, nil
×
3448
}
3449

3450
// dbAddressType is an enum type that represents the different address types
3451
// that we store in the node_addresses table. The address type determines how
3452
// the address is to be serialised/deserialize.
3453
type dbAddressType uint8
3454

3455
const (
3456
        addressTypeIPv4   dbAddressType = 1
3457
        addressTypeIPv6   dbAddressType = 2
3458
        addressTypeTorV2  dbAddressType = 3
3459
        addressTypeTorV3  dbAddressType = 4
3460
        addressTypeOpaque dbAddressType = math.MaxInt8
3461
)
3462

3463
// upsertNodeAddresses updates the node's addresses in the database. This
3464
// includes deleting any existing addresses and inserting the new set of
3465
// addresses. The deletion is necessary since the ordering of the addresses may
3466
// change, and we need to ensure that the database reflects the latest set of
3467
// addresses so that at the time of reconstructing the node announcement, the
3468
// order is preserved and the signature over the message remains valid.
3469
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
3470
        addresses []net.Addr) error {
×
3471

×
3472
        // Delete any existing addresses for the node. This is required since
×
3473
        // even if the new set of addresses is the same, the ordering may have
×
3474
        // changed for a given address type.
×
3475
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
3476
        if err != nil {
×
3477
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
3478
                        nodeID, err)
×
3479
        }
×
3480

3481
        // Copy the nodes latest set of addresses.
3482
        newAddresses := map[dbAddressType][]string{
×
3483
                addressTypeIPv4:   {},
×
3484
                addressTypeIPv6:   {},
×
3485
                addressTypeTorV2:  {},
×
3486
                addressTypeTorV3:  {},
×
3487
                addressTypeOpaque: {},
×
3488
        }
×
3489
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3490
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3491
        }
×
3492

3493
        for _, address := range addresses {
×
3494
                switch addr := address.(type) {
×
3495
                case *net.TCPAddr:
×
3496
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3497
                                addAddr(addressTypeIPv4, addr)
×
3498
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3499
                                addAddr(addressTypeIPv6, addr)
×
3500
                        } else {
×
3501
                                return fmt.Errorf("unhandled IP address: %v",
×
3502
                                        addr)
×
3503
                        }
×
3504

3505
                case *tor.OnionAddr:
×
3506
                        switch len(addr.OnionService) {
×
3507
                        case tor.V2Len:
×
3508
                                addAddr(addressTypeTorV2, addr)
×
3509
                        case tor.V3Len:
×
3510
                                addAddr(addressTypeTorV3, addr)
×
3511
                        default:
×
3512
                                return fmt.Errorf("invalid length for a tor " +
×
3513
                                        "address")
×
3514
                        }
3515

3516
                case *lnwire.OpaqueAddrs:
×
3517
                        addAddr(addressTypeOpaque, addr)
×
3518

3519
                default:
×
3520
                        return fmt.Errorf("unhandled address type: %T", addr)
×
3521
                }
3522
        }
3523

3524
        // Any remaining entries in newAddresses are new addresses that need to
3525
        // be added to the database for the first time.
3526
        for addrType, addrList := range newAddresses {
×
3527
                for position, addr := range addrList {
×
3528
                        err := db.InsertNodeAddress(
×
3529
                                ctx, sqlc.InsertNodeAddressParams{
×
3530
                                        NodeID:   nodeID,
×
3531
                                        Type:     int16(addrType),
×
3532
                                        Address:  addr,
×
3533
                                        Position: int32(position),
×
3534
                                },
×
3535
                        )
×
3536
                        if err != nil {
×
3537
                                return fmt.Errorf("unable to insert "+
×
3538
                                        "node(%d) address(%v): %w", nodeID,
×
3539
                                        addr, err)
×
3540
                        }
×
3541
                }
3542
        }
3543

3544
        return nil
×
3545
}
3546

3547
// getNodeAddresses fetches the addresses for a node with the given public key.
3548
func getNodeAddresses(ctx context.Context, db SQLQueries, nodePub []byte) (bool,
3549
        []net.Addr, error) {
×
3550

×
3551
        // GetNodeAddressesByPubKey ensures that the addresses for a given type
×
3552
        // are returned in the same order as they were inserted.
×
3553
        rows, err := db.GetNodeAddressesByPubKey(
×
3554
                ctx, sqlc.GetNodeAddressesByPubKeyParams{
×
3555
                        Version: int16(ProtocolV1),
×
3556
                        PubKey:  nodePub,
×
3557
                },
×
3558
        )
×
3559
        if err != nil {
×
3560
                return false, nil, err
×
3561
        }
×
3562

3563
        // GetNodeAddressesByPubKey uses a left join so there should always be
3564
        // at least one row returned if the node exists even if it has no
3565
        // addresses.
3566
        if len(rows) == 0 {
×
3567
                return false, nil, nil
×
3568
        }
×
3569

3570
        addresses := make([]net.Addr, 0, len(rows))
×
3571
        for _, addr := range rows {
×
3572
                if !(addr.Type.Valid && addr.Address.Valid) {
×
3573
                        continue
×
3574
                }
3575

3576
                address := addr.Address.String
×
3577

×
3578
                switch dbAddressType(addr.Type.Int16) {
×
3579
                case addressTypeIPv4:
×
3580
                        tcp, err := net.ResolveTCPAddr("tcp4", address)
×
3581
                        if err != nil {
×
3582
                                return false, nil, nil
×
3583
                        }
×
3584
                        tcp.IP = tcp.IP.To4()
×
3585

×
3586
                        addresses = append(addresses, tcp)
×
3587

3588
                case addressTypeIPv6:
×
3589
                        tcp, err := net.ResolveTCPAddr("tcp6", address)
×
3590
                        if err != nil {
×
3591
                                return false, nil, nil
×
3592
                        }
×
3593
                        addresses = append(addresses, tcp)
×
3594

3595
                case addressTypeTorV3, addressTypeTorV2:
×
3596
                        service, portStr, err := net.SplitHostPort(address)
×
3597
                        if err != nil {
×
3598
                                return false, nil, fmt.Errorf("unable to "+
×
3599
                                        "split tor v3 address: %v",
×
3600
                                        addr.Address)
×
3601
                        }
×
3602

3603
                        port, err := strconv.Atoi(portStr)
×
3604
                        if err != nil {
×
3605
                                return false, nil, err
×
3606
                        }
×
3607

3608
                        addresses = append(addresses, &tor.OnionAddr{
×
3609
                                OnionService: service,
×
3610
                                Port:         port,
×
3611
                        })
×
3612

3613
                case addressTypeOpaque:
×
3614
                        opaque, err := hex.DecodeString(address)
×
3615
                        if err != nil {
×
3616
                                return false, nil, fmt.Errorf("unable to "+
×
3617
                                        "decode opaque address: %v", addr)
×
3618
                        }
×
3619

3620
                        addresses = append(addresses, &lnwire.OpaqueAddrs{
×
3621
                                Payload: opaque,
×
3622
                        })
×
3623

3624
                default:
×
3625
                        return false, nil, fmt.Errorf("unknown address "+
×
3626
                                "type: %v", addr.Type)
×
3627
                }
3628
        }
3629

3630
        return true, addresses, nil
×
3631
}
3632

3633
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
3634
// database. This includes updating any existing types, inserting any new types,
3635
// and deleting any types that are no longer present.
3636
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3637
        nodeID int64, extraFields map[uint64][]byte) error {
×
3638

×
3639
        // Get any existing extra signed fields for the node.
×
3640
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3641
        if err != nil {
×
3642
                return err
×
3643
        }
×
3644

3645
        // Make a lookup map of the existing field types so that we can use it
3646
        // to keep track of any fields we should delete.
3647
        m := make(map[uint64]bool)
×
3648
        for _, field := range existingFields {
×
3649
                m[uint64(field.Type)] = true
×
3650
        }
×
3651

3652
        // For all the new fields, we'll upsert them and remove them from the
3653
        // map of existing fields.
3654
        for tlvType, value := range extraFields {
×
3655
                err = db.UpsertNodeExtraType(
×
3656
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
3657
                                NodeID: nodeID,
×
3658
                                Type:   int64(tlvType),
×
3659
                                Value:  value,
×
3660
                        },
×
3661
                )
×
3662
                if err != nil {
×
3663
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
3664
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3665
                }
×
3666

3667
                // Remove the field from the map of existing fields if it was
3668
                // present.
3669
                delete(m, tlvType)
×
3670
        }
3671

3672
        // For all the fields that are left in the map of existing fields, we'll
3673
        // delete them as they are no longer present in the new set of fields.
3674
        for tlvType := range m {
×
3675
                err = db.DeleteExtraNodeType(
×
3676
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
3677
                                NodeID: nodeID,
×
3678
                                Type:   int64(tlvType),
×
3679
                        },
×
3680
                )
×
3681
                if err != nil {
×
3682
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
3683
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3684
                }
×
3685
        }
3686

3687
        return nil
×
3688
}
3689

3690
// srcNodeInfo holds the information about the source node of the graph.
3691
type srcNodeInfo struct {
3692
        // id is the DB level ID of the source node entry in the "nodes" table.
3693
        id int64
3694

3695
        // pub is the public key of the source node.
3696
        pub route.Vertex
3697
}
3698

3699
// getSourceNode returns the DB node ID and pub key of the source node for the
3700
// specified protocol version.
3701
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
3702
        version ProtocolVersion) (int64, route.Vertex, error) {
×
3703

×
3704
        s.srcNodeMu.Lock()
×
3705
        defer s.srcNodeMu.Unlock()
×
3706

×
3707
        // If we already have the source node ID and pub key cached, then
×
3708
        // return them.
×
3709
        if info, ok := s.srcNodes[version]; ok {
×
3710
                return info.id, info.pub, nil
×
3711
        }
×
3712

3713
        var pubKey route.Vertex
×
3714

×
3715
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
3716
        if err != nil {
×
3717
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
3718
                        err)
×
3719
        }
×
3720

3721
        if len(nodes) == 0 {
×
3722
                return 0, pubKey, ErrSourceNodeNotSet
×
3723
        } else if len(nodes) > 1 {
×
3724
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
3725
                        "protocol %s found", version)
×
3726
        }
×
3727

3728
        copy(pubKey[:], nodes[0].PubKey)
×
3729

×
3730
        s.srcNodes[version] = &srcNodeInfo{
×
3731
                id:  nodes[0].NodeID,
×
3732
                pub: pubKey,
×
3733
        }
×
3734

×
3735
        return nodes[0].NodeID, pubKey, nil
×
3736
}
3737

3738
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
3739
// This then produces a map from TLV type to value. If the input is not a
3740
// valid TLV stream, then an error is returned.
3741
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
3742
        r := bytes.NewReader(data)
×
3743

×
3744
        tlvStream, err := tlv.NewStream()
×
3745
        if err != nil {
×
3746
                return nil, err
×
3747
        }
×
3748

3749
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
3750
        // pass it into the P2P decoding variant.
3751
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
3752
        if err != nil {
×
3753
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
3754
        }
×
3755
        if len(parsedTypes) == 0 {
×
3756
                return nil, nil
×
3757
        }
×
3758

3759
        records := make(map[uint64][]byte)
×
3760
        for k, v := range parsedTypes {
×
3761
                records[uint64(k)] = v
×
3762
        }
×
3763

3764
        return records, nil
×
3765
}
3766

3767
// insertChannel inserts a new channel record into the database.
3768
func insertChannel(ctx context.Context, db SQLQueries,
3769
        edge *models.ChannelEdgeInfo) error {
×
3770

×
3771
        chanIDB := channelIDToBytes(edge.ChannelID)
×
3772

×
3773
        // Make sure that the channel doesn't already exist. We do this
×
3774
        // explicitly instead of relying on catching a unique constraint error
×
3775
        // because relying on SQL to throw that error would abort the entire
×
3776
        // batch of transactions.
×
3777
        _, err := db.GetChannelBySCID(
×
3778
                ctx, sqlc.GetChannelBySCIDParams{
×
3779
                        Scid:    chanIDB[:],
×
3780
                        Version: int16(ProtocolV1),
×
3781
                },
×
3782
        )
×
3783
        if err == nil {
×
3784
                return ErrEdgeAlreadyExist
×
3785
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3786
                return fmt.Errorf("unable to fetch channel: %w", err)
×
3787
        }
×
3788

3789
        // Make sure that at least a "shell" entry for each node is present in
3790
        // the nodes table.
3791
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
3792
        if err != nil {
×
3793
                return fmt.Errorf("unable to create shell node: %w", err)
×
3794
        }
×
3795

3796
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
3797
        if err != nil {
×
3798
                return fmt.Errorf("unable to create shell node: %w", err)
×
3799
        }
×
3800

3801
        var capacity sql.NullInt64
×
3802
        if edge.Capacity != 0 {
×
3803
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
3804
        }
×
3805

3806
        createParams := sqlc.CreateChannelParams{
×
3807
                Version:     int16(ProtocolV1),
×
3808
                Scid:        chanIDB[:],
×
3809
                NodeID1:     node1DBID,
×
3810
                NodeID2:     node2DBID,
×
3811
                Outpoint:    edge.ChannelPoint.String(),
×
3812
                Capacity:    capacity,
×
3813
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
3814
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
3815
        }
×
3816

×
3817
        if edge.AuthProof != nil {
×
3818
                proof := edge.AuthProof
×
3819

×
3820
                createParams.Node1Signature = proof.NodeSig1Bytes
×
3821
                createParams.Node2Signature = proof.NodeSig2Bytes
×
3822
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
3823
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
3824
        }
×
3825

3826
        // Insert the new channel record.
3827
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
3828
        if err != nil {
×
3829
                return err
×
3830
        }
×
3831

3832
        // Insert any channel features.
3833
        if len(edge.Features) != 0 {
×
3834
                chanFeatures := lnwire.NewRawFeatureVector()
×
3835
                err := chanFeatures.Decode(bytes.NewReader(edge.Features))
×
3836
                if err != nil {
×
3837
                        return err
×
3838
                }
×
3839

3840
                fv := lnwire.NewFeatureVector(chanFeatures, lnwire.Features)
×
3841
                for feature := range fv.Features() {
×
3842
                        err = db.InsertChannelFeature(
×
3843
                                ctx, sqlc.InsertChannelFeatureParams{
×
3844
                                        ChannelID:  dbChanID,
×
3845
                                        FeatureBit: int32(feature),
×
3846
                                },
×
3847
                        )
×
3848
                        if err != nil {
×
3849
                                return fmt.Errorf("unable to insert "+
×
3850
                                        "channel(%d) feature(%v): %w", dbChanID,
×
3851
                                        feature, err)
×
3852
                        }
×
3853
                }
3854
        }
3855

3856
        // Finally, insert any extra TLV fields in the channel announcement.
3857
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3858
        if err != nil {
×
3859
                return fmt.Errorf("unable to marshal extra opaque data: %w",
×
3860
                        err)
×
3861
        }
×
3862

3863
        for tlvType, value := range extra {
×
3864
                err := db.CreateChannelExtraType(
×
3865
                        ctx, sqlc.CreateChannelExtraTypeParams{
×
3866
                                ChannelID: dbChanID,
×
3867
                                Type:      int64(tlvType),
×
3868
                                Value:     value,
×
3869
                        },
×
3870
                )
×
3871
                if err != nil {
×
3872
                        return fmt.Errorf("unable to upsert channel(%d) extra "+
×
3873
                                "signed field(%v): %w", edge.ChannelID,
×
3874
                                tlvType, err)
×
3875
                }
×
3876
        }
3877

3878
        return nil
×
3879
}
3880

3881
// maybeCreateShellNode checks if a shell node entry exists for the
3882
// given public key. If it does not exist, then a new shell node entry is
3883
// created. The ID of the node is returned. A shell node only has a protocol
3884
// version and public key persisted.
3885
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
3886
        pubKey route.Vertex) (int64, error) {
×
3887

×
3888
        dbNode, err := db.GetNodeByPubKey(
×
3889
                ctx, sqlc.GetNodeByPubKeyParams{
×
3890
                        PubKey:  pubKey[:],
×
3891
                        Version: int16(ProtocolV1),
×
3892
                },
×
3893
        )
×
3894
        // The node exists. Return the ID.
×
3895
        if err == nil {
×
3896
                return dbNode.ID, nil
×
3897
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3898
                return 0, err
×
3899
        }
×
3900

3901
        // Otherwise, the node does not exist, so we create a shell entry for
3902
        // it.
3903
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
3904
                Version: int16(ProtocolV1),
×
3905
                PubKey:  pubKey[:],
×
3906
        })
×
3907
        if err != nil {
×
3908
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
3909
        }
×
3910

3911
        return id, nil
×
3912
}
3913

3914
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
3915
// the database. This includes deleting any existing types and then inserting
3916
// the new types.
3917
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
3918
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
3919

×
3920
        // Delete all existing extra signed fields for the channel policy.
×
3921
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
3922
        if err != nil {
×
3923
                return fmt.Errorf("unable to delete "+
×
3924
                        "existing policy extra signed fields for policy %d: %w",
×
3925
                        chanPolicyID, err)
×
3926
        }
×
3927

3928
        // Insert all new extra signed fields for the channel policy.
3929
        for tlvType, value := range extraFields {
×
3930
                err = db.InsertChanPolicyExtraType(
×
3931
                        ctx, sqlc.InsertChanPolicyExtraTypeParams{
×
3932
                                ChannelPolicyID: chanPolicyID,
×
3933
                                Type:            int64(tlvType),
×
3934
                                Value:           value,
×
3935
                        },
×
3936
                )
×
3937
                if err != nil {
×
3938
                        return fmt.Errorf("unable to insert "+
×
3939
                                "channel_policy(%d) extra signed field(%v): %w",
×
3940
                                chanPolicyID, tlvType, err)
×
3941
                }
×
3942
        }
3943

3944
        return nil
×
3945
}
3946

3947
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
3948
// provided dbChanRow and also fetches any other required information
3949
// to construct the edge info.
3950
func getAndBuildEdgeInfo(ctx context.Context, db SQLQueries,
3951
        chain chainhash.Hash, dbChanID int64, dbChan sqlc.Channel, node1,
3952
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
3953

×
3954
        if dbChan.Version != int16(ProtocolV1) {
×
3955
                return nil, fmt.Errorf("unsupported channel version: %d",
×
3956
                        dbChan.Version)
×
3957
        }
×
3958

3959
        fv, extras, err := getChanFeaturesAndExtras(
×
3960
                ctx, db, dbChanID,
×
3961
        )
×
3962
        if err != nil {
×
3963
                return nil, err
×
3964
        }
×
3965

3966
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
3967
        if err != nil {
×
3968
                return nil, err
×
3969
        }
×
3970

3971
        var featureBuf bytes.Buffer
×
3972
        if err := fv.Encode(&featureBuf); err != nil {
×
3973
                return nil, fmt.Errorf("unable to encode features: %w", err)
×
3974
        }
×
3975

3976
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
3977
        if err != nil {
×
3978
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
3979
                        "fields: %w", err)
×
3980
        }
×
3981
        if recs == nil {
×
3982
                recs = make([]byte, 0)
×
3983
        }
×
3984

3985
        var btcKey1, btcKey2 route.Vertex
×
3986
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
3987
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
3988

×
3989
        channel := &models.ChannelEdgeInfo{
×
3990
                ChainHash:        chain,
×
3991
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
3992
                NodeKey1Bytes:    node1,
×
3993
                NodeKey2Bytes:    node2,
×
3994
                BitcoinKey1Bytes: btcKey1,
×
3995
                BitcoinKey2Bytes: btcKey2,
×
3996
                ChannelPoint:     *op,
×
3997
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
3998
                Features:         featureBuf.Bytes(),
×
3999
                ExtraOpaqueData:  recs,
×
4000
        }
×
4001

×
4002
        // We always set all the signatures at the same time, so we can
×
4003
        // safely check if one signature is present to determine if we have the
×
4004
        // rest of the signatures for the auth proof.
×
4005
        if len(dbChan.Bitcoin1Signature) > 0 {
×
4006
                channel.AuthProof = &models.ChannelAuthProof{
×
4007
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
4008
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
4009
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
4010
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
4011
                }
×
4012
        }
×
4013

4014
        return channel, nil
×
4015
}
4016

4017
// buildNodeVertices is a helper that converts raw node public keys
4018
// into route.Vertex instances.
4019
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4020
        route.Vertex, error) {
×
4021

×
4022
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4023
        if err != nil {
×
4024
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4025
                        "create vertex from node1 pubkey: %w", err)
×
4026
        }
×
4027

4028
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
4029
        if err != nil {
×
4030
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4031
                        "create vertex from node2 pubkey: %w", err)
×
4032
        }
×
4033

4034
        return node1Vertex, node2Vertex, nil
×
4035
}
4036

4037
// getChanFeaturesAndExtras fetches the channel features and extra TLV types
4038
// for a channel with the given ID.
4039
func getChanFeaturesAndExtras(ctx context.Context, db SQLQueries,
4040
        id int64) (*lnwire.FeatureVector, map[uint64][]byte, error) {
×
4041

×
4042
        rows, err := db.GetChannelFeaturesAndExtras(ctx, id)
×
4043
        if err != nil {
×
4044
                return nil, nil, fmt.Errorf("unable to fetch channel "+
×
4045
                        "features and extras: %w", err)
×
4046
        }
×
4047

4048
        var (
×
4049
                fv     = lnwire.EmptyFeatureVector()
×
4050
                extras = make(map[uint64][]byte)
×
4051
        )
×
4052
        for _, row := range rows {
×
4053
                if row.IsFeature {
×
4054
                        fv.Set(lnwire.FeatureBit(row.FeatureBit))
×
4055

×
4056
                        continue
×
4057
                }
4058

4059
                tlvType, ok := row.ExtraKey.(int64)
×
4060
                if !ok {
×
4061
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
4062
                                "TLV type: %T", row.ExtraKey)
×
4063
                }
×
4064

4065
                valueBytes, ok := row.Value.([]byte)
×
4066
                if !ok {
×
4067
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
4068
                                "Value: %T", row.Value)
×
4069
                }
×
4070

4071
                extras[uint64(tlvType)] = valueBytes
×
4072
        }
4073

4074
        return fv, extras, nil
×
4075
}
4076

4077
// getAndBuildChanPolicies uses the given sqlc.ChannelPolicy and also retrieves
4078
// all the extra info required to build the complete models.ChannelEdgePolicy
4079
// types. It returns two policies, which may be nil if the provided
4080
// sqlc.ChannelPolicy records are nil.
4081
func getAndBuildChanPolicies(ctx context.Context, db SQLQueries,
4082
        dbPol1, dbPol2 *sqlc.ChannelPolicy, channelID uint64, node1,
4083
        node2 route.Vertex) (*models.ChannelEdgePolicy,
4084
        *models.ChannelEdgePolicy, error) {
×
4085

×
4086
        if dbPol1 == nil && dbPol2 == nil {
×
4087
                return nil, nil, nil
×
4088
        }
×
4089

4090
        var (
×
4091
                policy1ID int64
×
4092
                policy2ID int64
×
4093
        )
×
4094
        if dbPol1 != nil {
×
4095
                policy1ID = dbPol1.ID
×
4096
        }
×
4097
        if dbPol2 != nil {
×
4098
                policy2ID = dbPol2.ID
×
4099
        }
×
4100
        rows, err := db.GetChannelPolicyExtraTypes(
×
4101
                ctx, sqlc.GetChannelPolicyExtraTypesParams{
×
4102
                        ID:   policy1ID,
×
4103
                        ID_2: policy2ID,
×
4104
                },
×
4105
        )
×
4106
        if err != nil {
×
4107
                return nil, nil, err
×
4108
        }
×
4109

4110
        var (
×
4111
                dbPol1Extras = make(map[uint64][]byte)
×
4112
                dbPol2Extras = make(map[uint64][]byte)
×
4113
        )
×
4114
        for _, row := range rows {
×
4115
                switch row.PolicyID {
×
4116
                case policy1ID:
×
4117
                        dbPol1Extras[uint64(row.Type)] = row.Value
×
4118
                case policy2ID:
×
4119
                        dbPol2Extras[uint64(row.Type)] = row.Value
×
4120
                default:
×
4121
                        return nil, nil, fmt.Errorf("unexpected policy ID %d "+
×
4122
                                "in row: %v", row.PolicyID, row)
×
4123
                }
4124
        }
4125

4126
        var pol1, pol2 *models.ChannelEdgePolicy
×
4127
        if dbPol1 != nil {
×
4128
                pol1, err = buildChanPolicy(
×
4129
                        *dbPol1, channelID, dbPol1Extras, node2, true,
×
4130
                )
×
4131
                if err != nil {
×
4132
                        return nil, nil, err
×
4133
                }
×
4134
        }
4135
        if dbPol2 != nil {
×
4136
                pol2, err = buildChanPolicy(
×
4137
                        *dbPol2, channelID, dbPol2Extras, node1, false,
×
4138
                )
×
4139
                if err != nil {
×
4140
                        return nil, nil, err
×
4141
                }
×
4142
        }
4143

4144
        return pol1, pol2, nil
×
4145
}
4146

4147
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4148
// provided sqlc.ChannelPolicy and other required information.
4149
func buildChanPolicy(dbPolicy sqlc.ChannelPolicy, channelID uint64,
4150
        extras map[uint64][]byte, toNode route.Vertex,
4151
        isNode1 bool) (*models.ChannelEdgePolicy, error) {
×
4152

×
4153
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4154
        if err != nil {
×
4155
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4156
                        "fields: %w", err)
×
4157
        }
×
4158

4159
        var msgFlags lnwire.ChanUpdateMsgFlags
×
4160
        if dbPolicy.MaxHtlcMsat.Valid {
×
4161
                msgFlags |= lnwire.ChanUpdateRequiredMaxHtlc
×
4162
        }
×
4163

4164
        var chanFlags lnwire.ChanUpdateChanFlags
×
4165
        if !isNode1 {
×
4166
                chanFlags |= lnwire.ChanUpdateDirection
×
4167
        }
×
4168
        if dbPolicy.Disabled.Bool {
×
4169
                chanFlags |= lnwire.ChanUpdateDisabled
×
4170
        }
×
4171

4172
        var inboundFee fn.Option[lnwire.Fee]
×
4173
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
4174
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4175

×
4176
                inboundFee = fn.Some(lnwire.Fee{
×
4177
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4178
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4179
                })
×
4180
        }
×
4181

4182
        return &models.ChannelEdgePolicy{
×
4183
                SigBytes:  dbPolicy.Signature,
×
4184
                ChannelID: channelID,
×
4185
                LastUpdate: time.Unix(
×
4186
                        dbPolicy.LastUpdate.Int64, 0,
×
4187
                ),
×
4188
                MessageFlags:  msgFlags,
×
4189
                ChannelFlags:  chanFlags,
×
4190
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
4191
                MinHTLC: lnwire.MilliSatoshi(
×
4192
                        dbPolicy.MinHtlcMsat,
×
4193
                ),
×
4194
                MaxHTLC: lnwire.MilliSatoshi(
×
4195
                        dbPolicy.MaxHtlcMsat.Int64,
×
4196
                ),
×
4197
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4198
                        dbPolicy.BaseFeeMsat,
×
4199
                ),
×
4200
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4201
                ToNode:                    toNode,
×
4202
                InboundFee:                inboundFee,
×
4203
                ExtraOpaqueData:           recs,
×
4204
        }, nil
×
4205
}
4206

4207
// buildNodes builds the models.LightningNode instances for the
4208
// given row which is expected to be a sqlc type that contains node information.
4209
func buildNodes(ctx context.Context, db SQLQueries, dbNode1,
4210
        dbNode2 sqlc.Node) (*models.LightningNode, *models.LightningNode,
4211
        error) {
×
4212

×
4213
        node1, err := buildNode(ctx, db, &dbNode1)
×
4214
        if err != nil {
×
4215
                return nil, nil, err
×
4216
        }
×
4217

4218
        node2, err := buildNode(ctx, db, &dbNode2)
×
4219
        if err != nil {
×
4220
                return nil, nil, err
×
4221
        }
×
4222

4223
        return node1, node2, nil
×
4224
}
4225

4226
// extractChannelPolicies extracts the sqlc.ChannelPolicy records from the give
4227
// row which is expected to be a sqlc type that contains channel policy
4228
// information. It returns two policies, which may be nil if the policy
4229
// information is not present in the row.
4230
//
4231
//nolint:ll,dupl,funlen
4232
func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
4233
        error) {
×
4234

×
4235
        var policy1, policy2 *sqlc.ChannelPolicy
×
4236
        switch r := row.(type) {
×
4237
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
4238
                if r.Policy1ID.Valid {
×
4239
                        policy1 = &sqlc.ChannelPolicy{
×
4240
                                ID:                      r.Policy1ID.Int64,
×
4241
                                Version:                 r.Policy1Version.Int16,
×
4242
                                ChannelID:               r.Channel.ID,
×
4243
                                NodeID:                  r.Policy1NodeID.Int64,
×
4244
                                Timelock:                r.Policy1Timelock.Int32,
×
4245
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4246
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4247
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4248
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4249
                                LastUpdate:              r.Policy1LastUpdate,
×
4250
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4251
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4252
                                Disabled:                r.Policy1Disabled,
×
4253
                                Signature:               r.Policy1Signature,
×
4254
                        }
×
4255
                }
×
4256
                if r.Policy2ID.Valid {
×
4257
                        policy2 = &sqlc.ChannelPolicy{
×
4258
                                ID:                      r.Policy2ID.Int64,
×
4259
                                Version:                 r.Policy2Version.Int16,
×
4260
                                ChannelID:               r.Channel.ID,
×
4261
                                NodeID:                  r.Policy2NodeID.Int64,
×
4262
                                Timelock:                r.Policy2Timelock.Int32,
×
4263
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4264
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4265
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4266
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4267
                                LastUpdate:              r.Policy2LastUpdate,
×
4268
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4269
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4270
                                Disabled:                r.Policy2Disabled,
×
4271
                                Signature:               r.Policy2Signature,
×
4272
                        }
×
4273
                }
×
4274

4275
                return policy1, policy2, nil
×
4276

4277
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4278
                if r.Policy1ID.Valid {
×
4279
                        policy1 = &sqlc.ChannelPolicy{
×
4280
                                ID:                      r.Policy1ID.Int64,
×
4281
                                Version:                 r.Policy1Version.Int16,
×
4282
                                ChannelID:               r.Channel.ID,
×
4283
                                NodeID:                  r.Policy1NodeID.Int64,
×
4284
                                Timelock:                r.Policy1Timelock.Int32,
×
4285
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4286
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4287
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4288
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4289
                                LastUpdate:              r.Policy1LastUpdate,
×
4290
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4291
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4292
                                Disabled:                r.Policy1Disabled,
×
4293
                                Signature:               r.Policy1Signature,
×
4294
                        }
×
4295
                }
×
4296
                if r.Policy2ID.Valid {
×
4297
                        policy2 = &sqlc.ChannelPolicy{
×
4298
                                ID:                      r.Policy2ID.Int64,
×
4299
                                Version:                 r.Policy2Version.Int16,
×
4300
                                ChannelID:               r.Channel.ID,
×
4301
                                NodeID:                  r.Policy2NodeID.Int64,
×
4302
                                Timelock:                r.Policy2Timelock.Int32,
×
4303
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4304
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4305
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4306
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4307
                                LastUpdate:              r.Policy2LastUpdate,
×
4308
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4309
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4310
                                Disabled:                r.Policy2Disabled,
×
4311
                                Signature:               r.Policy2Signature,
×
4312
                        }
×
4313
                }
×
4314

4315
                return policy1, policy2, nil
×
4316

4317
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4318
                if r.Policy1ID.Valid {
×
4319
                        policy1 = &sqlc.ChannelPolicy{
×
4320
                                ID:                      r.Policy1ID.Int64,
×
4321
                                Version:                 r.Policy1Version.Int16,
×
4322
                                ChannelID:               r.Channel.ID,
×
4323
                                NodeID:                  r.Policy1NodeID.Int64,
×
4324
                                Timelock:                r.Policy1Timelock.Int32,
×
4325
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4326
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4327
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4328
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4329
                                LastUpdate:              r.Policy1LastUpdate,
×
4330
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4331
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4332
                                Disabled:                r.Policy1Disabled,
×
4333
                                Signature:               r.Policy1Signature,
×
4334
                        }
×
4335
                }
×
4336
                if r.Policy2ID.Valid {
×
4337
                        policy2 = &sqlc.ChannelPolicy{
×
4338
                                ID:                      r.Policy2ID.Int64,
×
4339
                                Version:                 r.Policy2Version.Int16,
×
4340
                                ChannelID:               r.Channel.ID,
×
4341
                                NodeID:                  r.Policy2NodeID.Int64,
×
4342
                                Timelock:                r.Policy2Timelock.Int32,
×
4343
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4344
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4345
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4346
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4347
                                LastUpdate:              r.Policy2LastUpdate,
×
4348
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4349
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4350
                                Disabled:                r.Policy2Disabled,
×
4351
                                Signature:               r.Policy2Signature,
×
4352
                        }
×
4353
                }
×
4354

4355
                return policy1, policy2, nil
×
4356

4357
        case sqlc.ListChannelsByNodeIDRow:
×
4358
                if r.Policy1ID.Valid {
×
4359
                        policy1 = &sqlc.ChannelPolicy{
×
4360
                                ID:                      r.Policy1ID.Int64,
×
4361
                                Version:                 r.Policy1Version.Int16,
×
4362
                                ChannelID:               r.Channel.ID,
×
4363
                                NodeID:                  r.Policy1NodeID.Int64,
×
4364
                                Timelock:                r.Policy1Timelock.Int32,
×
4365
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4366
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4367
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4368
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4369
                                LastUpdate:              r.Policy1LastUpdate,
×
4370
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4371
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4372
                                Disabled:                r.Policy1Disabled,
×
4373
                                Signature:               r.Policy1Signature,
×
4374
                        }
×
4375
                }
×
4376
                if r.Policy2ID.Valid {
×
4377
                        policy2 = &sqlc.ChannelPolicy{
×
4378
                                ID:                      r.Policy2ID.Int64,
×
4379
                                Version:                 r.Policy2Version.Int16,
×
4380
                                ChannelID:               r.Channel.ID,
×
4381
                                NodeID:                  r.Policy2NodeID.Int64,
×
4382
                                Timelock:                r.Policy2Timelock.Int32,
×
4383
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4384
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4385
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4386
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4387
                                LastUpdate:              r.Policy2LastUpdate,
×
4388
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4389
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4390
                                Disabled:                r.Policy2Disabled,
×
4391
                                Signature:               r.Policy2Signature,
×
4392
                        }
×
4393
                }
×
4394

4395
                return policy1, policy2, nil
×
4396

4397
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4398
                if r.Policy1ID.Valid {
×
4399
                        policy1 = &sqlc.ChannelPolicy{
×
4400
                                ID:                      r.Policy1ID.Int64,
×
4401
                                Version:                 r.Policy1Version.Int16,
×
4402
                                ChannelID:               r.Channel.ID,
×
4403
                                NodeID:                  r.Policy1NodeID.Int64,
×
4404
                                Timelock:                r.Policy1Timelock.Int32,
×
4405
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4406
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4407
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4408
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4409
                                LastUpdate:              r.Policy1LastUpdate,
×
4410
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4411
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4412
                                Disabled:                r.Policy1Disabled,
×
4413
                                Signature:               r.Policy1Signature,
×
4414
                        }
×
4415
                }
×
4416
                if r.Policy2ID.Valid {
×
4417
                        policy2 = &sqlc.ChannelPolicy{
×
4418
                                ID:                      r.Policy2ID.Int64,
×
4419
                                Version:                 r.Policy2Version.Int16,
×
4420
                                ChannelID:               r.Channel.ID,
×
4421
                                NodeID:                  r.Policy2NodeID.Int64,
×
4422
                                Timelock:                r.Policy2Timelock.Int32,
×
4423
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4424
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4425
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4426
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4427
                                LastUpdate:              r.Policy2LastUpdate,
×
4428
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4429
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4430
                                Disabled:                r.Policy2Disabled,
×
4431
                                Signature:               r.Policy2Signature,
×
4432
                        }
×
4433
                }
×
4434

4435
                return policy1, policy2, nil
×
4436
        default:
×
4437
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
4438
                        "extractChannelPolicies: %T", r)
×
4439
        }
4440
}
4441

4442
// channelIDToBytes converts a channel ID (SCID) to a byte array
4443
// representation.
4444
func channelIDToBytes(channelID uint64) [8]byte {
×
4445
        var chanIDB [8]byte
×
4446
        byteOrder.PutUint64(chanIDB[:], channelID)
×
4447

×
4448
        return chanIDB
×
4449
}
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc