• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 15758703842

19 Jun 2025 01:11PM UTC coverage: 67.933% (-0.2%) from 68.161%
15758703842

Pull #9936

github

web-flow
Merge 2ee302abf into e0a9705d5
Pull Request #9936: [12] graph/db: Implement more graph SQLStore methods

0 of 651 new or added lines in 2 files covered. (0.0%)

71 existing lines in 21 files now uncovered.

134493 of 197979 relevant lines covered (67.93%)

22115.6 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "math"
11
        "net"
12
        "sort"
13
        "strconv"
14
        "sync"
15
        "time"
16

17
        "github.com/btcsuite/btcd/btcec/v2"
18
        "github.com/btcsuite/btcd/btcutil"
19
        "github.com/btcsuite/btcd/chaincfg/chainhash"
20
        "github.com/btcsuite/btcd/wire"
21
        "github.com/lightningnetwork/lnd/batch"
22
        "github.com/lightningnetwork/lnd/fn/v2"
23
        "github.com/lightningnetwork/lnd/graph/db/models"
24
        "github.com/lightningnetwork/lnd/lnwire"
25
        "github.com/lightningnetwork/lnd/routing/route"
26
        "github.com/lightningnetwork/lnd/sqldb"
27
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
28
        "github.com/lightningnetwork/lnd/tlv"
29
        "github.com/lightningnetwork/lnd/tor"
30
)
31

32
// pageSize is the limit for the number of records that can be returned
33
// in a paginated query. This can be tuned after some benchmarks.
34
const pageSize = 2000
35

36
// ProtocolVersion is an enum that defines the gossip protocol version of a
37
// message.
38
type ProtocolVersion uint8
39

40
const (
41
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
42
        ProtocolV1 ProtocolVersion = 1
43
)
44

45
// String returns a string representation of the protocol version.
46
func (v ProtocolVersion) String() string {
×
47
        return fmt.Sprintf("V%d", v)
×
48
}
×
49

50
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
51
// execute queries against the SQL graph tables.
52
//
53
//nolint:ll,interfacebloat
54
type SQLQueries interface {
55
        /*
56
                Node queries.
57
        */
58
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
59
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.Node, error)
60
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
61
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.Node, error)
62
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.Node, error)
63
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
64
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
65

66
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.NodeExtraType, error)
67
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
68
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
69

70
        InsertNodeAddress(ctx context.Context, arg sqlc.InsertNodeAddressParams) error
71
        GetNodeAddressesByPubKey(ctx context.Context, arg sqlc.GetNodeAddressesByPubKeyParams) ([]sqlc.GetNodeAddressesByPubKeyRow, error)
72
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
73

74
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
75
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.NodeFeature, error)
76
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
77
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
78

79
        /*
80
                Source node queries.
81
        */
82
        AddSourceNode(ctx context.Context, nodeID int64) error
83
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
84

85
        /*
86
                Channel queries.
87
        */
88
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
89
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.Channel, error)
90
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
91
        GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]sqlc.GetChannelFeaturesAndExtrasRow, error)
92
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
93
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
94
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
95
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
96
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.Channel, error)
97

98
        CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
99
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
100

101
        /*
102
                Channel Policy table queries.
103
        */
104
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
105
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.ChannelPolicy, error)
106

107
        InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error
108
        GetChannelPolicyExtraTypes(ctx context.Context, arg sqlc.GetChannelPolicyExtraTypesParams) ([]sqlc.GetChannelPolicyExtraTypesRow, error)
109
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
110
}
111

112
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
113
// database operations.
114
type BatchedSQLQueries interface {
115
        SQLQueries
116
        sqldb.BatchedTx[SQLQueries]
117
}
118

119
// SQLStore is an implementation of the V1Store interface that uses a SQL
120
// database as the backend.
121
//
122
// NOTE: currently, this temporarily embeds the KVStore struct so that we can
123
// implement the V1Store interface incrementally. For any method not
124
// implemented,  things will fall back to the KVStore. This is ONLY the case
125
// for the time being while this struct is purely used in unit tests only.
126
type SQLStore struct {
127
        cfg *SQLStoreConfig
128
        db  BatchedSQLQueries
129

130
        // cacheMu guards all caches (rejectCache and chanCache). If
131
        // this mutex will be acquired at the same time as the DB mutex then
132
        // the cacheMu MUST be acquired first to prevent deadlock.
133
        cacheMu     sync.RWMutex
134
        rejectCache *rejectCache
135
        chanCache   *channelCache
136

137
        chanScheduler batch.Scheduler[SQLQueries]
138
        nodeScheduler batch.Scheduler[SQLQueries]
139

140
        srcNodes  map[ProtocolVersion]*srcNodeInfo
141
        srcNodeMu sync.Mutex
142

143
        // Temporary fall-back to the KVStore so that we can implement the
144
        // interface incrementally.
145
        *KVStore
146
}
147

148
// A compile-time assertion to ensure that SQLStore implements the V1Store
149
// interface.
150
var _ V1Store = (*SQLStore)(nil)
151

152
// SQLStoreConfig holds the configuration for the SQLStore.
153
type SQLStoreConfig struct {
154
        // ChainHash is the genesis hash for the chain that all the gossip
155
        // messages in this store are aimed at.
156
        ChainHash chainhash.Hash
157
}
158

159
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
160
// storage backend.
161
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries, kvStore *KVStore,
162
        options ...StoreOptionModifier) (*SQLStore, error) {
×
163

×
164
        opts := DefaultOptions()
×
165
        for _, o := range options {
×
166
                o(opts)
×
167
        }
×
168

169
        if opts.NoMigration {
×
170
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
171
                        "supported for SQL stores")
×
172
        }
×
173

174
        s := &SQLStore{
×
175
                cfg:         cfg,
×
176
                db:          db,
×
177
                KVStore:     kvStore,
×
178
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
179
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
180
                srcNodes:    make(map[ProtocolVersion]*srcNodeInfo),
×
181
        }
×
182

×
183
        s.chanScheduler = batch.NewTimeScheduler(
×
184
                db, &s.cacheMu, opts.BatchCommitInterval,
×
185
        )
×
186
        s.nodeScheduler = batch.NewTimeScheduler(
×
187
                db, nil, opts.BatchCommitInterval,
×
188
        )
×
189

×
190
        return s, nil
×
191
}
192

193
// AddLightningNode adds a vertex/node to the graph database. If the node is not
194
// in the database from before, this will add a new, unconnected one to the
195
// graph. If it is present from before, this will update that node's
196
// information.
197
//
198
// NOTE: part of the V1Store interface.
199
func (s *SQLStore) AddLightningNode(ctx context.Context,
200
        node *models.LightningNode, opts ...batch.SchedulerOption) error {
×
201

×
202
        r := &batch.Request[SQLQueries]{
×
203
                Opts: batch.NewSchedulerOptions(opts...),
×
204
                Do: func(queries SQLQueries) error {
×
205
                        _, err := upsertNode(ctx, queries, node)
×
206
                        return err
×
207
                },
×
208
        }
209

210
        return s.nodeScheduler.Execute(ctx, r)
×
211
}
212

213
// FetchLightningNode attempts to look up a target node by its identity public
214
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
215
// returned.
216
//
217
// NOTE: part of the V1Store interface.
218
func (s *SQLStore) FetchLightningNode(ctx context.Context,
219
        pubKey route.Vertex) (*models.LightningNode, error) {
×
220

×
221
        var node *models.LightningNode
×
222
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
223
                var err error
×
224
                _, node, err = getNodeByPubKey(ctx, db, pubKey)
×
225

×
226
                return err
×
227
        }, sqldb.NoOpReset)
×
228
        if err != nil {
×
229
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
230
        }
×
231

232
        return node, nil
×
233
}
234

235
// HasLightningNode determines if the graph has a vertex identified by the
236
// target node identity public key. If the node exists in the database, a
237
// timestamp of when the data for the node was lasted updated is returned along
238
// with a true boolean. Otherwise, an empty time.Time is returned with a false
239
// boolean.
240
//
241
// NOTE: part of the V1Store interface.
242
func (s *SQLStore) HasLightningNode(ctx context.Context,
243
        pubKey [33]byte) (time.Time, bool, error) {
×
244

×
245
        var (
×
246
                exists     bool
×
247
                lastUpdate time.Time
×
248
        )
×
249
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
250
                dbNode, err := db.GetNodeByPubKey(
×
251
                        ctx, sqlc.GetNodeByPubKeyParams{
×
252
                                Version: int16(ProtocolV1),
×
253
                                PubKey:  pubKey[:],
×
254
                        },
×
255
                )
×
256
                if errors.Is(err, sql.ErrNoRows) {
×
257
                        return nil
×
258
                } else if err != nil {
×
259
                        return fmt.Errorf("unable to fetch node: %w", err)
×
260
                }
×
261

262
                exists = true
×
263

×
264
                if dbNode.LastUpdate.Valid {
×
265
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
266
                }
×
267

268
                return nil
×
269
        }, sqldb.NoOpReset)
270
        if err != nil {
×
271
                return time.Time{}, false,
×
272
                        fmt.Errorf("unable to fetch node: %w", err)
×
273
        }
×
274

275
        return lastUpdate, exists, nil
×
276
}
277

278
// AddrsForNode returns all known addresses for the target node public key
279
// that the graph DB is aware of. The returned boolean indicates if the
280
// given node is unknown to the graph DB or not.
281
//
282
// NOTE: part of the V1Store interface.
283
func (s *SQLStore) AddrsForNode(ctx context.Context,
284
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
285

×
286
        var (
×
287
                addresses []net.Addr
×
288
                known     bool
×
289
        )
×
290
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
291
                var err error
×
292
                known, addresses, err = getNodeAddresses(
×
293
                        ctx, db, nodePub.SerializeCompressed(),
×
294
                )
×
295
                if err != nil {
×
296
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
297
                                err)
×
298
                }
×
299

300
                return nil
×
301
        }, sqldb.NoOpReset)
302
        if err != nil {
×
303
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
304
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
305
        }
×
306

307
        return known, addresses, nil
×
308
}
309

310
// DeleteLightningNode starts a new database transaction to remove a vertex/node
311
// from the database according to the node's public key.
312
//
313
// NOTE: part of the V1Store interface.
314
func (s *SQLStore) DeleteLightningNode(ctx context.Context,
315
        pubKey route.Vertex) error {
×
316

×
317
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
318
                res, err := db.DeleteNodeByPubKey(
×
319
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
320
                                Version: int16(ProtocolV1),
×
321
                                PubKey:  pubKey[:],
×
322
                        },
×
323
                )
×
324
                if err != nil {
×
325
                        return err
×
326
                }
×
327

328
                rows, err := res.RowsAffected()
×
329
                if err != nil {
×
330
                        return err
×
331
                }
×
332

333
                if rows == 0 {
×
334
                        return ErrGraphNodeNotFound
×
335
                } else if rows > 1 {
×
336
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
337
                }
×
338

339
                return err
×
340
        }, sqldb.NoOpReset)
341
        if err != nil {
×
342
                return fmt.Errorf("unable to delete node: %w", err)
×
343
        }
×
344

345
        return nil
×
346
}
347

348
// FetchNodeFeatures returns the features of the given node. If no features are
349
// known for the node, an empty feature vector is returned.
350
//
351
// NOTE: this is part of the graphdb.NodeTraverser interface.
352
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
353
        *lnwire.FeatureVector, error) {
×
354

×
355
        ctx := context.TODO()
×
356

×
357
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
358
}
×
359

360
// LookupAlias attempts to return the alias as advertised by the target node.
361
//
362
// NOTE: part of the V1Store interface.
363
func (s *SQLStore) LookupAlias(pub *btcec.PublicKey) (string, error) {
×
364
        var (
×
365
                ctx   = context.TODO()
×
366
                alias string
×
367
        )
×
368
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
369
                dbNode, err := db.GetNodeByPubKey(
×
370
                        ctx, sqlc.GetNodeByPubKeyParams{
×
371
                                Version: int16(ProtocolV1),
×
372
                                PubKey:  pub.SerializeCompressed(),
×
373
                        },
×
374
                )
×
375
                if errors.Is(err, sql.ErrNoRows) {
×
376
                        return ErrNodeAliasNotFound
×
377
                } else if err != nil {
×
378
                        return fmt.Errorf("unable to fetch node: %w", err)
×
379
                }
×
380

381
                if !dbNode.Alias.Valid {
×
382
                        return ErrNodeAliasNotFound
×
383
                }
×
384

385
                alias = dbNode.Alias.String
×
386

×
387
                return nil
×
388
        }, sqldb.NoOpReset)
389
        if err != nil {
×
390
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
391
        }
×
392

393
        return alias, nil
×
394
}
395

396
// SourceNode returns the source node of the graph. The source node is treated
397
// as the center node within a star-graph. This method may be used to kick off
398
// a path finding algorithm in order to explore the reachability of another
399
// node based off the source node.
400
//
401
// NOTE: part of the V1Store interface.
402
func (s *SQLStore) SourceNode() (*models.LightningNode, error) {
×
403
        ctx := context.TODO()
×
404

×
405
        var node *models.LightningNode
×
406
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
407
                _, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
408
                if err != nil {
×
409
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
410
                                err)
×
411
                }
×
412

413
                _, node, err = getNodeByPubKey(ctx, db, nodePub)
×
414

×
415
                return err
×
416
        }, sqldb.NoOpReset)
417
        if err != nil {
×
418
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
419
        }
×
420

421
        return node, nil
×
422
}
423

424
// SetSourceNode sets the source node within the graph database. The source
425
// node is to be used as the center of a star-graph within path finding
426
// algorithms.
427
//
428
// NOTE: part of the V1Store interface.
429
func (s *SQLStore) SetSourceNode(node *models.LightningNode) error {
×
430
        ctx := context.TODO()
×
431

×
432
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
433
                id, err := upsertNode(ctx, db, node)
×
434
                if err != nil {
×
435
                        return fmt.Errorf("unable to upsert source node: %w",
×
436
                                err)
×
437
                }
×
438

439
                // Make sure that if a source node for this version is already
440
                // set, then the ID is the same as the one we are about to set.
441
                dbSourceNodeID, _, err := s.getSourceNode(ctx, db, ProtocolV1)
×
442
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
443
                        return fmt.Errorf("unable to fetch source node: %w",
×
444
                                err)
×
445
                } else if err == nil {
×
446
                        if dbSourceNodeID != id {
×
447
                                return fmt.Errorf("v1 source node already "+
×
448
                                        "set to a different node: %d vs %d",
×
449
                                        dbSourceNodeID, id)
×
450
                        }
×
451

452
                        return nil
×
453
                }
454

455
                return db.AddSourceNode(ctx, id)
×
456
        }, sqldb.NoOpReset)
457
}
458

459
// NodeUpdatesInHorizon returns all the known lightning node which have an
460
// update timestamp within the passed range. This method can be used by two
461
// nodes to quickly determine if they have the same set of up to date node
462
// announcements.
463
//
464
// NOTE: This is part of the V1Store interface.
465
func (s *SQLStore) NodeUpdatesInHorizon(startTime,
466
        endTime time.Time) ([]models.LightningNode, error) {
×
467

×
468
        ctx := context.TODO()
×
469

×
470
        var nodes []models.LightningNode
×
471
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
472
                dbNodes, err := db.GetNodesByLastUpdateRange(
×
473
                        ctx, sqlc.GetNodesByLastUpdateRangeParams{
×
474
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
475
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
476
                        },
×
477
                )
×
478
                if err != nil {
×
479
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
480
                }
×
481

482
                for _, dbNode := range dbNodes {
×
483
                        node, err := buildNode(ctx, db, &dbNode)
×
484
                        if err != nil {
×
485
                                return fmt.Errorf("unable to build node: %w",
×
486
                                        err)
×
487
                        }
×
488

489
                        nodes = append(nodes, *node)
×
490
                }
491

492
                return nil
×
493
        }, sqldb.NoOpReset)
494
        if err != nil {
×
495
                return nil, fmt.Errorf("unable to fetch nodes: %w", err)
×
496
        }
×
497

498
        return nodes, nil
×
499
}
500

501
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
502
// undirected edge from the two target nodes are created. The information stored
503
// denotes the static attributes of the channel, such as the channelID, the keys
504
// involved in creation of the channel, and the set of features that the channel
505
// supports. The chanPoint and chanID are used to uniquely identify the edge
506
// globally within the database.
507
//
508
// NOTE: part of the V1Store interface.
509
func (s *SQLStore) AddChannelEdge(edge *models.ChannelEdgeInfo,
510
        opts ...batch.SchedulerOption) error {
×
511

×
512
        ctx := context.TODO()
×
513

×
514
        var alreadyExists bool
×
515
        r := &batch.Request[SQLQueries]{
×
516
                Opts: batch.NewSchedulerOptions(opts...),
×
517
                Reset: func() {
×
518
                        alreadyExists = false
×
519
                },
×
520
                Do: func(tx SQLQueries) error {
×
521
                        err := insertChannel(ctx, tx, edge)
×
522

×
523
                        // Silence ErrEdgeAlreadyExist so that the batch can
×
524
                        // succeed, but propagate the error via local state.
×
525
                        if errors.Is(err, ErrEdgeAlreadyExist) {
×
526
                                alreadyExists = true
×
527
                                return nil
×
528
                        }
×
529

530
                        return err
×
531
                },
532
                OnCommit: func(err error) error {
×
533
                        switch {
×
534
                        case err != nil:
×
535
                                return err
×
536
                        case alreadyExists:
×
537
                                return ErrEdgeAlreadyExist
×
538
                        default:
×
539
                                s.rejectCache.remove(edge.ChannelID)
×
540
                                s.chanCache.remove(edge.ChannelID)
×
541
                                return nil
×
542
                        }
543
                },
544
        }
545

546
        return s.chanScheduler.Execute(ctx, r)
×
547
}
548

549
// HighestChanID returns the "highest" known channel ID in the channel graph.
550
// This represents the "newest" channel from the PoV of the chain. This method
551
// can be used by peers to quickly determine if their graphs are in sync.
552
//
553
// NOTE: This is part of the V1Store interface.
554
func (s *SQLStore) HighestChanID() (uint64, error) {
×
555
        ctx := context.TODO()
×
556

×
557
        var highestChanID uint64
×
558
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
559
                chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
×
560
                if errors.Is(err, sql.ErrNoRows) {
×
561
                        return nil
×
562
                } else if err != nil {
×
563
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
564
                                err)
×
565
                }
×
566

567
                highestChanID = byteOrder.Uint64(chanID)
×
568

×
569
                return nil
×
570
        }, sqldb.NoOpReset)
571
        if err != nil {
×
572
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
573
        }
×
574

575
        return highestChanID, nil
×
576
}
577

578
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
579
// within the database for the referenced channel. The `flags` attribute within
580
// the ChannelEdgePolicy determines which of the directed edges are being
581
// updated. If the flag is 1, then the first node's information is being
582
// updated, otherwise it's the second node's information. The node ordering is
583
// determined by the lexicographical ordering of the identity public keys of the
584
// nodes on either side of the channel.
585
//
586
// NOTE: part of the V1Store interface.
587
func (s *SQLStore) UpdateEdgePolicy(edge *models.ChannelEdgePolicy,
588
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
589

×
590
        ctx := context.TODO()
×
591

×
592
        var (
×
593
                isUpdate1    bool
×
594
                edgeNotFound bool
×
595
                from, to     route.Vertex
×
596
        )
×
597

×
598
        r := &batch.Request[SQLQueries]{
×
599
                Opts: batch.NewSchedulerOptions(opts...),
×
600
                Reset: func() {
×
601
                        isUpdate1 = false
×
602
                        edgeNotFound = false
×
603
                },
×
604
                Do: func(tx SQLQueries) error {
×
605
                        var err error
×
606
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
607
                                ctx, tx, edge,
×
608
                        )
×
609
                        if err != nil {
×
610
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
611
                        }
×
612

613
                        // Silence ErrEdgeNotFound so that the batch can
614
                        // succeed, but propagate the error via local state.
615
                        if errors.Is(err, ErrEdgeNotFound) {
×
616
                                edgeNotFound = true
×
617
                                return nil
×
618
                        }
×
619

620
                        return err
×
621
                },
622
                OnCommit: func(err error) error {
×
623
                        switch {
×
624
                        case err != nil:
×
625
                                return err
×
626
                        case edgeNotFound:
×
627
                                return ErrEdgeNotFound
×
628
                        default:
×
629
                                s.updateEdgeCache(edge, isUpdate1)
×
630
                                return nil
×
631
                        }
632
                },
633
        }
634

635
        err := s.chanScheduler.Execute(ctx, r)
×
636

×
637
        return from, to, err
×
638
}
639

640
// updateEdgeCache updates our reject and channel caches with the new
641
// edge policy information.
642
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
643
        isUpdate1 bool) {
×
644

×
645
        // If an entry for this channel is found in reject cache, we'll modify
×
646
        // the entry with the updated timestamp for the direction that was just
×
647
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
648
        // during the next query for this edge.
×
649
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
650
                if isUpdate1 {
×
651
                        entry.upd1Time = e.LastUpdate.Unix()
×
652
                } else {
×
653
                        entry.upd2Time = e.LastUpdate.Unix()
×
654
                }
×
655
                s.rejectCache.insert(e.ChannelID, entry)
×
656
        }
657

658
        // If an entry for this channel is found in channel cache, we'll modify
659
        // the entry with the updated policy for the direction that was just
660
        // written. If the edge doesn't exist, we'll defer loading the info and
661
        // policies and lazily read from disk during the next query.
662
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
663
                if isUpdate1 {
×
664
                        channel.Policy1 = e
×
665
                } else {
×
666
                        channel.Policy2 = e
×
667
                }
×
668
                s.chanCache.insert(e.ChannelID, channel)
×
669
        }
670
}
671

672
// ForEachSourceNodeChannel iterates through all channels of the source node,
673
// executing the passed callback on each. The call-back is provided with the
674
// channel's outpoint, whether we have a policy for the channel and the channel
675
// peer's node information.
676
//
677
// NOTE: part of the V1Store interface.
678
func (s *SQLStore) ForEachSourceNodeChannel(cb func(chanPoint wire.OutPoint,
679
        havePolicy bool, otherNode *models.LightningNode) error) error {
×
680

×
681
        var ctx = context.TODO()
×
682

×
683
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
684
                nodeID, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
685
                if err != nil {
×
686
                        return fmt.Errorf("unable to fetch source node: %w",
×
687
                                err)
×
688
                }
×
689

690
                return forEachNodeChannel(
×
691
                        ctx, db, s.cfg.ChainHash, nodeID,
×
692
                        func(info *models.ChannelEdgeInfo,
×
693
                                outPolicy *models.ChannelEdgePolicy,
×
694
                                _ *models.ChannelEdgePolicy) error {
×
695

×
696
                                // Fetch the other node.
×
697
                                var (
×
698
                                        otherNodePub [33]byte
×
699
                                        node1        = info.NodeKey1Bytes
×
700
                                        node2        = info.NodeKey2Bytes
×
701
                                )
×
702
                                switch {
×
703
                                case bytes.Equal(node1[:], nodePub[:]):
×
704
                                        otherNodePub = node2
×
705
                                case bytes.Equal(node2[:], nodePub[:]):
×
706
                                        otherNodePub = node1
×
707
                                default:
×
708
                                        return fmt.Errorf("node not " +
×
709
                                                "participating in this channel")
×
710
                                }
711

712
                                _, otherNode, err := getNodeByPubKey(
×
713
                                        ctx, db, otherNodePub,
×
714
                                )
×
715
                                if err != nil {
×
716
                                        return fmt.Errorf("unable to fetch "+
×
717
                                                "other node(%x): %w",
×
718
                                                otherNodePub, err)
×
719
                                }
×
720

721
                                return cb(
×
722
                                        info.ChannelPoint, outPolicy != nil,
×
723
                                        otherNode,
×
724
                                )
×
725
                        },
726
                )
727
        }, sqldb.NoOpReset)
728
}
729

730
// ForEachNode iterates through all the stored vertices/nodes in the graph,
731
// executing the passed callback with each node encountered. If the callback
732
// returns an error, then the transaction is aborted and the iteration stops
733
// early. Any operations performed on the NodeTx passed to the call-back are
734
// executed under the same read transaction and so, methods on the NodeTx object
735
// _MUST_ only be called from within the call-back.
736
//
737
// NOTE: part of the V1Store interface.
738
func (s *SQLStore) ForEachNode(cb func(tx NodeRTx) error) error {
×
739
        var (
×
740
                ctx          = context.TODO()
×
741
                lastID int64 = 0
×
742
        )
×
743

×
744
        handleNode := func(db SQLQueries, dbNode sqlc.Node) error {
×
745
                node, err := buildNode(ctx, db, &dbNode)
×
746
                if err != nil {
×
747
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
748
                                dbNode.ID, err)
×
749
                }
×
750

751
                err = cb(
×
752
                        newSQLGraphNodeTx(db, s.cfg.ChainHash, dbNode.ID, node),
×
753
                )
×
754
                if err != nil {
×
755
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
756
                                dbNode.ID, err)
×
757
                }
×
758

759
                return nil
×
760
        }
761

762
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
763
                for {
×
764
                        nodes, err := db.ListNodesPaginated(
×
765
                                ctx, sqlc.ListNodesPaginatedParams{
×
766
                                        Version: int16(ProtocolV1),
×
767
                                        ID:      lastID,
×
768
                                        Limit:   pageSize,
×
769
                                },
×
770
                        )
×
771
                        if err != nil {
×
772
                                return fmt.Errorf("unable to fetch nodes: %w",
×
773
                                        err)
×
774
                        }
×
775

776
                        if len(nodes) == 0 {
×
777
                                break
×
778
                        }
779

780
                        for _, dbNode := range nodes {
×
781
                                err = handleNode(db, dbNode)
×
782
                                if err != nil {
×
783
                                        return err
×
784
                                }
×
785

786
                                lastID = dbNode.ID
×
787
                        }
788
                }
789

790
                return nil
×
791
        }, sqldb.NoOpReset)
792
}
793

794
// sqlGraphNodeTx is an implementation of the NodeRTx interface backed by the
795
// SQLStore and a SQL transaction.
796
type sqlGraphNodeTx struct {
797
        db    SQLQueries
798
        id    int64
799
        node  *models.LightningNode
800
        chain chainhash.Hash
801
}
802

803
// A compile-time constraint to ensure sqlGraphNodeTx implements the NodeRTx
804
// interface.
805
var _ NodeRTx = (*sqlGraphNodeTx)(nil)
806

807
func newSQLGraphNodeTx(db SQLQueries, chain chainhash.Hash,
808
        id int64, node *models.LightningNode) *sqlGraphNodeTx {
×
809

×
810
        return &sqlGraphNodeTx{
×
811
                db:    db,
×
812
                chain: chain,
×
813
                id:    id,
×
814
                node:  node,
×
815
        }
×
816
}
×
817

818
// Node returns the raw information of the node.
819
//
820
// NOTE: This is a part of the NodeRTx interface.
821
func (s *sqlGraphNodeTx) Node() *models.LightningNode {
×
822
        return s.node
×
823
}
×
824

825
// ForEachChannel can be used to iterate over the node's channels under the same
826
// transaction used to fetch the node.
827
//
828
// NOTE: This is a part of the NodeRTx interface.
829
func (s *sqlGraphNodeTx) ForEachChannel(cb func(*models.ChannelEdgeInfo,
830
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
×
831

×
832
        ctx := context.TODO()
×
833

×
834
        return forEachNodeChannel(ctx, s.db, s.chain, s.id, cb)
×
835
}
×
836

837
// FetchNode fetches the node with the given pub key under the same transaction
838
// used to fetch the current node. The returned node is also a NodeRTx and any
839
// operations on that NodeRTx will also be done under the same transaction.
840
//
841
// NOTE: This is a part of the NodeRTx interface.
842
func (s *sqlGraphNodeTx) FetchNode(nodePub route.Vertex) (NodeRTx, error) {
×
843
        ctx := context.TODO()
×
844

×
845
        id, node, err := getNodeByPubKey(ctx, s.db, nodePub)
×
846
        if err != nil {
×
847
                return nil, fmt.Errorf("unable to fetch V1 node(%x): %w",
×
848
                        nodePub, err)
×
849
        }
×
850

851
        return newSQLGraphNodeTx(s.db, s.chain, id, node), nil
×
852
}
853

854
// ForEachNodeDirectedChannel iterates through all channels of a given node,
855
// executing the passed callback on the directed edge representing the channel
856
// and its incoming policy. If the callback returns an error, then the iteration
857
// is halted with the error propagated back up to the caller.
858
//
859
// Unknown policies are passed into the callback as nil values.
860
//
861
// NOTE: this is part of the graphdb.NodeTraverser interface.
862
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
863
        cb func(channel *DirectedChannel) error) error {
×
864

×
865
        var ctx = context.TODO()
×
866

×
867
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
868
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
869
        }, sqldb.NoOpReset)
×
870
}
871

872
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
873
// graph, executing the passed callback with each node encountered. If the
874
// callback returns an error, then the transaction is aborted and the iteration
875
// stops early.
876
//
877
// NOTE: This is a part of the V1Store interface.
878
func (s *SQLStore) ForEachNodeCacheable(cb func(route.Vertex,
879
        *lnwire.FeatureVector) error) error {
×
880

×
881
        ctx := context.TODO()
×
882

×
883
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
884
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
885
                        nodePub route.Vertex) error {
×
886

×
887
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
888
                        if err != nil {
×
889
                                return fmt.Errorf("unable to fetch node "+
×
890
                                        "features: %w", err)
×
891
                        }
×
892

893
                        return cb(nodePub, features)
×
894
                })
895
        }, sqldb.NoOpReset)
896
        if err != nil {
×
897
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
898
        }
×
899

900
        return nil
×
901
}
902

903
// ForEachNodeChannel iterates through all channels of the given node,
904
// executing the passed callback with an edge info structure and the policies
905
// of each end of the channel. The first edge policy is the outgoing edge *to*
906
// the connecting node, while the second is the incoming edge *from* the
907
// connecting node. If the callback returns an error, then the iteration is
908
// halted with the error propagated back up to the caller.
909
//
910
// Unknown policies are passed into the callback as nil values.
911
//
912
// NOTE: part of the V1Store interface.
913
func (s *SQLStore) ForEachNodeChannel(nodePub route.Vertex,
914
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
915
                *models.ChannelEdgePolicy) error) error {
×
916

×
917
        var ctx = context.TODO()
×
918

×
919
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
920
                dbNode, err := db.GetNodeByPubKey(
×
921
                        ctx, sqlc.GetNodeByPubKeyParams{
×
922
                                Version: int16(ProtocolV1),
×
923
                                PubKey:  nodePub[:],
×
924
                        },
×
925
                )
×
926
                if errors.Is(err, sql.ErrNoRows) {
×
927
                        return nil
×
928
                } else if err != nil {
×
929
                        return fmt.Errorf("unable to fetch node: %w", err)
×
930
                }
×
931

932
                return forEachNodeChannel(
×
933
                        ctx, db, s.cfg.ChainHash, dbNode.ID, cb,
×
934
                )
×
935
        }, sqldb.NoOpReset)
936
}
937

938
// ChanUpdatesInHorizon returns all the known channel edges which have at least
939
// one edge that has an update timestamp within the specified horizon.
940
//
941
// NOTE: This is part of the V1Store interface.
942
func (s *SQLStore) ChanUpdatesInHorizon(startTime,
NEW
943
        endTime time.Time) ([]ChannelEdge, error) {
×
NEW
944

×
NEW
945
        s.cacheMu.Lock()
×
NEW
946
        defer s.cacheMu.Unlock()
×
NEW
947

×
NEW
948
        var (
×
NEW
949
                ctx = context.TODO()
×
NEW
950
                // To ensure we don't return duplicate ChannelEdges, we'll use an
×
NEW
951
                // additional map to keep track of the edges already seen to prevent
×
NEW
952
                // re-adding it.
×
NEW
953
                edgesSeen    = make(map[uint64]struct{})
×
NEW
954
                edgesToCache = make(map[uint64]ChannelEdge)
×
NEW
955
                edges        []ChannelEdge
×
NEW
956
                hits         int
×
NEW
957
        )
×
NEW
958
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
959
                rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
NEW
960
                        ctx, sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
NEW
961
                                Version: int16(ProtocolV1),
×
NEW
962
                                StartTime: sql.NullInt64{
×
NEW
963
                                        Valid: true,
×
NEW
964
                                        Int64: startTime.Unix(),
×
NEW
965
                                },
×
NEW
966
                                EndTime: sql.NullInt64{
×
NEW
967
                                        Valid: true,
×
NEW
968
                                        Int64: endTime.Unix(),
×
NEW
969
                                },
×
NEW
970
                        },
×
NEW
971
                )
×
NEW
972
                if err != nil {
×
NEW
973
                        return err
×
NEW
974
                }
×
975

NEW
976
                for _, row := range rows {
×
NEW
977
                        // If we've already retrieved the info and policies for
×
NEW
978
                        // this edge, then we can skip it as we don't need to do
×
NEW
979
                        // so again.
×
NEW
980
                        chanIDInt := byteOrder.Uint64(row.Channel.Scid)
×
NEW
981
                        if _, ok := edgesSeen[chanIDInt]; ok {
×
NEW
982
                                continue
×
983
                        }
984

NEW
985
                        if channel, ok := s.chanCache.get(chanIDInt); ok {
×
NEW
986
                                hits++
×
NEW
987
                                edgesSeen[chanIDInt] = struct{}{}
×
NEW
988
                                edges = append(edges, channel)
×
NEW
989

×
NEW
990
                                continue
×
991
                        }
992

NEW
993
                        node1, node2, err := buildNodes(
×
NEW
994
                                ctx, db, row.Node, row.Node_2,
×
NEW
995
                        )
×
NEW
996
                        if err != nil {
×
NEW
997
                                return err
×
NEW
998
                        }
×
999

NEW
1000
                        channel, err := getAndBuildEdgeInfo(
×
NEW
1001
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
NEW
1002
                                row.Channel, node1.PubKeyBytes,
×
NEW
1003
                                node2.PubKeyBytes,
×
NEW
1004
                        )
×
NEW
1005
                        if err != nil {
×
NEW
1006
                                return fmt.Errorf("unable to build channel "+
×
NEW
1007
                                        "info: %w", err)
×
NEW
1008
                        }
×
1009

NEW
1010
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
NEW
1011
                        if err != nil {
×
NEW
1012
                                return fmt.Errorf("unable to extract channel "+
×
NEW
1013
                                        "policies: %w", err)
×
NEW
1014
                        }
×
1015

NEW
1016
                        p1, p2, err := getAndBuildChanPolicies(
×
NEW
1017
                                ctx, db, dbPol1, dbPol2, channel.ChannelID,
×
NEW
1018
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
NEW
1019
                        )
×
NEW
1020
                        if err != nil {
×
NEW
1021
                                return fmt.Errorf("unable to build channel "+
×
NEW
1022
                                        "policies: %w", err)
×
NEW
1023
                        }
×
1024

NEW
1025
                        edgesSeen[chanIDInt] = struct{}{}
×
NEW
1026
                        chanEdge := ChannelEdge{
×
NEW
1027
                                Info:    channel,
×
NEW
1028
                                Policy1: p1,
×
NEW
1029
                                Policy2: p2,
×
NEW
1030
                                Node1:   node1,
×
NEW
1031
                                Node2:   node2,
×
NEW
1032
                        }
×
NEW
1033
                        edges = append(edges, chanEdge)
×
NEW
1034
                        edgesToCache[chanIDInt] = chanEdge
×
1035
                }
1036

NEW
1037
                return nil
×
NEW
1038
        }, func() {
×
NEW
1039
                edgesSeen = make(map[uint64]struct{})
×
NEW
1040
                edgesToCache = make(map[uint64]ChannelEdge)
×
NEW
1041
                edges = nil
×
NEW
1042
        })
×
NEW
1043
        if err != nil {
×
NEW
1044
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
NEW
1045
        }
×
1046

1047
        // Insert any edges loaded from disk into the cache.
NEW
1048
        for chanid, channel := range edgesToCache {
×
NEW
1049
                s.chanCache.insert(chanid, channel)
×
NEW
1050
        }
×
1051

NEW
1052
        log.Debugf("ChanUpdatesInHorizon hit percentage: %f (%d/%d)",
×
NEW
1053
                float64(hits)/float64(len(edges)), hits, len(edges))
×
NEW
1054

×
NEW
1055
        return edges, nil
×
1056
}
1057

1058
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1059
// data to the call-back.
1060
//
1061
// NOTE: The callback contents MUST not be modified.
1062
//
1063
// NOTE: part of the V1Store interface.
1064
func (s *SQLStore) ForEachNodeCached(cb func(node route.Vertex,
NEW
1065
        chans map[uint64]*DirectedChannel) error) error {
×
NEW
1066

×
NEW
1067
        var ctx = context.TODO()
×
NEW
1068

×
NEW
1069
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
1070
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
NEW
1071
                        nodePub route.Vertex) error {
×
NEW
1072

×
NEW
1073
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
NEW
1074
                        if err != nil {
×
NEW
1075
                                return fmt.Errorf("unable to fetch "+
×
NEW
1076
                                        "node(id=%d) features: %w", nodeID, err)
×
NEW
1077
                        }
×
1078

NEW
1079
                        toNodeCallback := func() route.Vertex {
×
NEW
1080
                                return nodePub
×
NEW
1081
                        }
×
1082

NEW
1083
                        rows, err := db.ListChannelsByNodeID(
×
NEW
1084
                                ctx, sqlc.ListChannelsByNodeIDParams{
×
NEW
1085
                                        Version: int16(ProtocolV1),
×
NEW
1086
                                        NodeID1: nodeID,
×
NEW
1087
                                },
×
NEW
1088
                        )
×
NEW
1089
                        if err != nil {
×
NEW
1090
                                return fmt.Errorf("unable to fetch channels "+
×
NEW
1091
                                        "of node(id=%d): %w", nodeID, err)
×
NEW
1092
                        }
×
1093

NEW
1094
                        channels := make(map[uint64]*DirectedChannel, len(rows))
×
NEW
1095
                        for _, row := range rows {
×
NEW
1096
                                node1, node2, err := buildNodeVertices(
×
NEW
1097
                                        row.Node1Pubkey, row.Node2Pubkey,
×
NEW
1098
                                )
×
NEW
1099
                                if err != nil {
×
NEW
1100
                                        return err
×
NEW
1101
                                }
×
1102

NEW
1103
                                e, err := getAndBuildEdgeInfo(
×
NEW
1104
                                        ctx, db, s.cfg.ChainHash,
×
NEW
1105
                                        row.Channel.ID, row.Channel, node1,
×
NEW
1106
                                        node2,
×
NEW
1107
                                )
×
NEW
1108
                                if err != nil {
×
NEW
1109
                                        return fmt.Errorf("unable to build "+
×
NEW
1110
                                                "channel info: %w", err)
×
NEW
1111
                                }
×
1112

NEW
1113
                                dbPol1, dbPol2, err := extractChannelPolicies(
×
NEW
1114
                                        row,
×
NEW
1115
                                )
×
NEW
1116
                                if err != nil {
×
NEW
1117
                                        return fmt.Errorf("unable to "+
×
NEW
1118
                                                "extract channel "+
×
NEW
1119
                                                "policies: %w", err)
×
NEW
1120
                                }
×
1121

NEW
1122
                                p1, p2, err := getAndBuildChanPolicies(
×
NEW
1123
                                        ctx, db, dbPol1, dbPol2, e.ChannelID,
×
NEW
1124
                                        node1, node2,
×
NEW
1125
                                )
×
NEW
1126
                                if err != nil {
×
NEW
1127
                                        return fmt.Errorf("unable to "+
×
NEW
1128
                                                "build channel policies: %w",
×
NEW
1129
                                                err)
×
NEW
1130
                                }
×
1131

1132
                                // Determine the outgoing and incoming policy
1133
                                // for this channel and node combo.
NEW
1134
                                outPolicy, inPolicy := p1, p2
×
NEW
1135
                                if p1 != nil && p1.ToNode == nodePub {
×
NEW
1136
                                        outPolicy, inPolicy = p2, p1
×
NEW
1137
                                } else if p2 != nil && p2.ToNode != nodePub {
×
NEW
1138
                                        outPolicy, inPolicy = p2, p1
×
NEW
1139
                                }
×
1140

NEW
1141
                                var cachedInPolicy *models.CachedEdgePolicy
×
NEW
1142
                                if inPolicy != nil {
×
NEW
1143
                                        cachedInPolicy = models.NewCachedPolicy(
×
NEW
1144
                                                p2,
×
NEW
1145
                                        )
×
NEW
1146
                                        cachedInPolicy.ToNodePubKey =
×
NEW
1147
                                                toNodeCallback
×
NEW
1148
                                        cachedInPolicy.ToNodeFeatures =
×
NEW
1149
                                                features
×
NEW
1150
                                }
×
1151

NEW
1152
                                var inboundFee lnwire.Fee
×
NEW
1153
                                outPolicy.InboundFee.WhenSome(
×
NEW
1154
                                        func(fee lnwire.Fee) {
×
NEW
1155
                                                inboundFee = fee
×
NEW
1156
                                        },
×
1157
                                )
1158

NEW
1159
                                directedChannel := &DirectedChannel{
×
NEW
1160
                                        ChannelID: e.ChannelID,
×
NEW
1161
                                        IsNode1: nodePub ==
×
NEW
1162
                                                e.NodeKey1Bytes,
×
NEW
1163
                                        OtherNode:    e.NodeKey2Bytes,
×
NEW
1164
                                        Capacity:     e.Capacity,
×
NEW
1165
                                        OutPolicySet: p1 != nil,
×
NEW
1166
                                        InPolicy:     cachedInPolicy,
×
NEW
1167
                                        InboundFee:   inboundFee,
×
NEW
1168
                                }
×
NEW
1169

×
NEW
1170
                                if nodePub == e.NodeKey2Bytes {
×
NEW
1171
                                        directedChannel.OtherNode =
×
NEW
1172
                                                e.NodeKey1Bytes
×
NEW
1173
                                }
×
1174

NEW
1175
                                channels[e.ChannelID] = directedChannel
×
1176
                        }
1177

NEW
1178
                        return cb(nodePub, channels)
×
1179
                })
1180
        }, sqldb.NoOpReset)
1181
}
1182

1183
// ForEachChannel iterates through all the channel edges stored within the
1184
// graph and invokes the passed callback for each edge. The callback takes two
1185
// edges as since this is a directed graph, both the in/out edges are visited.
1186
// If the callback returns an error, then the transaction is aborted and the
1187
// iteration stops early.
1188
//
1189
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1190
// for that particular channel edge routing policy will be passed into the
1191
// callback.
1192
//
1193
// NOTE: part of the V1Store interface.
1194
func (s *SQLStore) ForEachChannel(cb func(*models.ChannelEdgeInfo,
NEW
1195
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
×
NEW
1196

×
NEW
1197
        var (
×
NEW
1198
                ctx          = context.TODO()
×
NEW
1199
                lastID int64 = 0
×
NEW
1200
        )
×
NEW
1201

×
NEW
1202
        handleChannel := func(db SQLQueries,
×
NEW
1203
                row sqlc.ListChannelsPaginatedRow) error {
×
NEW
1204

×
NEW
1205
                node1, node2, err := buildNodeVertices(
×
NEW
1206
                        row.Node1Pubkey, row.Node2Pubkey,
×
NEW
1207
                )
×
NEW
1208
                if err != nil {
×
NEW
1209
                        return fmt.Errorf("unable to build node vertices: %w",
×
NEW
1210
                                err)
×
NEW
1211
                }
×
1212

NEW
1213
                edge, err := getAndBuildEdgeInfo(
×
NEW
1214
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
NEW
1215
                        node1, node2,
×
NEW
1216
                )
×
NEW
1217
                if err != nil {
×
NEW
1218
                        return fmt.Errorf("unable to build channel info: %w",
×
NEW
1219
                                err)
×
NEW
1220
                }
×
1221

NEW
1222
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
NEW
1223
                if err != nil {
×
NEW
1224
                        return fmt.Errorf("unable to extract channel "+
×
NEW
1225
                                "policies: %w", err)
×
NEW
1226
                }
×
1227

NEW
1228
                p1, p2, err := getAndBuildChanPolicies(
×
NEW
1229
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
NEW
1230
                )
×
NEW
1231
                if err != nil {
×
NEW
1232
                        return fmt.Errorf("unable to build channel "+
×
NEW
1233
                                "policies: %w", err)
×
NEW
1234
                }
×
1235

NEW
1236
                err = cb(edge, p1, p2)
×
NEW
1237
                if err != nil {
×
NEW
1238
                        return fmt.Errorf("callback failed for channel "+
×
NEW
1239
                                "id=%d: %w", edge.ChannelID, err)
×
NEW
1240
                }
×
1241

NEW
1242
                return nil
×
1243
        }
1244

NEW
1245
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
1246
                for {
×
NEW
1247
                        rows, err := db.ListChannelsPaginated(
×
NEW
1248
                                ctx, sqlc.ListChannelsPaginatedParams{
×
NEW
1249
                                        Version: int16(ProtocolV1),
×
NEW
1250
                                        ID:      lastID,
×
NEW
1251
                                        Limit:   pageSize,
×
NEW
1252
                                },
×
NEW
1253
                        )
×
NEW
1254
                        if err != nil {
×
NEW
1255
                                return err
×
NEW
1256
                        }
×
1257

NEW
1258
                        if len(rows) == 0 {
×
NEW
1259
                                break
×
1260
                        }
1261

NEW
1262
                        for _, row := range rows {
×
NEW
1263
                                err := handleChannel(db, row)
×
NEW
1264
                                if err != nil {
×
NEW
1265
                                        return err
×
NEW
1266
                                }
×
1267

NEW
1268
                                lastID = row.Channel.ID
×
1269
                        }
1270
                }
1271

NEW
1272
                return nil
×
1273
        }, sqldb.NoOpReset)
1274
}
1275

1276
// FilterChannelRange returns the channel ID's of all known channels which were
1277
// mined in a block height within the passed range. The channel IDs are grouped
1278
// by their common block height. This method can be used to quickly share with a
1279
// peer the set of channels we know of within a particular range to catch them
1280
// up after a period of time offline. If withTimestamps is true then the
1281
// timestamp info of the latest received channel update messages of the channel
1282
// will be included in the response.
1283
//
1284
// NOTE: This is part of the V1Store interface.
1285
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
NEW
1286
        withTimestamps bool) ([]BlockChannelRange, error) {
×
NEW
1287

×
NEW
1288
        var (
×
NEW
1289
                ctx       = context.TODO()
×
NEW
1290
                startSCID = &lnwire.ShortChannelID{
×
NEW
1291
                        BlockHeight: startHeight,
×
NEW
1292
                }
×
NEW
1293
                endSCID = lnwire.ShortChannelID{
×
NEW
1294
                        BlockHeight: endHeight,
×
NEW
1295
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
NEW
1296
                        TxPosition:  math.MaxUint16,
×
NEW
1297
                }
×
NEW
1298
        )
×
NEW
1299

×
NEW
1300
        var chanIDStart [8]byte
×
NEW
1301
        byteOrder.PutUint64(chanIDStart[:], startSCID.ToUint64())
×
NEW
1302
        var chanIDEnd [8]byte
×
NEW
1303
        byteOrder.PutUint64(chanIDEnd[:], endSCID.ToUint64())
×
NEW
1304

×
NEW
1305
        // 1) get all channels where channelID is between start and end chan ID.
×
NEW
1306
        // 2) skip if not public (ie, no channel_proof)
×
NEW
1307
        // 3) collect that channel.
×
NEW
1308
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
NEW
1309
        //    and add those timestamps to the collected channel.
×
NEW
1310
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
NEW
1311
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
1312
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
NEW
1313
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
NEW
1314
                                StartScid: chanIDStart[:],
×
NEW
1315
                                EndScid:   chanIDEnd[:],
×
NEW
1316
                        },
×
NEW
1317
                )
×
NEW
1318
                if err != nil {
×
NEW
1319
                        return fmt.Errorf("unable to fetch channel range: %w",
×
NEW
1320
                                err)
×
NEW
1321
                }
×
1322

NEW
1323
                for _, dbChan := range dbChans {
×
NEW
1324
                        cid := lnwire.NewShortChanIDFromInt(
×
NEW
1325
                                byteOrder.Uint64(dbChan.Scid),
×
NEW
1326
                        )
×
NEW
1327
                        chanInfo := NewChannelUpdateInfo(
×
NEW
1328
                                cid, time.Time{}, time.Time{},
×
NEW
1329
                        )
×
NEW
1330

×
NEW
1331
                        if !withTimestamps {
×
NEW
1332
                                channelsPerBlock[cid.BlockHeight] = append(
×
NEW
1333
                                        channelsPerBlock[cid.BlockHeight],
×
NEW
1334
                                        chanInfo,
×
NEW
1335
                                )
×
NEW
1336

×
NEW
1337
                                continue
×
1338
                        }
1339

1340
                        //nolint:ll
NEW
1341
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
NEW
1342
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
NEW
1343
                                        Version:   int16(ProtocolV1),
×
NEW
1344
                                        ChannelID: dbChan.ID,
×
NEW
1345
                                        NodeID:    dbChan.NodeID1,
×
NEW
1346
                                },
×
NEW
1347
                        )
×
NEW
1348
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
NEW
1349
                                return fmt.Errorf("unable to fetch node1 "+
×
NEW
1350
                                        "policy: %w", err)
×
NEW
1351
                        } else if err == nil {
×
NEW
1352
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
NEW
1353
                                        node1Policy.LastUpdate.Int64, 0,
×
NEW
1354
                                )
×
NEW
1355
                        }
×
1356

1357
                        //nolint:ll
NEW
1358
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
NEW
1359
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
NEW
1360
                                        Version:   int16(ProtocolV1),
×
NEW
1361
                                        ChannelID: dbChan.ID,
×
NEW
1362
                                        NodeID:    dbChan.NodeID2,
×
NEW
1363
                                },
×
NEW
1364
                        )
×
NEW
1365
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
NEW
1366
                                return fmt.Errorf("unable to fetch node2 "+
×
NEW
1367
                                        "policy: %w", err)
×
NEW
1368
                        } else if err == nil {
×
NEW
1369
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
NEW
1370
                                        node2Policy.LastUpdate.Int64, 0,
×
NEW
1371
                                )
×
NEW
1372
                        }
×
1373

NEW
1374
                        channelsPerBlock[cid.BlockHeight] = append(
×
NEW
1375
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
NEW
1376
                        )
×
1377
                }
1378

NEW
1379
                return nil
×
NEW
1380
        }, func() {
×
NEW
1381
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
NEW
1382
        })
×
NEW
1383
        if err != nil {
×
NEW
1384
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
NEW
1385
        }
×
1386

NEW
1387
        if len(channelsPerBlock) == 0 {
×
NEW
1388
                return nil, nil
×
NEW
1389
        }
×
1390

1391
        // Return the channel ranges in ascending block height order.
NEW
1392
        blocks := make([]uint32, 0, len(channelsPerBlock))
×
NEW
1393
        for block := range channelsPerBlock {
×
NEW
1394
                blocks = append(blocks, block)
×
NEW
1395
        }
×
NEW
1396
        sort.Slice(blocks, func(i, j int) bool {
×
NEW
1397
                return blocks[i] < blocks[j]
×
NEW
1398
        })
×
1399

NEW
1400
        channelRanges := make([]BlockChannelRange, 0, len(channelsPerBlock))
×
NEW
1401
        for _, block := range blocks {
×
NEW
1402
                channelRanges = append(channelRanges, BlockChannelRange{
×
NEW
1403
                        Height:   block,
×
NEW
1404
                        Channels: channelsPerBlock[block],
×
NEW
1405
                })
×
NEW
1406
        }
×
1407

NEW
1408
        return channelRanges, nil
×
1409
}
1410

1411
// forEachNodeDirectedChannel iterates through all channels of a given
1412
// node, executing the passed callback on the directed edge representing the
1413
// channel and its incoming policy. If the node is not found, no error is
1414
// returned.
1415
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
1416
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
1417

×
1418
        toNodeCallback := func() route.Vertex {
×
1419
                return nodePub
×
1420
        }
×
1421

1422
        dbID, err := db.GetNodeIDByPubKey(
×
1423
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
1424
                        Version: int16(ProtocolV1),
×
1425
                        PubKey:  nodePub[:],
×
1426
                },
×
1427
        )
×
1428
        if errors.Is(err, sql.ErrNoRows) {
×
1429
                return nil
×
1430
        } else if err != nil {
×
1431
                return fmt.Errorf("unable to fetch node: %w", err)
×
1432
        }
×
1433

1434
        rows, err := db.ListChannelsByNodeID(
×
1435
                ctx, sqlc.ListChannelsByNodeIDParams{
×
1436
                        Version: int16(ProtocolV1),
×
1437
                        NodeID1: dbID,
×
1438
                },
×
1439
        )
×
1440
        if err != nil {
×
1441
                return fmt.Errorf("unable to fetch channels: %w", err)
×
1442
        }
×
1443

1444
        // Exit early if there are no channels for this node so we don't
1445
        // do the unnecessary feature fetching.
1446
        if len(rows) == 0 {
×
1447
                return nil
×
1448
        }
×
1449

1450
        features, err := getNodeFeatures(ctx, db, dbID)
×
1451
        if err != nil {
×
1452
                return fmt.Errorf("unable to fetch node features: %w", err)
×
1453
        }
×
1454

1455
        for _, row := range rows {
×
1456
                node1, node2, err := buildNodeVertices(
×
1457
                        row.Node1Pubkey, row.Node2Pubkey,
×
1458
                )
×
1459
                if err != nil {
×
1460
                        return fmt.Errorf("unable to build node vertices: %w",
×
1461
                                err)
×
1462
                }
×
1463

NEW
1464
                edge := buildCacheableChannelInfo(row.Channel, node1, node2)
×
1465

×
1466
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1467
                if err != nil {
×
1468
                        return err
×
1469
                }
×
1470

1471
                var p1, p2 *models.CachedEdgePolicy
×
1472
                if dbPol1 != nil {
×
1473
                        policy1, err := buildChanPolicy(
×
1474
                                *dbPol1, edge.ChannelID, nil, node2, true,
×
1475
                        )
×
1476
                        if err != nil {
×
1477
                                return err
×
1478
                        }
×
1479

1480
                        p1 = models.NewCachedPolicy(policy1)
×
1481
                }
1482
                if dbPol2 != nil {
×
1483
                        policy2, err := buildChanPolicy(
×
1484
                                *dbPol2, edge.ChannelID, nil, node1, false,
×
1485
                        )
×
1486
                        if err != nil {
×
1487
                                return err
×
1488
                        }
×
1489

1490
                        p2 = models.NewCachedPolicy(policy2)
×
1491
                }
1492

1493
                // Determine the outgoing and incoming policy for this
1494
                // channel and node combo.
1495
                outPolicy, inPolicy := p1, p2
×
1496
                if p1 != nil && node2 == nodePub {
×
1497
                        outPolicy, inPolicy = p2, p1
×
1498
                } else if p2 != nil && node1 != nodePub {
×
1499
                        outPolicy, inPolicy = p2, p1
×
1500
                }
×
1501

1502
                var cachedInPolicy *models.CachedEdgePolicy
×
1503
                if inPolicy != nil {
×
1504
                        cachedInPolicy = inPolicy
×
1505
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
1506
                        cachedInPolicy.ToNodeFeatures = features
×
1507
                }
×
1508

1509
                directedChannel := &DirectedChannel{
×
1510
                        ChannelID:    edge.ChannelID,
×
1511
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
1512
                        OtherNode:    edge.NodeKey2Bytes,
×
1513
                        Capacity:     edge.Capacity,
×
1514
                        OutPolicySet: outPolicy != nil,
×
1515
                        InPolicy:     cachedInPolicy,
×
1516
                }
×
1517
                if outPolicy != nil {
×
1518
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
1519
                                directedChannel.InboundFee = fee
×
1520
                        })
×
1521
                }
1522

1523
                if nodePub == edge.NodeKey2Bytes {
×
1524
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
1525
                }
×
1526

1527
                if err := cb(directedChannel); err != nil {
×
1528
                        return err
×
1529
                }
×
1530
        }
1531

1532
        return nil
×
1533
}
1534

1535
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
1536
// and executes the provided callback for each node.
1537
func forEachNodeCacheable(ctx context.Context, db SQLQueries,
1538
        cb func(nodeID int64, nodePub route.Vertex) error) error {
×
1539

×
1540
        var lastID int64
×
1541

×
1542
        for {
×
1543
                nodes, err := db.ListNodeIDsAndPubKeys(
×
1544
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
1545
                                Version: int16(ProtocolV1),
×
1546
                                ID:      lastID,
×
1547
                                Limit:   pageSize,
×
1548
                        },
×
1549
                )
×
1550
                if err != nil {
×
1551
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
1552
                }
×
1553

1554
                if len(nodes) == 0 {
×
1555
                        break
×
1556
                }
1557

1558
                for _, node := range nodes {
×
1559
                        var pub route.Vertex
×
1560
                        copy(pub[:], node.PubKey)
×
1561

×
1562
                        if err := cb(node.ID, pub); err != nil {
×
1563
                                return fmt.Errorf("forEachNodeCacheable "+
×
1564
                                        "callback failed for node(id=%d): %w",
×
1565
                                        node.ID, err)
×
1566
                        }
×
1567

1568
                        lastID = node.ID
×
1569
                }
1570
        }
1571

1572
        return nil
×
1573
}
1574

1575
// forEachNodeChannel iterates through all channels of a node, executing
1576
// the passed callback on each. The call-back is provided with the channel's
1577
// edge information, the outgoing policy and the incoming policy for the
1578
// channel and node combo.
1579
func forEachNodeChannel(ctx context.Context, db SQLQueries,
1580
        chain chainhash.Hash, id int64, cb func(*models.ChannelEdgeInfo,
1581
                *models.ChannelEdgePolicy,
1582
                *models.ChannelEdgePolicy) error) error {
×
1583

×
1584
        // Get all the V1 channels for this node.Add commentMore actions
×
1585
        rows, err := db.ListChannelsByNodeID(
×
1586
                ctx, sqlc.ListChannelsByNodeIDParams{
×
1587
                        Version: int16(ProtocolV1),
×
1588
                        NodeID1: id,
×
1589
                },
×
1590
        )
×
1591
        if err != nil {
×
1592
                return fmt.Errorf("unable to fetch channels: %w", err)
×
1593
        }
×
1594

1595
        // Call the call-back for each channel and its known policies.
1596
        for _, row := range rows {
×
1597
                node1, node2, err := buildNodeVertices(
×
1598
                        row.Node1Pubkey, row.Node2Pubkey,
×
1599
                )
×
1600
                if err != nil {
×
1601
                        return fmt.Errorf("unable to build node vertices: %w",
×
1602
                                err)
×
1603
                }
×
1604

1605
                edge, err := getAndBuildEdgeInfo(
×
1606
                        ctx, db, chain, row.Channel.ID, row.Channel, node1,
×
1607
                        node2,
×
1608
                )
×
1609
                if err != nil {
×
1610
                        return fmt.Errorf("unable to build channel info: %w",
×
1611
                                err)
×
1612
                }
×
1613

1614
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1615
                if err != nil {
×
1616
                        return fmt.Errorf("unable to extract channel "+
×
1617
                                "policies: %w", err)
×
1618
                }
×
1619

1620
                p1, p2, err := getAndBuildChanPolicies(
×
1621
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1622
                )
×
1623
                if err != nil {
×
1624
                        return fmt.Errorf("unable to build channel "+
×
1625
                                "policies: %w", err)
×
1626
                }
×
1627

1628
                // Determine the outgoing and incoming policy for this
1629
                // channel and node combo.
1630
                p1ToNode := row.Channel.NodeID2
×
1631
                p2ToNode := row.Channel.NodeID1
×
1632
                outPolicy, inPolicy := p1, p2
×
1633
                if (p1 != nil && p1ToNode == id) ||
×
1634
                        (p2 != nil && p2ToNode != id) {
×
1635

×
1636
                        outPolicy, inPolicy = p2, p1
×
1637
                }
×
1638

1639
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
1640
                        return err
×
1641
                }
×
1642
        }
1643

1644
        return nil
×
1645
}
1646

1647
// updateChanEdgePolicy upserts the channel policy info we have stored for
1648
// a channel we already know of.
1649
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
1650
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
1651
        error) {
×
1652

×
1653
        var (
×
1654
                node1Pub, node2Pub route.Vertex
×
1655
                isNode1            bool
×
1656
                chanIDB            [8]byte
×
1657
        )
×
1658
        byteOrder.PutUint64(chanIDB[:], edge.ChannelID)
×
1659

×
1660
        // Check that this edge policy refers to a channel that we already
×
1661
        // know of. We do this explicitly so that we can return the appropriate
×
1662
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
1663
        // abort the transaction which would abort the entire batch.
×
1664
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
1665
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
1666
                        Scid:    chanIDB[:],
×
1667
                        Version: int16(ProtocolV1),
×
1668
                },
×
1669
        )
×
1670
        if errors.Is(err, sql.ErrNoRows) {
×
1671
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
1672
        } else if err != nil {
×
1673
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
1674
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
1675
        }
×
1676

1677
        copy(node1Pub[:], dbChan.Node1PubKey)
×
1678
        copy(node2Pub[:], dbChan.Node2PubKey)
×
1679

×
1680
        // Figure out which node this edge is from.
×
1681
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
1682
        nodeID := dbChan.NodeID1
×
1683
        if !isNode1 {
×
1684
                nodeID = dbChan.NodeID2
×
1685
        }
×
1686

1687
        var (
×
1688
                inboundBase sql.NullInt64
×
1689
                inboundRate sql.NullInt64
×
1690
        )
×
1691
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
1692
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
1693
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
1694
        })
×
1695

1696
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
1697
                Version:     int16(ProtocolV1),
×
1698
                ChannelID:   dbChan.ID,
×
1699
                NodeID:      nodeID,
×
1700
                Timelock:    int32(edge.TimeLockDelta),
×
1701
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
1702
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
1703
                MinHtlcMsat: int64(edge.MinHTLC),
×
1704
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
1705
                Disabled: sql.NullBool{
×
1706
                        Valid: true,
×
1707
                        Bool:  edge.IsDisabled(),
×
1708
                },
×
1709
                MaxHtlcMsat: sql.NullInt64{
×
1710
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
1711
                        Int64: int64(edge.MaxHTLC),
×
1712
                },
×
1713
                InboundBaseFeeMsat:      inboundBase,
×
1714
                InboundFeeRateMilliMsat: inboundRate,
×
1715
                Signature:               edge.SigBytes,
×
1716
        })
×
1717
        if err != nil {
×
1718
                return node1Pub, node2Pub, isNode1,
×
1719
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
1720
        }
×
1721

1722
        // Convert the flat extra opaque data into a map of TLV types to
1723
        // values.
1724
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
1725
        if err != nil {
×
1726
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
1727
                        "marshal extra opaque data: %w", err)
×
1728
        }
×
1729

1730
        // Update the channel policy's extra signed fields.
1731
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
1732
        if err != nil {
×
1733
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
1734
                        "policy extra TLVs: %w", err)
×
1735
        }
×
1736

1737
        return node1Pub, node2Pub, isNode1, nil
×
1738
}
1739

1740
// getNodeByPubKey attempts to look up a target node by its public key.
1741
func getNodeByPubKey(ctx context.Context, db SQLQueries,
1742
        pubKey route.Vertex) (int64, *models.LightningNode, error) {
×
1743

×
1744
        dbNode, err := db.GetNodeByPubKey(
×
1745
                ctx, sqlc.GetNodeByPubKeyParams{
×
1746
                        Version: int16(ProtocolV1),
×
1747
                        PubKey:  pubKey[:],
×
1748
                },
×
1749
        )
×
1750
        if errors.Is(err, sql.ErrNoRows) {
×
1751
                return 0, nil, ErrGraphNodeNotFound
×
1752
        } else if err != nil {
×
1753
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
1754
        }
×
1755

1756
        node, err := buildNode(ctx, db, &dbNode)
×
1757
        if err != nil {
×
1758
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
1759
        }
×
1760

1761
        return dbNode.ID, node, nil
×
1762
}
1763

1764
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
1765
// provided database channel row and the public keys of the two nodes
1766
// involved in the channel.
1767
func buildCacheableChannelInfo(dbChan sqlc.Channel, node1Pub,
NEW
1768
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
1769

×
1770
        return &models.CachedEdgeInfo{
×
1771
                ChannelID:     byteOrder.Uint64(dbChan.Scid),
×
1772
                NodeKey1Bytes: node1Pub,
×
1773
                NodeKey2Bytes: node2Pub,
×
1774
                Capacity:      btcutil.Amount(dbChan.Capacity.Int64),
×
NEW
1775
        }
×
1776
}
×
1777

1778
// buildNode constructs a LightningNode instance from the given database node
1779
// record. The node's features, addresses and extra signed fields are also
1780
// fetched from the database and set on the node.
1781
func buildNode(ctx context.Context, db SQLQueries, dbNode *sqlc.Node) (
1782
        *models.LightningNode, error) {
×
1783

×
1784
        if dbNode.Version != int16(ProtocolV1) {
×
1785
                return nil, fmt.Errorf("unsupported node version: %d",
×
1786
                        dbNode.Version)
×
1787
        }
×
1788

1789
        var pub [33]byte
×
1790
        copy(pub[:], dbNode.PubKey)
×
1791

×
1792
        node := &models.LightningNode{
×
1793
                PubKeyBytes: pub,
×
1794
                Features:    lnwire.EmptyFeatureVector(),
×
1795
                LastUpdate:  time.Unix(0, 0),
×
1796
        }
×
1797

×
1798
        if len(dbNode.Signature) == 0 {
×
1799
                return node, nil
×
1800
        }
×
1801

1802
        node.HaveNodeAnnouncement = true
×
1803
        node.AuthSigBytes = dbNode.Signature
×
1804
        node.Alias = dbNode.Alias.String
×
1805
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
1806

×
1807
        var err error
×
1808
        if dbNode.Color.Valid {
×
1809
                node.Color, err = DecodeHexColor(dbNode.Color.String)
×
1810
                if err != nil {
×
1811
                        return nil, fmt.Errorf("unable to decode color: %w",
×
1812
                                err)
×
1813
                }
×
1814
        }
1815

1816
        // Fetch the node's features.
1817
        node.Features, err = getNodeFeatures(ctx, db, dbNode.ID)
×
1818
        if err != nil {
×
1819
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
1820
                        "features: %w", dbNode.ID, err)
×
1821
        }
×
1822

1823
        // Fetch the node's addresses.
1824
        _, node.Addresses, err = getNodeAddresses(ctx, db, pub[:])
×
1825
        if err != nil {
×
1826
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
1827
                        "addresses: %w", dbNode.ID, err)
×
1828
        }
×
1829

1830
        // Fetch the node's extra signed fields.
1831
        extraTLVMap, err := getNodeExtraSignedFields(ctx, db, dbNode.ID)
×
1832
        if err != nil {
×
1833
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
1834
                        "extra signed fields: %w", dbNode.ID, err)
×
1835
        }
×
1836

1837
        recs, err := lnwire.CustomRecords(extraTLVMap).Serialize()
×
1838
        if err != nil {
×
1839
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
1840
                        "fields: %w", err)
×
1841
        }
×
1842

1843
        if len(recs) != 0 {
×
1844
                node.ExtraOpaqueData = recs
×
1845
        }
×
1846

1847
        return node, nil
×
1848
}
1849

1850
// getNodeFeatures fetches the feature bits and constructs the feature vector
1851
// for a node with the given DB ID.
1852
func getNodeFeatures(ctx context.Context, db SQLQueries,
1853
        nodeID int64) (*lnwire.FeatureVector, error) {
×
1854

×
1855
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
1856
        if err != nil {
×
1857
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
1858
                        nodeID, err)
×
1859
        }
×
1860

1861
        features := lnwire.EmptyFeatureVector()
×
1862
        for _, feature := range rows {
×
1863
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
1864
        }
×
1865

1866
        return features, nil
×
1867
}
1868

1869
// getNodeExtraSignedFields fetches the extra signed fields for a node with the
1870
// given DB ID.
1871
func getNodeExtraSignedFields(ctx context.Context, db SQLQueries,
1872
        nodeID int64) (map[uint64][]byte, error) {
×
1873

×
1874
        fields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
1875
        if err != nil {
×
1876
                return nil, fmt.Errorf("unable to get node(%d) extra "+
×
1877
                        "signed fields: %w", nodeID, err)
×
1878
        }
×
1879

1880
        extraFields := make(map[uint64][]byte)
×
1881
        for _, field := range fields {
×
1882
                extraFields[uint64(field.Type)] = field.Value
×
1883
        }
×
1884

1885
        return extraFields, nil
×
1886
}
1887

1888
// upsertNode upserts the node record into the database. If the node already
1889
// exists, then the node's information is updated. If the node doesn't exist,
1890
// then a new node is created. The node's features, addresses and extra TLV
1891
// types are also updated. The node's DB ID is returned.
1892
func upsertNode(ctx context.Context, db SQLQueries,
1893
        node *models.LightningNode) (int64, error) {
×
1894

×
1895
        params := sqlc.UpsertNodeParams{
×
1896
                Version: int16(ProtocolV1),
×
1897
                PubKey:  node.PubKeyBytes[:],
×
1898
        }
×
1899

×
1900
        if node.HaveNodeAnnouncement {
×
1901
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
1902
                params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
×
1903
                params.Alias = sqldb.SQLStr(node.Alias)
×
1904
                params.Signature = node.AuthSigBytes
×
1905
        }
×
1906

1907
        nodeID, err := db.UpsertNode(ctx, params)
×
1908
        if err != nil {
×
1909
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
1910
                        err)
×
1911
        }
×
1912

1913
        // We can exit here if we don't have the announcement yet.
1914
        if !node.HaveNodeAnnouncement {
×
1915
                return nodeID, nil
×
1916
        }
×
1917

1918
        // Update the node's features.
1919
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
1920
        if err != nil {
×
1921
                return 0, fmt.Errorf("inserting node features: %w", err)
×
1922
        }
×
1923

1924
        // Update the node's addresses.
1925
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
1926
        if err != nil {
×
1927
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
1928
        }
×
1929

1930
        // Convert the flat extra opaque data into a map of TLV types to
1931
        // values.
1932
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
1933
        if err != nil {
×
1934
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
1935
                        err)
×
1936
        }
×
1937

1938
        // Update the node's extra signed fields.
1939
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
1940
        if err != nil {
×
1941
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
1942
        }
×
1943

1944
        return nodeID, nil
×
1945
}
1946

1947
// upsertNodeFeatures updates the node's features node_features table. This
1948
// includes deleting any feature bits no longer present and inserting any new
1949
// feature bits. If the feature bit does not yet exist in the features table,
1950
// then an entry is created in that table first.
1951
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
1952
        features *lnwire.FeatureVector) error {
×
1953

×
1954
        // Get any existing features for the node.
×
1955
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
1956
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1957
                return err
×
1958
        }
×
1959

1960
        // Copy the nodes latest set of feature bits.
1961
        newFeatures := make(map[int32]struct{})
×
1962
        if features != nil {
×
1963
                for feature := range features.Features() {
×
1964
                        newFeatures[int32(feature)] = struct{}{}
×
1965
                }
×
1966
        }
1967

1968
        // For any current feature that already exists in the DB, remove it from
1969
        // the in-memory map. For any existing feature that does not exist in
1970
        // the in-memory map, delete it from the database.
1971
        for _, feature := range existingFeatures {
×
1972
                // The feature is still present, so there are no updates to be
×
1973
                // made.
×
1974
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
1975
                        delete(newFeatures, feature.FeatureBit)
×
1976
                        continue
×
1977
                }
1978

1979
                // The feature is no longer present, so we remove it from the
1980
                // database.
1981
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
1982
                        NodeID:     nodeID,
×
1983
                        FeatureBit: feature.FeatureBit,
×
1984
                })
×
1985
                if err != nil {
×
1986
                        return fmt.Errorf("unable to delete node(%d) "+
×
1987
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
1988
                                err)
×
1989
                }
×
1990
        }
1991

1992
        // Any remaining entries in newFeatures are new features that need to be
1993
        // added to the database for the first time.
1994
        for feature := range newFeatures {
×
1995
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
1996
                        NodeID:     nodeID,
×
1997
                        FeatureBit: feature,
×
1998
                })
×
1999
                if err != nil {
×
2000
                        return fmt.Errorf("unable to insert node(%d) "+
×
2001
                                "feature(%v): %w", nodeID, feature, err)
×
2002
                }
×
2003
        }
2004

2005
        return nil
×
2006
}
2007

2008
// fetchNodeFeatures fetches the features for a node with the given public key.
2009
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
2010
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
2011

×
2012
        rows, err := queries.GetNodeFeaturesByPubKey(
×
2013
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
2014
                        PubKey:  nodePub[:],
×
2015
                        Version: int16(ProtocolV1),
×
2016
                },
×
2017
        )
×
2018
        if err != nil {
×
2019
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
2020
                        nodePub, err)
×
2021
        }
×
2022

2023
        features := lnwire.EmptyFeatureVector()
×
2024
        for _, bit := range rows {
×
2025
                features.Set(lnwire.FeatureBit(bit))
×
2026
        }
×
2027

2028
        return features, nil
×
2029
}
2030

2031
// dbAddressType is an enum type that represents the different address types
2032
// that we store in the node_addresses table. The address type determines how
2033
// the address is to be serialised/deserialize.
2034
type dbAddressType uint8
2035

2036
const (
2037
        addressTypeIPv4   dbAddressType = 1
2038
        addressTypeIPv6   dbAddressType = 2
2039
        addressTypeTorV2  dbAddressType = 3
2040
        addressTypeTorV3  dbAddressType = 4
2041
        addressTypeOpaque dbAddressType = math.MaxInt8
2042
)
2043

2044
// upsertNodeAddresses updates the node's addresses in the database. This
2045
// includes deleting any existing addresses and inserting the new set of
2046
// addresses. The deletion is necessary since the ordering of the addresses may
2047
// change, and we need to ensure that the database reflects the latest set of
2048
// addresses so that at the time of reconstructing the node announcement, the
2049
// order is preserved and the signature over the message remains valid.
2050
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
2051
        addresses []net.Addr) error {
×
2052

×
2053
        // Delete any existing addresses for the node. This is required since
×
2054
        // even if the new set of addresses is the same, the ordering may have
×
2055
        // changed for a given address type.
×
2056
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
2057
        if err != nil {
×
2058
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
2059
                        nodeID, err)
×
2060
        }
×
2061

2062
        // Copy the nodes latest set of addresses.
2063
        newAddresses := map[dbAddressType][]string{
×
2064
                addressTypeIPv4:   {},
×
2065
                addressTypeIPv6:   {},
×
2066
                addressTypeTorV2:  {},
×
2067
                addressTypeTorV3:  {},
×
2068
                addressTypeOpaque: {},
×
2069
        }
×
2070
        addAddr := func(t dbAddressType, addr net.Addr) {
×
2071
                newAddresses[t] = append(newAddresses[t], addr.String())
×
2072
        }
×
2073

2074
        for _, address := range addresses {
×
2075
                switch addr := address.(type) {
×
2076
                case *net.TCPAddr:
×
2077
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
2078
                                addAddr(addressTypeIPv4, addr)
×
2079
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
2080
                                addAddr(addressTypeIPv6, addr)
×
2081
                        } else {
×
2082
                                return fmt.Errorf("unhandled IP address: %v",
×
2083
                                        addr)
×
2084
                        }
×
2085

2086
                case *tor.OnionAddr:
×
2087
                        switch len(addr.OnionService) {
×
2088
                        case tor.V2Len:
×
2089
                                addAddr(addressTypeTorV2, addr)
×
2090
                        case tor.V3Len:
×
2091
                                addAddr(addressTypeTorV3, addr)
×
2092
                        default:
×
2093
                                return fmt.Errorf("invalid length for a tor " +
×
2094
                                        "address")
×
2095
                        }
2096

2097
                case *lnwire.OpaqueAddrs:
×
2098
                        addAddr(addressTypeOpaque, addr)
×
2099

2100
                default:
×
2101
                        return fmt.Errorf("unhandled address type: %T", addr)
×
2102
                }
2103
        }
2104

2105
        // Any remaining entries in newAddresses are new addresses that need to
2106
        // be added to the database for the first time.
2107
        for addrType, addrList := range newAddresses {
×
2108
                for position, addr := range addrList {
×
2109
                        err := db.InsertNodeAddress(
×
2110
                                ctx, sqlc.InsertNodeAddressParams{
×
2111
                                        NodeID:   nodeID,
×
2112
                                        Type:     int16(addrType),
×
2113
                                        Address:  addr,
×
2114
                                        Position: int32(position),
×
2115
                                },
×
2116
                        )
×
2117
                        if err != nil {
×
2118
                                return fmt.Errorf("unable to insert "+
×
2119
                                        "node(%d) address(%v): %w", nodeID,
×
2120
                                        addr, err)
×
2121
                        }
×
2122
                }
2123
        }
2124

2125
        return nil
×
2126
}
2127

2128
// getNodeAddresses fetches the addresses for a node with the given public key.
2129
func getNodeAddresses(ctx context.Context, db SQLQueries, nodePub []byte) (bool,
2130
        []net.Addr, error) {
×
2131

×
2132
        // GetNodeAddressesByPubKey ensures that the addresses for a given type
×
2133
        // are returned in the same order as they were inserted.
×
2134
        rows, err := db.GetNodeAddressesByPubKey(
×
2135
                ctx, sqlc.GetNodeAddressesByPubKeyParams{
×
2136
                        Version: int16(ProtocolV1),
×
2137
                        PubKey:  nodePub,
×
2138
                },
×
2139
        )
×
2140
        if err != nil {
×
2141
                return false, nil, err
×
2142
        }
×
2143

2144
        // GetNodeAddressesByPubKey uses a left join so there should always be
2145
        // at least one row returned if the node exists even if it has no
2146
        // addresses.
2147
        if len(rows) == 0 {
×
2148
                return false, nil, nil
×
2149
        }
×
2150

2151
        addresses := make([]net.Addr, 0, len(rows))
×
2152
        for _, addr := range rows {
×
2153
                if !(addr.Type.Valid && addr.Address.Valid) {
×
2154
                        continue
×
2155
                }
2156

2157
                address := addr.Address.String
×
2158

×
2159
                switch dbAddressType(addr.Type.Int16) {
×
2160
                case addressTypeIPv4:
×
2161
                        tcp, err := net.ResolveTCPAddr("tcp4", address)
×
2162
                        if err != nil {
×
2163
                                return false, nil, nil
×
2164
                        }
×
2165
                        tcp.IP = tcp.IP.To4()
×
2166

×
2167
                        addresses = append(addresses, tcp)
×
2168

2169
                case addressTypeIPv6:
×
2170
                        tcp, err := net.ResolveTCPAddr("tcp6", address)
×
2171
                        if err != nil {
×
2172
                                return false, nil, nil
×
2173
                        }
×
2174
                        addresses = append(addresses, tcp)
×
2175

2176
                case addressTypeTorV3, addressTypeTorV2:
×
2177
                        service, portStr, err := net.SplitHostPort(address)
×
2178
                        if err != nil {
×
2179
                                return false, nil, fmt.Errorf("unable to "+
×
2180
                                        "split tor v3 address: %v",
×
2181
                                        addr.Address)
×
2182
                        }
×
2183

2184
                        port, err := strconv.Atoi(portStr)
×
2185
                        if err != nil {
×
2186
                                return false, nil, err
×
2187
                        }
×
2188

2189
                        addresses = append(addresses, &tor.OnionAddr{
×
2190
                                OnionService: service,
×
2191
                                Port:         port,
×
2192
                        })
×
2193

2194
                case addressTypeOpaque:
×
2195
                        opaque, err := hex.DecodeString(address)
×
2196
                        if err != nil {
×
2197
                                return false, nil, fmt.Errorf("unable to "+
×
2198
                                        "decode opaque address: %v", addr)
×
2199
                        }
×
2200

2201
                        addresses = append(addresses, &lnwire.OpaqueAddrs{
×
2202
                                Payload: opaque,
×
2203
                        })
×
2204

2205
                default:
×
2206
                        return false, nil, fmt.Errorf("unknown address "+
×
2207
                                "type: %v", addr.Type)
×
2208
                }
2209
        }
2210

2211
        return true, addresses, nil
×
2212
}
2213

2214
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
2215
// database. This includes updating any existing types, inserting any new types,
2216
// and deleting any types that are no longer present.
2217
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
2218
        nodeID int64, extraFields map[uint64][]byte) error {
×
2219

×
2220
        // Get any existing extra signed fields for the node.
×
2221
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
2222
        if err != nil {
×
2223
                return err
×
2224
        }
×
2225

2226
        // Make a lookup map of the existing field types so that we can use it
2227
        // to keep track of any fields we should delete.
2228
        m := make(map[uint64]bool)
×
2229
        for _, field := range existingFields {
×
2230
                m[uint64(field.Type)] = true
×
2231
        }
×
2232

2233
        // For all the new fields, we'll upsert them and remove them from the
2234
        // map of existing fields.
2235
        for tlvType, value := range extraFields {
×
2236
                err = db.UpsertNodeExtraType(
×
2237
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
2238
                                NodeID: nodeID,
×
2239
                                Type:   int64(tlvType),
×
2240
                                Value:  value,
×
2241
                        },
×
2242
                )
×
2243
                if err != nil {
×
2244
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
2245
                                "signed field(%v): %w", nodeID, tlvType, err)
×
2246
                }
×
2247

2248
                // Remove the field from the map of existing fields if it was
2249
                // present.
2250
                delete(m, tlvType)
×
2251
        }
2252

2253
        // For all the fields that are left in the map of existing fields, we'll
2254
        // delete them as they are no longer present in the new set of fields.
2255
        for tlvType := range m {
×
2256
                err = db.DeleteExtraNodeType(
×
2257
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
2258
                                NodeID: nodeID,
×
2259
                                Type:   int64(tlvType),
×
2260
                        },
×
2261
                )
×
2262
                if err != nil {
×
2263
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
2264
                                "signed field(%v): %w", nodeID, tlvType, err)
×
2265
                }
×
2266
        }
2267

2268
        return nil
×
2269
}
2270

2271
// srcNodeInfo holds the information about the source node of the graph.
2272
type srcNodeInfo struct {
2273
        // id is the DB level ID of the source node entry in the "nodes" table.
2274
        id int64
2275

2276
        // pub is the public key of the source node.
2277
        pub route.Vertex
2278
}
2279

2280
// getSourceNode returns the DB node ID and pub key of the source node for the
2281
// specified protocol version.
2282
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
2283
        version ProtocolVersion) (int64, route.Vertex, error) {
×
2284

×
2285
        s.srcNodeMu.Lock()
×
2286
        defer s.srcNodeMu.Unlock()
×
2287

×
2288
        // If we already have the source node ID and pub key cached, then
×
2289
        // return them.
×
2290
        if info, ok := s.srcNodes[version]; ok {
×
2291
                return info.id, info.pub, nil
×
2292
        }
×
2293

2294
        var pubKey route.Vertex
×
2295

×
2296
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
2297
        if err != nil {
×
2298
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
2299
                        err)
×
2300
        }
×
2301

2302
        if len(nodes) == 0 {
×
2303
                return 0, pubKey, ErrSourceNodeNotSet
×
2304
        } else if len(nodes) > 1 {
×
2305
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
2306
                        "protocol %s found", version)
×
2307
        }
×
2308

2309
        copy(pubKey[:], nodes[0].PubKey)
×
2310

×
2311
        s.srcNodes[version] = &srcNodeInfo{
×
2312
                id:  nodes[0].NodeID,
×
2313
                pub: pubKey,
×
2314
        }
×
2315

×
2316
        return nodes[0].NodeID, pubKey, nil
×
2317
}
2318

2319
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
2320
// This then produces a map from TLV type to value. If the input is not a
2321
// valid TLV stream, then an error is returned.
2322
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
2323
        r := bytes.NewReader(data)
×
2324

×
2325
        tlvStream, err := tlv.NewStream()
×
2326
        if err != nil {
×
2327
                return nil, err
×
2328
        }
×
2329

2330
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
2331
        // pass it into the P2P decoding variant.
2332
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
2333
        if err != nil {
×
2334
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
2335
        }
×
2336
        if len(parsedTypes) == 0 {
×
2337
                return nil, nil
×
2338
        }
×
2339

2340
        records := make(map[uint64][]byte)
×
2341
        for k, v := range parsedTypes {
×
2342
                records[uint64(k)] = v
×
2343
        }
×
2344

2345
        return records, nil
×
2346
}
2347

2348
// insertChannel inserts a new channel record into the database.
2349
func insertChannel(ctx context.Context, db SQLQueries,
2350
        edge *models.ChannelEdgeInfo) error {
×
2351

×
2352
        var chanIDB [8]byte
×
2353
        byteOrder.PutUint64(chanIDB[:], edge.ChannelID)
×
2354

×
2355
        // Make sure that the channel doesn't already exist. We do this
×
2356
        // explicitly instead of relying on catching a unique constraint error
×
2357
        // because relying on SQL to throw that error would abort the entire
×
2358
        // batch of transactions.
×
2359
        _, err := db.GetChannelBySCID(
×
2360
                ctx, sqlc.GetChannelBySCIDParams{
×
2361
                        Scid:    chanIDB[:],
×
2362
                        Version: int16(ProtocolV1),
×
2363
                },
×
2364
        )
×
2365
        if err == nil {
×
2366
                return ErrEdgeAlreadyExist
×
2367
        } else if !errors.Is(err, sql.ErrNoRows) {
×
2368
                return fmt.Errorf("unable to fetch channel: %w", err)
×
2369
        }
×
2370

2371
        // Make sure that at least a "shell" entry for each node is present in
2372
        // the nodes table.
2373
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
2374
        if err != nil {
×
2375
                return fmt.Errorf("unable to create shell node: %w", err)
×
2376
        }
×
2377

2378
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
2379
        if err != nil {
×
2380
                return fmt.Errorf("unable to create shell node: %w", err)
×
2381
        }
×
2382

2383
        var capacity sql.NullInt64
×
2384
        if edge.Capacity != 0 {
×
2385
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
2386
        }
×
2387

2388
        createParams := sqlc.CreateChannelParams{
×
2389
                Version:     int16(ProtocolV1),
×
2390
                Scid:        chanIDB[:],
×
2391
                NodeID1:     node1DBID,
×
2392
                NodeID2:     node2DBID,
×
2393
                Outpoint:    edge.ChannelPoint.String(),
×
2394
                Capacity:    capacity,
×
2395
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
2396
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
2397
        }
×
2398

×
2399
        if edge.AuthProof != nil {
×
2400
                proof := edge.AuthProof
×
2401

×
2402
                createParams.Node1Signature = proof.NodeSig1Bytes
×
2403
                createParams.Node2Signature = proof.NodeSig2Bytes
×
2404
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
2405
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
2406
        }
×
2407

2408
        // Insert the new channel record.
2409
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
2410
        if err != nil {
×
2411
                return err
×
2412
        }
×
2413

2414
        // Insert any channel features.
2415
        if len(edge.Features) != 0 {
×
2416
                chanFeatures := lnwire.NewRawFeatureVector()
×
2417
                err := chanFeatures.Decode(bytes.NewReader(edge.Features))
×
2418
                if err != nil {
×
2419
                        return err
×
2420
                }
×
2421

2422
                fv := lnwire.NewFeatureVector(chanFeatures, lnwire.Features)
×
2423
                for feature := range fv.Features() {
×
2424
                        err = db.InsertChannelFeature(
×
2425
                                ctx, sqlc.InsertChannelFeatureParams{
×
2426
                                        ChannelID:  dbChanID,
×
2427
                                        FeatureBit: int32(feature),
×
2428
                                },
×
2429
                        )
×
2430
                        if err != nil {
×
2431
                                return fmt.Errorf("unable to insert "+
×
2432
                                        "channel(%d) feature(%v): %w", dbChanID,
×
2433
                                        feature, err)
×
2434
                        }
×
2435
                }
2436
        }
2437

2438
        // Finally, insert any extra TLV fields in the channel announcement.
2439
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
2440
        if err != nil {
×
2441
                return fmt.Errorf("unable to marshal extra opaque data: %w",
×
2442
                        err)
×
2443
        }
×
2444

2445
        for tlvType, value := range extra {
×
2446
                err := db.CreateChannelExtraType(
×
2447
                        ctx, sqlc.CreateChannelExtraTypeParams{
×
2448
                                ChannelID: dbChanID,
×
2449
                                Type:      int64(tlvType),
×
2450
                                Value:     value,
×
2451
                        },
×
2452
                )
×
2453
                if err != nil {
×
2454
                        return fmt.Errorf("unable to upsert channel(%d) extra "+
×
2455
                                "signed field(%v): %w", edge.ChannelID,
×
2456
                                tlvType, err)
×
2457
                }
×
2458
        }
2459

2460
        return nil
×
2461
}
2462

2463
// maybeCreateShellNode checks if a shell node entry exists for the
2464
// given public key. If it does not exist, then a new shell node entry is
2465
// created. The ID of the node is returned. A shell node only has a protocol
2466
// version and public key persisted.
2467
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
2468
        pubKey route.Vertex) (int64, error) {
×
2469

×
2470
        dbNode, err := db.GetNodeByPubKey(
×
2471
                ctx, sqlc.GetNodeByPubKeyParams{
×
2472
                        PubKey:  pubKey[:],
×
2473
                        Version: int16(ProtocolV1),
×
2474
                },
×
2475
        )
×
2476
        // The node exists. Return the ID.
×
2477
        if err == nil {
×
2478
                return dbNode.ID, nil
×
2479
        } else if !errors.Is(err, sql.ErrNoRows) {
×
2480
                return 0, err
×
2481
        }
×
2482

2483
        // Otherwise, the node does not exist, so we create a shell entry for
2484
        // it.
2485
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
2486
                Version: int16(ProtocolV1),
×
2487
                PubKey:  pubKey[:],
×
2488
        })
×
2489
        if err != nil {
×
2490
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
2491
        }
×
2492

2493
        return id, nil
×
2494
}
2495

2496
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
2497
// the database. This includes deleting any existing types and then inserting
2498
// the new types.
2499
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
2500
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
2501

×
2502
        // Delete all existing extra signed fields for the channel policy.
×
2503
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
2504
        if err != nil {
×
2505
                return fmt.Errorf("unable to delete "+
×
2506
                        "existing policy extra signed fields for policy %d: %w",
×
2507
                        chanPolicyID, err)
×
2508
        }
×
2509

2510
        // Insert all new extra signed fields for the channel policy.
2511
        for tlvType, value := range extraFields {
×
2512
                err = db.InsertChanPolicyExtraType(
×
2513
                        ctx, sqlc.InsertChanPolicyExtraTypeParams{
×
2514
                                ChannelPolicyID: chanPolicyID,
×
2515
                                Type:            int64(tlvType),
×
2516
                                Value:           value,
×
2517
                        },
×
2518
                )
×
2519
                if err != nil {
×
2520
                        return fmt.Errorf("unable to insert "+
×
2521
                                "channel_policy(%d) extra signed field(%v): %w",
×
2522
                                chanPolicyID, tlvType, err)
×
2523
                }
×
2524
        }
2525

2526
        return nil
×
2527
}
2528

2529
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
2530
// provided dbChanRow and also fetches any other required information
2531
// to construct the edge info.
2532
func getAndBuildEdgeInfo(ctx context.Context, db SQLQueries,
2533
        chain chainhash.Hash, dbChanID int64, dbChan sqlc.Channel, node1,
2534
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
2535

×
2536
        fv, extras, err := getChanFeaturesAndExtras(
×
2537
                ctx, db, dbChanID,
×
2538
        )
×
2539
        if err != nil {
×
2540
                return nil, err
×
2541
        }
×
2542

2543
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
2544
        if err != nil {
×
2545
                return nil, err
×
2546
        }
×
2547

2548
        var featureBuf bytes.Buffer
×
2549
        if err := fv.Encode(&featureBuf); err != nil {
×
2550
                return nil, fmt.Errorf("unable to encode features: %w", err)
×
2551
        }
×
2552

2553
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
2554
        if err != nil {
×
2555
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
2556
                        "fields: %w", err)
×
2557
        }
×
2558
        if recs == nil {
×
2559
                recs = make([]byte, 0)
×
2560
        }
×
2561

2562
        var btcKey1, btcKey2 route.Vertex
×
2563
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
2564
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
2565

×
2566
        channel := &models.ChannelEdgeInfo{
×
2567
                ChainHash:        chain,
×
2568
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
2569
                NodeKey1Bytes:    node1,
×
2570
                NodeKey2Bytes:    node2,
×
2571
                BitcoinKey1Bytes: btcKey1,
×
2572
                BitcoinKey2Bytes: btcKey2,
×
2573
                ChannelPoint:     *op,
×
2574
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
2575
                Features:         featureBuf.Bytes(),
×
2576
                ExtraOpaqueData:  recs,
×
2577
        }
×
2578

×
2579
        // We always set all the signatures at the same time, so we can
×
2580
        // safely check if one signature is present to determine if we have the
×
2581
        // rest of the signatures for the auth proof.
×
2582
        if len(dbChan.Bitcoin1Signature) > 0 {
×
2583
                channel.AuthProof = &models.ChannelAuthProof{
×
2584
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
2585
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
2586
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
2587
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
2588
                }
×
2589
        }
×
2590

2591
        return channel, nil
×
2592
}
2593

2594
// buildNodeVertices is a helper that converts raw node public keys
2595
// into route.Vertex instances.
2596
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
2597
        route.Vertex, error) {
×
2598

×
2599
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
2600
        if err != nil {
×
2601
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
2602
                        "create vertex from node1 pubkey: %w", err)
×
2603
        }
×
2604

2605
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
2606
        if err != nil {
×
2607
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
2608
                        "create vertex from node2 pubkey: %w", err)
×
2609
        }
×
2610

2611
        return node1Vertex, node2Vertex, nil
×
2612
}
2613

2614
// getChanFeaturesAndExtras fetches the channel features and extra TLV types
2615
// for a channel with the given ID.
2616
func getChanFeaturesAndExtras(ctx context.Context, db SQLQueries,
2617
        id int64) (*lnwire.FeatureVector, map[uint64][]byte, error) {
×
2618

×
2619
        rows, err := db.GetChannelFeaturesAndExtras(ctx, id)
×
2620
        if err != nil {
×
2621
                return nil, nil, fmt.Errorf("unable to fetch channel "+
×
2622
                        "features and extras: %w", err)
×
2623
        }
×
2624

2625
        var (
×
2626
                fv     = lnwire.EmptyFeatureVector()
×
2627
                extras = make(map[uint64][]byte)
×
2628
        )
×
2629
        for _, row := range rows {
×
2630
                if row.IsFeature {
×
2631
                        fv.Set(lnwire.FeatureBit(row.FeatureBit))
×
2632

×
2633
                        continue
×
2634
                }
2635

2636
                tlvType, ok := row.ExtraKey.(int64)
×
2637
                if !ok {
×
2638
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
2639
                                "TLV type: %T", row.ExtraKey)
×
2640
                }
×
2641

2642
                valueBytes, ok := row.Value.([]byte)
×
2643
                if !ok {
×
2644
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
2645
                                "Value: %T", row.Value)
×
2646
                }
×
2647

2648
                extras[uint64(tlvType)] = valueBytes
×
2649
        }
2650

2651
        return fv, extras, nil
×
2652
}
2653

2654
// getAndBuildChanPolicies uses the given sqlc.ChannelPolicy and also retrieves
2655
// all the extra info required to build the complete models.ChannelEdgePolicy
2656
// types. It returns two policies, which may be nil if the provided
2657
// sqlc.ChannelPolicy records are nil.
2658
func getAndBuildChanPolicies(ctx context.Context, db SQLQueries,
2659
        dbPol1, dbPol2 *sqlc.ChannelPolicy, channelID uint64, node1,
2660
        node2 route.Vertex) (*models.ChannelEdgePolicy,
2661
        *models.ChannelEdgePolicy, error) {
×
2662

×
2663
        if dbPol1 == nil && dbPol2 == nil {
×
2664
                return nil, nil, nil
×
2665
        }
×
2666

2667
        var (
×
2668
                policy1ID int64
×
2669
                policy2ID int64
×
2670
        )
×
2671
        if dbPol1 != nil {
×
2672
                policy1ID = dbPol1.ID
×
2673
        }
×
2674
        if dbPol2 != nil {
×
2675
                policy2ID = dbPol2.ID
×
2676
        }
×
2677
        rows, err := db.GetChannelPolicyExtraTypes(
×
2678
                ctx, sqlc.GetChannelPolicyExtraTypesParams{
×
2679
                        ID:   policy1ID,
×
2680
                        ID_2: policy2ID,
×
2681
                },
×
2682
        )
×
2683
        if err != nil {
×
2684
                return nil, nil, err
×
2685
        }
×
2686

2687
        var (
×
2688
                dbPol1Extras = make(map[uint64][]byte)
×
2689
                dbPol2Extras = make(map[uint64][]byte)
×
2690
        )
×
2691
        for _, row := range rows {
×
2692
                switch row.PolicyID {
×
2693
                case policy1ID:
×
2694
                        dbPol1Extras[uint64(row.Type)] = row.Value
×
2695
                case policy2ID:
×
2696
                        dbPol2Extras[uint64(row.Type)] = row.Value
×
2697
                default:
×
2698
                        return nil, nil, fmt.Errorf("unexpected policy ID %d "+
×
2699
                                "in row: %v", row.PolicyID, row)
×
2700
                }
2701
        }
2702

2703
        var pol1, pol2 *models.ChannelEdgePolicy
×
2704
        if dbPol1 != nil {
×
2705
                pol1, err = buildChanPolicy(
×
2706
                        *dbPol1, channelID, dbPol1Extras, node2, true,
×
2707
                )
×
2708
                if err != nil {
×
2709
                        return nil, nil, err
×
2710
                }
×
2711
        }
2712
        if dbPol2 != nil {
×
2713
                pol2, err = buildChanPolicy(
×
2714
                        *dbPol2, channelID, dbPol2Extras, node1, false,
×
2715
                )
×
2716
                if err != nil {
×
2717
                        return nil, nil, err
×
2718
                }
×
2719
        }
2720

2721
        return pol1, pol2, nil
×
2722
}
2723

2724
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
2725
// provided sqlc.ChannelPolicy and other required information.
2726
func buildChanPolicy(dbPolicy sqlc.ChannelPolicy, channelID uint64,
2727
        extras map[uint64][]byte, toNode route.Vertex,
2728
        isNode1 bool) (*models.ChannelEdgePolicy, error) {
×
2729

×
2730
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
2731
        if err != nil {
×
2732
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
2733
                        "fields: %w", err)
×
2734
        }
×
2735

2736
        var msgFlags lnwire.ChanUpdateMsgFlags
×
2737
        if dbPolicy.MaxHtlcMsat.Valid {
×
2738
                msgFlags |= lnwire.ChanUpdateRequiredMaxHtlc
×
2739
        }
×
2740

2741
        var chanFlags lnwire.ChanUpdateChanFlags
×
2742
        if !isNode1 {
×
2743
                chanFlags |= lnwire.ChanUpdateDirection
×
2744
        }
×
2745
        if dbPolicy.Disabled.Bool {
×
2746
                chanFlags |= lnwire.ChanUpdateDisabled
×
2747
        }
×
2748

2749
        var inboundFee fn.Option[lnwire.Fee]
×
2750
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
2751
                dbPolicy.InboundBaseFeeMsat.Valid {
×
2752

×
2753
                inboundFee = fn.Some(lnwire.Fee{
×
2754
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
2755
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
2756
                })
×
2757
        }
×
2758

2759
        return &models.ChannelEdgePolicy{
×
2760
                SigBytes:  dbPolicy.Signature,
×
2761
                ChannelID: channelID,
×
2762
                LastUpdate: time.Unix(
×
2763
                        dbPolicy.LastUpdate.Int64, 0,
×
2764
                ),
×
2765
                MessageFlags:  msgFlags,
×
2766
                ChannelFlags:  chanFlags,
×
2767
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
2768
                MinHTLC: lnwire.MilliSatoshi(
×
2769
                        dbPolicy.MinHtlcMsat,
×
2770
                ),
×
2771
                MaxHTLC: lnwire.MilliSatoshi(
×
2772
                        dbPolicy.MaxHtlcMsat.Int64,
×
2773
                ),
×
2774
                FeeBaseMSat: lnwire.MilliSatoshi(
×
2775
                        dbPolicy.BaseFeeMsat,
×
2776
                ),
×
2777
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
2778
                ToNode:                    toNode,
×
2779
                InboundFee:                inboundFee,
×
2780
                ExtraOpaqueData:           recs,
×
2781
        }, nil
×
2782
}
2783

2784
// getAndBuildNodes builds the models.LightningNode instances for the
2785
// given row which is expected to be a sqlc type that contains node information.
2786
func buildNodes(ctx context.Context, db SQLQueries, dbNode1,
2787
        dbNode2 sqlc.Node) (*models.LightningNode, *models.LightningNode,
NEW
2788
        error) {
×
NEW
2789

×
NEW
2790
        node1, err := buildNode(ctx, db, &dbNode1)
×
NEW
2791
        if err != nil {
×
NEW
2792
                return nil, nil, err
×
NEW
2793
        }
×
2794

NEW
2795
        node2, err := buildNode(ctx, db, &dbNode2)
×
NEW
2796
        if err != nil {
×
NEW
2797
                return nil, nil, err
×
NEW
2798
        }
×
2799

NEW
2800
        return node1, node2, nil
×
2801
}
2802

2803
// extractChannelPolicies extracts the sqlc.ChannelPolicy records from the give
2804
// row which is expected to be a sqlc type that contains channel policy
2805
// information. It returns two policies, which may be nil if the policy
2806
// information is not present in the row.
2807
//
2808
//nolint:ll
2809
func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
2810
        error) {
×
2811

×
2812
        var policy1, policy2 *sqlc.ChannelPolicy
×
2813
        switch r := row.(type) {
×
NEW
2814
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
NEW
2815
                if r.Policy1ID.Valid {
×
NEW
2816
                        policy1 = &sqlc.ChannelPolicy{
×
NEW
2817
                                ID:                      r.Policy1ID.Int64,
×
NEW
2818
                                Version:                 r.Policy1Version.Int16,
×
NEW
2819
                                ChannelID:               r.Channel.ID,
×
NEW
2820
                                NodeID:                  r.Policy1NodeID.Int64,
×
NEW
2821
                                Timelock:                r.Policy1Timelock.Int32,
×
NEW
2822
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
NEW
2823
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
NEW
2824
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
NEW
2825
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
NEW
2826
                                LastUpdate:              r.Policy1LastUpdate,
×
NEW
2827
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
NEW
2828
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
NEW
2829
                                Disabled:                r.Policy1Disabled,
×
NEW
2830
                                Signature:               r.Policy1Signature,
×
NEW
2831
                        }
×
NEW
2832
                }
×
NEW
2833
                if r.Policy2ID.Valid {
×
NEW
2834
                        policy2 = &sqlc.ChannelPolicy{
×
NEW
2835
                                ID:                      r.Policy2ID.Int64,
×
NEW
2836
                                Version:                 r.Policy2Version.Int16,
×
NEW
2837
                                ChannelID:               r.Channel.ID,
×
NEW
2838
                                NodeID:                  r.Policy2NodeID.Int64,
×
NEW
2839
                                Timelock:                r.Policy2Timelock.Int32,
×
NEW
2840
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
NEW
2841
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
NEW
2842
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
NEW
2843
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
NEW
2844
                                LastUpdate:              r.Policy2LastUpdate,
×
NEW
2845
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
NEW
2846
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
NEW
2847
                                Disabled:                r.Policy2Disabled,
×
NEW
2848
                                Signature:               r.Policy2Signature,
×
NEW
2849
                        }
×
NEW
2850
                }
×
NEW
2851
                return policy1, policy2, nil
×
2852

2853
        case sqlc.ListChannelsByNodeIDRow:
×
2854
                if r.Policy1ID.Valid {
×
2855
                        policy1 = &sqlc.ChannelPolicy{
×
2856
                                ID:                      r.Policy1ID.Int64,
×
2857
                                Version:                 r.Policy1Version.Int16,
×
2858
                                ChannelID:               r.Channel.ID,
×
2859
                                NodeID:                  r.Policy1NodeID.Int64,
×
2860
                                Timelock:                r.Policy1Timelock.Int32,
×
2861
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
2862
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
2863
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
2864
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
2865
                                LastUpdate:              r.Policy1LastUpdate,
×
2866
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
2867
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
2868
                                Disabled:                r.Policy1Disabled,
×
2869
                                Signature:               r.Policy1Signature,
×
2870
                        }
×
2871
                }
×
2872
                if r.Policy2ID.Valid {
×
2873
                        policy2 = &sqlc.ChannelPolicy{
×
2874
                                ID:                      r.Policy2ID.Int64,
×
2875
                                Version:                 r.Policy2Version.Int16,
×
2876
                                ChannelID:               r.Channel.ID,
×
2877
                                NodeID:                  r.Policy2NodeID.Int64,
×
2878
                                Timelock:                r.Policy2Timelock.Int32,
×
2879
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
2880
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
2881
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
2882
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
2883
                                LastUpdate:              r.Policy2LastUpdate,
×
2884
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
2885
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
2886
                                Disabled:                r.Policy2Disabled,
×
2887
                                Signature:               r.Policy2Signature,
×
2888
                        }
×
2889
                }
×
2890

NEW
2891
                return policy1, policy2, nil
×
2892

NEW
2893
        case sqlc.ListChannelsPaginatedRow:
×
NEW
2894
                if r.Policy1ID.Valid {
×
NEW
2895
                        policy1 = &sqlc.ChannelPolicy{
×
NEW
2896
                                ID:                      r.Policy1ID.Int64,
×
NEW
2897
                                Version:                 r.Policy1Version.Int16,
×
NEW
2898
                                ChannelID:               r.Channel.ID,
×
NEW
2899
                                NodeID:                  r.Policy1NodeID.Int64,
×
NEW
2900
                                Timelock:                r.Policy1Timelock.Int32,
×
NEW
2901
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
NEW
2902
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
NEW
2903
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
NEW
2904
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
NEW
2905
                                LastUpdate:              r.Policy1LastUpdate,
×
NEW
2906
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
NEW
2907
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
NEW
2908
                                Disabled:                r.Policy1Disabled,
×
NEW
2909
                                Signature:               r.Policy1Signature,
×
NEW
2910
                        }
×
NEW
2911
                }
×
NEW
2912
                if r.Policy2ID.Valid {
×
NEW
2913
                        policy2 = &sqlc.ChannelPolicy{
×
NEW
2914
                                ID:                      r.Policy2ID.Int64,
×
NEW
2915
                                Version:                 r.Policy2Version.Int16,
×
NEW
2916
                                ChannelID:               r.Channel.ID,
×
NEW
2917
                                NodeID:                  r.Policy2NodeID.Int64,
×
NEW
2918
                                Timelock:                r.Policy2Timelock.Int32,
×
NEW
2919
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
NEW
2920
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
NEW
2921
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
NEW
2922
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
NEW
2923
                                LastUpdate:              r.Policy2LastUpdate,
×
NEW
2924
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
NEW
2925
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
NEW
2926
                                Disabled:                r.Policy2Disabled,
×
NEW
2927
                                Signature:               r.Policy2Signature,
×
NEW
2928
                        }
×
NEW
2929
                }
×
2930

2931
                return policy1, policy2, nil
×
2932
        default:
×
2933
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
2934
                        "extractChannelPolicies: %T", r)
×
2935
        }
2936
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc