• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 15592669300

11 Jun 2025 06:25PM UTC coverage: 58.274% (-10.2%) from 68.522%
15592669300

Pull #9887

github

web-flow
Merge 1950bd519 into 07f65b511
Pull Request #9887: graph/db+sqldb: channel policy SQL schemas, queries and upsert CRUD

0 of 186 new or added lines in 1 file covered. (0.0%)

28429 existing lines in 454 files now uncovered.

97734 of 167716 relevant lines covered (58.27%)

1.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "math"
11
        "net"
12
        "strconv"
13
        "sync"
14
        "time"
15

16
        "github.com/btcsuite/btcd/btcec/v2"
17
        "github.com/btcsuite/btcd/chaincfg/chainhash"
18
        "github.com/lightningnetwork/lnd/batch"
19
        "github.com/lightningnetwork/lnd/graph/db/models"
20
        "github.com/lightningnetwork/lnd/lnwire"
21
        "github.com/lightningnetwork/lnd/routing/route"
22
        "github.com/lightningnetwork/lnd/sqldb"
23
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
24
        "github.com/lightningnetwork/lnd/tlv"
25
        "github.com/lightningnetwork/lnd/tor"
26
)
27

28
// ProtocolVersion is an enum that defines the gossip protocol version of a
29
// message.
30
type ProtocolVersion uint8
31

32
const (
33
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
34
        ProtocolV1 ProtocolVersion = 1
35
)
36

37
// String returns a string representation of the protocol version.
38
func (v ProtocolVersion) String() string {
×
39
        return fmt.Sprintf("V%d", v)
×
40
}
×
41

42
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
43
// execute queries against the SQL graph tables.
44
//
45
//nolint:ll,interfacebloat
46
type SQLQueries interface {
47
        /*
48
                Node queries.
49
        */
50
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
51
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.Node, error)
52
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.Node, error)
53
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
54

55
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.NodeExtraType, error)
56
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
57
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
58

59
        InsertNodeAddress(ctx context.Context, arg sqlc.InsertNodeAddressParams) error
60
        GetNodeAddressesByPubKey(ctx context.Context, arg sqlc.GetNodeAddressesByPubKeyParams) ([]sqlc.GetNodeAddressesByPubKeyRow, error)
61
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
62

63
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
64
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.NodeFeature, error)
65
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
66
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
67

68
        /*
69
                Source node queries.
70
        */
71
        AddSourceNode(ctx context.Context, nodeID int64) error
72
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
73

74
        /*
75
                Channel queries.
76
        */
77
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
78
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.Channel, error)
79
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
80
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
81

82
        CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
83
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
84

85
        /*
86
                Channel Policy table queries.
87
        */
88
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
89

90
        GetChannelPolicyExtraTypes(ctx context.Context, arg sqlc.GetChannelPolicyExtraTypesParams) ([]sqlc.GetChannelPolicyExtraTypesRow, error)
91
        DeleteChannelPolicyExtraType(ctx context.Context, arg sqlc.DeleteChannelPolicyExtraTypeParams) error
92
        UpsertChanPolicyExtraType(ctx context.Context, arg sqlc.UpsertChanPolicyExtraTypeParams) error
93
}
94

95
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
96
// database operations.
97
type BatchedSQLQueries interface {
98
        SQLQueries
99
        sqldb.BatchedTx[SQLQueries]
100
}
101

102
// SQLStore is an implementation of the V1Store interface that uses a SQL
103
// database as the backend.
104
//
105
// NOTE: currently, this temporarily embeds the KVStore struct so that we can
106
// implement the V1Store interface incrementally. For any method not
107
// implemented,  things will fall back to the KVStore. This is ONLY the case
108
// for the time being while this struct is purely used in unit tests only.
109
type SQLStore struct {
110
        cfg *SQLStoreConfig
111
        db  BatchedSQLQueries
112

113
        // cacheMu guards all caches (rejectCache and chanCache). If
114
        // this mutex will be acquired at the same time as the DB mutex then
115
        // the cacheMu MUST be acquired first to prevent deadlock.
116
        cacheMu     sync.RWMutex
117
        rejectCache *rejectCache
118
        chanCache   *channelCache
119

120
        chanScheduler batch.Scheduler[SQLQueries]
121
        nodeScheduler batch.Scheduler[SQLQueries]
122

123
        // Temporary fall-back to the KVStore so that we can implement the
124
        // interface incrementally.
125
        *KVStore
126
}
127

128
// A compile-time assertion to ensure that SQLStore implements the V1Store
129
// interface.
130
var _ V1Store = (*SQLStore)(nil)
131

132
// SQLStoreConfig holds the configuration for the SQLStore.
133
type SQLStoreConfig struct {
134
        // ChainHash is the genesis hash for the chain that all the gossip
135
        // messages in this store are aimed at.
136
        ChainHash chainhash.Hash
137
}
138

139
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
140
// storage backend.
141
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries, kvStore *KVStore,
142
        options ...StoreOptionModifier) (*SQLStore, error) {
×
143

×
144
        opts := DefaultOptions()
×
145
        for _, o := range options {
×
146
                o(opts)
×
147
        }
×
148

149
        if opts.NoMigration {
×
150
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
151
                        "supported for SQL stores")
×
152
        }
×
153

154
        s := &SQLStore{
×
NEW
155
                cfg:         cfg,
×
156
                db:          db,
×
157
                KVStore:     kvStore,
×
158
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
159
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
160
        }
×
161

×
162
        s.chanScheduler = batch.NewTimeScheduler(
×
163
                db, &s.cacheMu, opts.BatchCommitInterval,
×
164
        )
×
165
        s.nodeScheduler = batch.NewTimeScheduler(
×
166
                db, nil, opts.BatchCommitInterval,
×
167
        )
×
168

×
169
        return s, nil
×
170
}
171

172
// AddLightningNode adds a vertex/node to the graph database. If the node is not
173
// in the database from before, this will add a new, unconnected one to the
174
// graph. If it is present from before, this will update that node's
175
// information.
176
//
177
// NOTE: part of the V1Store interface.
178
func (s *SQLStore) AddLightningNode(node *models.LightningNode,
179
        opts ...batch.SchedulerOption) error {
×
180

×
181
        ctx := context.TODO()
×
182

×
183
        r := &batch.Request[SQLQueries]{
×
184
                Opts: batch.NewSchedulerOptions(opts...),
×
185
                Do: func(queries SQLQueries) error {
×
186
                        _, err := upsertNode(ctx, queries, node)
×
187
                        return err
×
188
                },
×
189
        }
190

191
        return s.nodeScheduler.Execute(ctx, r)
×
192
}
193

194
// FetchLightningNode attempts to look up a target node by its identity public
195
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
196
// returned.
197
//
198
// NOTE: part of the V1Store interface.
199
func (s *SQLStore) FetchLightningNode(pubKey route.Vertex) (
200
        *models.LightningNode, error) {
×
201

×
202
        ctx := context.TODO()
×
203

×
204
        var node *models.LightningNode
×
205
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
206
                var err error
×
207
                _, node, err = getNodeByPubKey(ctx, db, pubKey)
×
208

×
209
                return err
×
210
        }, sqldb.NoOpReset)
×
211
        if err != nil {
×
212
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
213
        }
×
214

215
        return node, nil
×
216
}
217

218
// HasLightningNode determines if the graph has a vertex identified by the
219
// target node identity public key. If the node exists in the database, a
220
// timestamp of when the data for the node was lasted updated is returned along
221
// with a true boolean. Otherwise, an empty time.Time is returned with a false
222
// boolean.
223
//
224
// NOTE: part of the V1Store interface.
225
func (s *SQLStore) HasLightningNode(pubKey [33]byte) (time.Time, bool,
226
        error) {
×
227

×
228
        ctx := context.TODO()
×
229

×
230
        var (
×
231
                exists     bool
×
232
                lastUpdate time.Time
×
233
        )
×
234
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
235
                dbNode, err := db.GetNodeByPubKey(
×
236
                        ctx, sqlc.GetNodeByPubKeyParams{
×
237
                                Version: int16(ProtocolV1),
×
238
                                PubKey:  pubKey[:],
×
239
                        },
×
240
                )
×
241
                if errors.Is(err, sql.ErrNoRows) {
×
242
                        return nil
×
243
                } else if err != nil {
×
244
                        return fmt.Errorf("unable to fetch node: %w", err)
×
245
                }
×
246

247
                exists = true
×
248

×
249
                if dbNode.LastUpdate.Valid {
×
250
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
251
                }
×
252

253
                return nil
×
254
        }, sqldb.NoOpReset)
255
        if err != nil {
×
256
                return time.Time{}, false,
×
257
                        fmt.Errorf("unable to fetch node: %w", err)
×
258
        }
×
259

260
        return lastUpdate, exists, nil
×
261
}
262

263
// AddrsForNode returns all known addresses for the target node public key
264
// that the graph DB is aware of. The returned boolean indicates if the
265
// given node is unknown to the graph DB or not.
266
//
267
// NOTE: part of the V1Store interface.
268
func (s *SQLStore) AddrsForNode(nodePub *btcec.PublicKey) (bool, []net.Addr,
269
        error) {
×
270

×
271
        ctx := context.TODO()
×
272

×
273
        var (
×
274
                addresses []net.Addr
×
275
                known     bool
×
276
        )
×
277
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
278
                var err error
×
279
                known, addresses, err = getNodeAddresses(
×
280
                        ctx, db, nodePub.SerializeCompressed(),
×
281
                )
×
282
                if err != nil {
×
283
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
284
                                err)
×
285
                }
×
286

287
                return nil
×
288
        }, sqldb.NoOpReset)
289
        if err != nil {
×
290
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
291
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
292
        }
×
293

294
        return known, addresses, nil
×
295
}
296

297
// DeleteLightningNode starts a new database transaction to remove a vertex/node
298
// from the database according to the node's public key.
299
//
300
// NOTE: part of the V1Store interface.
301
func (s *SQLStore) DeleteLightningNode(pubKey route.Vertex) error {
×
302
        ctx := context.TODO()
×
303

×
304
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
305
                res, err := db.DeleteNodeByPubKey(
×
306
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
307
                                Version: int16(ProtocolV1),
×
308
                                PubKey:  pubKey[:],
×
309
                        },
×
310
                )
×
311
                if err != nil {
×
312
                        return err
×
313
                }
×
314

315
                rows, err := res.RowsAffected()
×
316
                if err != nil {
×
317
                        return err
×
318
                }
×
319

320
                if rows == 0 {
×
321
                        return ErrGraphNodeNotFound
×
322
                } else if rows > 1 {
×
323
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
324
                }
×
325

326
                return err
×
327
        }, sqldb.NoOpReset)
328
        if err != nil {
×
329
                return fmt.Errorf("unable to delete node: %w", err)
×
330
        }
×
331

332
        return nil
×
333
}
334

335
// FetchNodeFeatures returns the features of the given node. If no features are
336
// known for the node, an empty feature vector is returned.
337
//
338
// NOTE: this is part of the graphdb.NodeTraverser interface.
339
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
340
        *lnwire.FeatureVector, error) {
×
341

×
342
        ctx := context.TODO()
×
343

×
344
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
345
}
×
346

347
// LookupAlias attempts to return the alias as advertised by the target node.
348
//
349
// NOTE: part of the V1Store interface.
350
func (s *SQLStore) LookupAlias(pub *btcec.PublicKey) (string, error) {
×
351
        var (
×
352
                ctx   = context.TODO()
×
353
                alias string
×
354
        )
×
355
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
356
                dbNode, err := db.GetNodeByPubKey(
×
357
                        ctx, sqlc.GetNodeByPubKeyParams{
×
358
                                Version: int16(ProtocolV1),
×
359
                                PubKey:  pub.SerializeCompressed(),
×
360
                        },
×
361
                )
×
362
                if errors.Is(err, sql.ErrNoRows) {
×
363
                        return ErrNodeAliasNotFound
×
364
                } else if err != nil {
×
365
                        return fmt.Errorf("unable to fetch node: %w", err)
×
366
                }
×
367

368
                if !dbNode.Alias.Valid {
×
369
                        return ErrNodeAliasNotFound
×
370
                }
×
371

372
                alias = dbNode.Alias.String
×
373

×
374
                return nil
×
375
        }, sqldb.NoOpReset)
376
        if err != nil {
×
377
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
378
        }
×
379

380
        return alias, nil
×
381
}
382

383
// SourceNode returns the source node of the graph. The source node is treated
384
// as the center node within a star-graph. This method may be used to kick off
385
// a path finding algorithm in order to explore the reachability of another
386
// node based off the source node.
387
//
388
// NOTE: part of the V1Store interface.
389
func (s *SQLStore) SourceNode() (*models.LightningNode, error) {
×
390
        ctx := context.TODO()
×
391

×
392
        var node *models.LightningNode
×
393
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
394
                _, nodePub, err := getSourceNode(ctx, db, ProtocolV1)
×
395
                if err != nil {
×
396
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
397
                                err)
×
398
                }
×
399

400
                _, node, err = getNodeByPubKey(ctx, db, nodePub)
×
401

×
402
                return err
×
403
        }, sqldb.NoOpReset)
404
        if err != nil {
×
405
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
406
        }
×
407

408
        return node, nil
×
409
}
410

411
// SetSourceNode sets the source node within the graph database. The source
412
// node is to be used as the center of a star-graph within path finding
413
// algorithms.
414
//
415
// NOTE: part of the V1Store interface.
416
func (s *SQLStore) SetSourceNode(node *models.LightningNode) error {
×
417
        ctx := context.TODO()
×
418

×
419
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
420
                id, err := upsertNode(ctx, db, node)
×
421
                if err != nil {
×
422
                        return fmt.Errorf("unable to upsert source node: %w",
×
423
                                err)
×
424
                }
×
425

426
                // Make sure that if a source node for this version is already
427
                // set, then the ID is the same as the one we are about to set.
428
                dbSourceNodeID, _, err := getSourceNode(ctx, db, ProtocolV1)
×
429
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
430
                        return fmt.Errorf("unable to fetch source node: %w",
×
431
                                err)
×
432
                } else if err == nil {
×
433
                        if dbSourceNodeID != id {
×
434
                                return fmt.Errorf("v1 source node already "+
×
435
                                        "set to a different node: %d vs %d",
×
436
                                        dbSourceNodeID, id)
×
437
                        }
×
438

439
                        return nil
×
440
                }
441

442
                return db.AddSourceNode(ctx, id)
×
443
        }, sqldb.NoOpReset)
444
}
445

446
// NodeUpdatesInHorizon returns all the known lightning node which have an
447
// update timestamp within the passed range. This method can be used by two
448
// nodes to quickly determine if they have the same set of up to date node
449
// announcements.
450
//
451
// NOTE: This is part of the V1Store interface.
452
func (s *SQLStore) NodeUpdatesInHorizon(startTime,
453
        endTime time.Time) ([]models.LightningNode, error) {
×
454

×
455
        ctx := context.TODO()
×
456

×
457
        var nodes []models.LightningNode
×
458
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
459
                dbNodes, err := db.GetNodesByLastUpdateRange(
×
460
                        ctx, sqlc.GetNodesByLastUpdateRangeParams{
×
461
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
462
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
463
                        },
×
464
                )
×
465
                if err != nil {
×
466
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
467
                }
×
468

469
                for _, dbNode := range dbNodes {
×
470
                        node, err := buildNode(ctx, db, &dbNode)
×
471
                        if err != nil {
×
472
                                return fmt.Errorf("unable to build node: %w",
×
473
                                        err)
×
474
                        }
×
475

476
                        nodes = append(nodes, *node)
×
477
                }
478

479
                return nil
×
480
        }, sqldb.NoOpReset)
481
        if err != nil {
×
482
                return nil, fmt.Errorf("unable to fetch nodes: %w", err)
×
483
        }
×
484

485
        return nodes, nil
×
486
}
487

488
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
489
// undirected edge from the two target nodes are created. The information stored
490
// denotes the static attributes of the channel, such as the channelID, the keys
491
// involved in creation of the channel, and the set of features that the channel
492
// supports. The chanPoint and chanID are used to uniquely identify the edge
493
// globally within the database.
494
//
495
// NOTE: part of the V1Store interface.
496
func (s *SQLStore) AddChannelEdge(edge *models.ChannelEdgeInfo,
497
        opts ...batch.SchedulerOption) error {
×
498

×
499
        ctx := context.TODO()
×
500

×
501
        var alreadyExists bool
×
502
        r := &batch.Request[SQLQueries]{
×
503
                Opts: batch.NewSchedulerOptions(opts...),
×
504
                Reset: func() {
×
505
                        alreadyExists = false
×
506
                },
×
507
                Do: func(tx SQLQueries) error {
×
508
                        err := insertChannel(ctx, tx, edge)
×
509

×
510
                        // Silence ErrEdgeAlreadyExist so that the batch can
×
511
                        // succeed, but propagate the error via local state.
×
512
                        if errors.Is(err, ErrEdgeAlreadyExist) {
×
513
                                alreadyExists = true
×
514
                                return nil
×
515
                        }
×
516

517
                        return err
×
518
                },
519
                OnCommit: func(err error) error {
×
520
                        switch {
×
521
                        case err != nil:
×
522
                                return err
×
523
                        case alreadyExists:
×
524
                                return ErrEdgeAlreadyExist
×
525
                        default:
×
526
                                s.rejectCache.remove(edge.ChannelID)
×
527
                                s.chanCache.remove(edge.ChannelID)
×
528
                                return nil
×
529
                        }
530
                },
531
        }
532

533
        return s.chanScheduler.Execute(ctx, r)
×
534
}
535

536
// HighestChanID returns the "highest" known channel ID in the channel graph.
537
// This represents the "newest" channel from the PoV of the chain. This method
538
// can be used by peers to quickly determine if their graphs are in sync.
539
//
540
// NOTE: This is part of the V1Store interface.
541
func (s *SQLStore) HighestChanID() (uint64, error) {
×
542
        ctx := context.TODO()
×
543

×
544
        var highestChanID uint64
×
545
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
546
                chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
×
547
                if errors.Is(err, sql.ErrNoRows) {
×
548
                        return nil
×
549
                } else if err != nil {
×
550
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
551
                                err)
×
552
                }
×
553

554
                highestChanID = byteOrder.Uint64(chanID)
×
555

×
556
                return nil
×
557
        }, sqldb.NoOpReset)
558
        if err != nil {
×
559
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
560
        }
×
561

562
        return highestChanID, nil
×
563
}
564

565
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
566
// within the database for the referenced channel. The `flags` attribute within
567
// the ChannelEdgePolicy determines which of the directed edges are being
568
// updated. If the flag is 1, then the first node's information is being
569
// updated, otherwise it's the second node's information. The node ordering is
570
// determined by the lexicographical ordering of the identity public keys of the
571
// nodes on either side of the channel.
572
//
573
// NOTE: part of the V1Store interface.
574
func (s *SQLStore) UpdateEdgePolicy(edge *models.ChannelEdgePolicy,
NEW
575
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
NEW
576

×
NEW
577
        ctx := context.TODO()
×
NEW
578

×
NEW
579
        var (
×
NEW
580
                isUpdate1    bool
×
NEW
581
                edgeNotFound bool
×
NEW
582
                from, to     route.Vertex
×
NEW
583
        )
×
NEW
584

×
NEW
585
        r := &batch.Request[SQLQueries]{
×
NEW
586
                Opts: batch.NewSchedulerOptions(opts...),
×
NEW
587
                Reset: func() {
×
NEW
588
                        isUpdate1 = false
×
NEW
589
                        edgeNotFound = false
×
NEW
590
                },
×
NEW
591
                Do: func(tx SQLQueries) error {
×
NEW
592
                        var err error
×
NEW
593
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
NEW
594
                                ctx, tx, edge,
×
NEW
595
                        )
×
NEW
596
                        if err != nil {
×
NEW
597
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
NEW
598
                        }
×
599

600
                        // Silence ErrEdgeNotFound so that the batch can
601
                        // succeed, but propagate the error via local state.
NEW
602
                        if errors.Is(err, ErrEdgeNotFound) {
×
NEW
603
                                edgeNotFound = true
×
NEW
604
                                return nil
×
NEW
605
                        }
×
606

NEW
607
                        return err
×
608
                },
NEW
609
                OnCommit: func(err error) error {
×
NEW
610
                        switch {
×
NEW
611
                        case err != nil:
×
NEW
612
                                return err
×
NEW
613
                        case edgeNotFound:
×
NEW
614
                                return ErrEdgeNotFound
×
NEW
615
                        default:
×
NEW
616
                                s.updateEdgeCache(edge, isUpdate1)
×
NEW
617
                                return nil
×
618
                        }
619
                },
620
        }
621

NEW
622
        err := s.chanScheduler.Execute(ctx, r)
×
NEW
623

×
NEW
624
        return from, to, err
×
625
}
626

627
// updateEdgeCache updates our reject and channel caches with the new
628
// edge policy information.
629
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
NEW
630
        isUpdate1 bool) {
×
NEW
631

×
NEW
632
        // If an entry for this channel is found in reject cache, we'll modify
×
NEW
633
        // the entry with the updated timestamp for the direction that was just
×
NEW
634
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
NEW
635
        // during the next query for this edge.
×
NEW
636
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
NEW
637
                if isUpdate1 {
×
NEW
638
                        entry.upd1Time = e.LastUpdate.Unix()
×
NEW
639
                } else {
×
NEW
640
                        entry.upd2Time = e.LastUpdate.Unix()
×
NEW
641
                }
×
NEW
642
                s.rejectCache.insert(e.ChannelID, entry)
×
643
        }
644

645
        // If an entry for this channel is found in channel cache, we'll modify
646
        // the entry with the updated policy for the direction that was just
647
        // written. If the edge doesn't exist, we'll defer loading the info and
648
        // policies and lazily read from disk during the next query.
NEW
649
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
NEW
650
                if isUpdate1 {
×
NEW
651
                        channel.Policy1 = e
×
NEW
652
                } else {
×
NEW
653
                        channel.Policy2 = e
×
NEW
654
                }
×
NEW
655
                s.chanCache.insert(e.ChannelID, channel)
×
656
        }
657
}
658

659
// updateChanEdgePolicy upserts the channel policy info we have stored for
660
// a channel we already know of.
661
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
662
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
NEW
663
        error) {
×
NEW
664

×
NEW
665
        var (
×
NEW
666
                node1Pub, node2Pub route.Vertex
×
NEW
667
                isNode1            bool
×
NEW
668
                chanIDB            [8]byte
×
NEW
669
        )
×
NEW
670
        byteOrder.PutUint64(chanIDB[:], edge.ChannelID)
×
NEW
671

×
NEW
672
        // Check that this edge policy refers to a channel that we already
×
NEW
673
        // know of. We do this explicitly so that we can return the appropriate
×
NEW
674
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
NEW
675
        // abort the transaction which would abort the entire batch.
×
NEW
676
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
NEW
677
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
NEW
678
                        Scid:    chanIDB[:],
×
NEW
679
                        Version: int16(ProtocolV1),
×
NEW
680
                },
×
NEW
681
        )
×
NEW
682
        if errors.Is(err, sql.ErrNoRows) {
×
NEW
683
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
NEW
684
        } else if err != nil {
×
NEW
685
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
NEW
686
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
NEW
687
        }
×
688

NEW
689
        copy(node1Pub[:], dbChan.Node1PubKey)
×
NEW
690
        copy(node2Pub[:], dbChan.Node2PubKey)
×
NEW
691

×
NEW
692
        // Figure out which node this edge is from.
×
NEW
693
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
NEW
694
        nodeID := dbChan.NodeID1
×
NEW
695
        if !isNode1 {
×
NEW
696
                nodeID = dbChan.NodeID2
×
NEW
697
        }
×
698

NEW
699
        var (
×
NEW
700
                inboundBase sql.NullInt64
×
NEW
701
                inboundRate sql.NullInt64
×
NEW
702
        )
×
NEW
703
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
NEW
704
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
NEW
705
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
NEW
706
        })
×
707

NEW
708
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
NEW
709
                Version:     int16(ProtocolV1),
×
NEW
710
                ChannelID:   dbChan.ID,
×
NEW
711
                NodeID:      nodeID,
×
NEW
712
                Timelock:    int32(edge.TimeLockDelta),
×
NEW
713
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
NEW
714
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
NEW
715
                MinHtlcMsat: int64(edge.MinHTLC),
×
NEW
716
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
NEW
717
                Disabled: sql.NullBool{
×
NEW
718
                        Valid: true,
×
NEW
719
                        Bool:  edge.IsDisabled(),
×
NEW
720
                },
×
NEW
721
                MaxHtlcMsat: sql.NullInt64{
×
NEW
722
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
NEW
723
                        Int64: int64(edge.MaxHTLC),
×
NEW
724
                },
×
NEW
725
                InboundBaseFeeMsat:      inboundBase,
×
NEW
726
                InboundFeeRateMilliMsat: inboundRate,
×
NEW
727
                Signature:               edge.SigBytes,
×
NEW
728
        })
×
NEW
729
        if err != nil {
×
NEW
730
                return node1Pub, node2Pub, isNode1,
×
NEW
731
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
NEW
732
        }
×
733

734
        // Convert the flat extra opaque data into a map of TLV types to
735
        // values.
NEW
736
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
NEW
737
        if err != nil {
×
NEW
738
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
NEW
739
                        "marshal extra opaque data: %w", err)
×
NEW
740
        }
×
741

742
        // Update the channel policy's extra signed fields.
NEW
743
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
NEW
744
        if err != nil {
×
NEW
745
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
NEW
746
                        "policy extra TLVs: %w", err)
×
NEW
747
        }
×
748

NEW
749
        return node1Pub, node2Pub, isNode1, nil
×
750
}
751

752
// getNodeByPubKey attempts to look up a target node by its public key.
753
func getNodeByPubKey(ctx context.Context, db SQLQueries,
754
        pubKey route.Vertex) (int64, *models.LightningNode, error) {
×
755

×
756
        dbNode, err := db.GetNodeByPubKey(
×
757
                ctx, sqlc.GetNodeByPubKeyParams{
×
758
                        Version: int16(ProtocolV1),
×
759
                        PubKey:  pubKey[:],
×
760
                },
×
761
        )
×
762
        if errors.Is(err, sql.ErrNoRows) {
×
763
                return 0, nil, ErrGraphNodeNotFound
×
764
        } else if err != nil {
×
765
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
766
        }
×
767

768
        node, err := buildNode(ctx, db, &dbNode)
×
769
        if err != nil {
×
770
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
771
        }
×
772

773
        return dbNode.ID, node, nil
×
774
}
775

776
// buildNode constructs a LightningNode instance from the given database node
777
// record. The node's features, addresses and extra signed fields are also
778
// fetched from the database and set on the node.
779
func buildNode(ctx context.Context, db SQLQueries, dbNode *sqlc.Node) (
780
        *models.LightningNode, error) {
×
781

×
782
        if dbNode.Version != int16(ProtocolV1) {
×
783
                return nil, fmt.Errorf("unsupported node version: %d",
×
784
                        dbNode.Version)
×
785
        }
×
786

787
        var pub [33]byte
×
788
        copy(pub[:], dbNode.PubKey)
×
789

×
790
        node := &models.LightningNode{
×
791
                PubKeyBytes: pub,
×
792
                Features:    lnwire.EmptyFeatureVector(),
×
793
                LastUpdate:  time.Unix(0, 0),
×
794
        }
×
795

×
796
        if len(dbNode.Signature) == 0 {
×
797
                return node, nil
×
798
        }
×
799

800
        node.HaveNodeAnnouncement = true
×
801
        node.AuthSigBytes = dbNode.Signature
×
802
        node.Alias = dbNode.Alias.String
×
803
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
804

×
805
        var err error
×
806
        node.Color, err = DecodeHexColor(dbNode.Color.String)
×
807
        if err != nil {
×
808
                return nil, fmt.Errorf("unable to decode color: %w", err)
×
809
        }
×
810

811
        // Fetch the node's features.
812
        node.Features, err = getNodeFeatures(ctx, db, dbNode.ID)
×
813
        if err != nil {
×
814
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
815
                        "features: %w", dbNode.ID, err)
×
816
        }
×
817

818
        // Fetch the node's addresses.
819
        _, node.Addresses, err = getNodeAddresses(ctx, db, pub[:])
×
820
        if err != nil {
×
821
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
822
                        "addresses: %w", dbNode.ID, err)
×
823
        }
×
824

825
        // Fetch the node's extra signed fields.
826
        extraTLVMap, err := getNodeExtraSignedFields(ctx, db, dbNode.ID)
×
827
        if err != nil {
×
828
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
829
                        "extra signed fields: %w", dbNode.ID, err)
×
830
        }
×
831

832
        recs, err := lnwire.CustomRecords(extraTLVMap).Serialize()
×
833
        if err != nil {
×
834
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
835
                        "fields: %w", err)
×
836
        }
×
837

838
        if len(recs) != 0 {
×
839
                node.ExtraOpaqueData = recs
×
840
        }
×
841

842
        return node, nil
×
843
}
844

845
// getNodeFeatures fetches the feature bits and constructs the feature vector
846
// for a node with the given DB ID.
847
func getNodeFeatures(ctx context.Context, db SQLQueries,
848
        nodeID int64) (*lnwire.FeatureVector, error) {
×
849

×
850
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
851
        if err != nil {
×
852
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
853
                        nodeID, err)
×
854
        }
×
855

856
        features := lnwire.EmptyFeatureVector()
×
857
        for _, feature := range rows {
×
858
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
859
        }
×
860

861
        return features, nil
×
862
}
863

864
// getNodeExtraSignedFields fetches the extra signed fields for a node with the
865
// given DB ID.
866
func getNodeExtraSignedFields(ctx context.Context, db SQLQueries,
867
        nodeID int64) (map[uint64][]byte, error) {
×
868

×
869
        fields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
870
        if err != nil {
×
871
                return nil, fmt.Errorf("unable to get node(%d) extra "+
×
872
                        "signed fields: %w", nodeID, err)
×
873
        }
×
874

875
        extraFields := make(map[uint64][]byte)
×
876
        for _, field := range fields {
×
877
                extraFields[uint64(field.Type)] = field.Value
×
878
        }
×
879

880
        return extraFields, nil
×
881
}
882

883
// upsertNode upserts the node record into the database. If the node already
884
// exists, then the node's information is updated. If the node doesn't exist,
885
// then a new node is created. The node's features, addresses and extra TLV
886
// types are also updated. The node's DB ID is returned.
887
func upsertNode(ctx context.Context, db SQLQueries,
888
        node *models.LightningNode) (int64, error) {
×
889

×
890
        params := sqlc.UpsertNodeParams{
×
891
                Version: int16(ProtocolV1),
×
892
                PubKey:  node.PubKeyBytes[:],
×
893
        }
×
894

×
895
        if node.HaveNodeAnnouncement {
×
896
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
897
                params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
×
898
                params.Alias = sqldb.SQLStr(node.Alias)
×
899
                params.Signature = node.AuthSigBytes
×
900
        }
×
901

902
        nodeID, err := db.UpsertNode(ctx, params)
×
903
        if err != nil {
×
904
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
905
                        err)
×
906
        }
×
907

908
        // We can exit here if we don't have the announcement yet.
909
        if !node.HaveNodeAnnouncement {
×
910
                return nodeID, nil
×
911
        }
×
912

913
        // Update the node's features.
914
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
915
        if err != nil {
×
916
                return 0, fmt.Errorf("inserting node features: %w", err)
×
917
        }
×
918

919
        // Update the node's addresses.
920
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
921
        if err != nil {
×
922
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
923
        }
×
924

925
        // Convert the flat extra opaque data into a map of TLV types to
926
        // values.
927
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
928
        if err != nil {
×
929
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
930
                        err)
×
931
        }
×
932

933
        // Update the node's extra signed fields.
934
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
935
        if err != nil {
×
936
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
937
        }
×
938

939
        return nodeID, nil
×
940
}
941

942
// upsertNodeFeatures updates the node's features node_features table. This
943
// includes deleting any feature bits no longer present and inserting any new
944
// feature bits. If the feature bit does not yet exist in the features table,
945
// then an entry is created in that table first.
946
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
947
        features *lnwire.FeatureVector) error {
×
948

×
949
        // Get any existing features for the node.
×
950
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
951
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
952
                return err
×
953
        }
×
954

955
        // Copy the nodes latest set of feature bits.
956
        newFeatures := make(map[int32]struct{})
×
957
        if features != nil {
×
958
                for feature := range features.Features() {
×
959
                        newFeatures[int32(feature)] = struct{}{}
×
960
                }
×
961
        }
962

963
        // For any current feature that already exists in the DB, remove it from
964
        // the in-memory map. For any existing feature that does not exist in
965
        // the in-memory map, delete it from the database.
966
        for _, feature := range existingFeatures {
×
967
                // The feature is still present, so there are no updates to be
×
968
                // made.
×
969
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
970
                        delete(newFeatures, feature.FeatureBit)
×
971
                        continue
×
972
                }
973

974
                // The feature is no longer present, so we remove it from the
975
                // database.
976
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
977
                        NodeID:     nodeID,
×
978
                        FeatureBit: feature.FeatureBit,
×
979
                })
×
980
                if err != nil {
×
981
                        return fmt.Errorf("unable to delete node(%d) "+
×
982
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
983
                                err)
×
984
                }
×
985
        }
986

987
        // Any remaining entries in newFeatures are new features that need to be
988
        // added to the database for the first time.
989
        for feature := range newFeatures {
×
990
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
991
                        NodeID:     nodeID,
×
992
                        FeatureBit: feature,
×
993
                })
×
994
                if err != nil {
×
995
                        return fmt.Errorf("unable to insert node(%d) "+
×
996
                                "feature(%v): %w", nodeID, feature, err)
×
997
                }
×
998
        }
999

1000
        return nil
×
1001
}
1002

1003
// fetchNodeFeatures fetches the features for a node with the given public key.
1004
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
1005
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
1006

×
1007
        rows, err := queries.GetNodeFeaturesByPubKey(
×
1008
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
1009
                        PubKey:  nodePub[:],
×
1010
                        Version: int16(ProtocolV1),
×
1011
                },
×
1012
        )
×
1013
        if err != nil {
×
1014
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
1015
                        nodePub, err)
×
1016
        }
×
1017

1018
        features := lnwire.EmptyFeatureVector()
×
1019
        for _, bit := range rows {
×
1020
                features.Set(lnwire.FeatureBit(bit))
×
1021
        }
×
1022

1023
        return features, nil
×
1024
}
1025

1026
// dbAddressType is an enum type that represents the different address types
1027
// that we store in the node_addresses table. The address type determines how
1028
// the address is to be serialised/deserialize.
1029
type dbAddressType uint8
1030

1031
const (
1032
        addressTypeIPv4   dbAddressType = 1
1033
        addressTypeIPv6   dbAddressType = 2
1034
        addressTypeTorV2  dbAddressType = 3
1035
        addressTypeTorV3  dbAddressType = 4
1036
        addressTypeOpaque dbAddressType = math.MaxInt8
1037
)
1038

1039
// upsertNodeAddresses updates the node's addresses in the database. This
1040
// includes deleting any existing addresses and inserting the new set of
1041
// addresses. The deletion is necessary since the ordering of the addresses may
1042
// change, and we need to ensure that the database reflects the latest set of
1043
// addresses so that at the time of reconstructing the node announcement, the
1044
// order is preserved and the signature over the message remains valid.
1045
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
1046
        addresses []net.Addr) error {
×
1047

×
1048
        // Delete any existing addresses for the node. This is required since
×
1049
        // even if the new set of addresses is the same, the ordering may have
×
1050
        // changed for a given address type.
×
1051
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
1052
        if err != nil {
×
1053
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
1054
                        nodeID, err)
×
1055
        }
×
1056

1057
        // Copy the nodes latest set of addresses.
1058
        newAddresses := map[dbAddressType][]string{
×
1059
                addressTypeIPv4:   {},
×
1060
                addressTypeIPv6:   {},
×
1061
                addressTypeTorV2:  {},
×
1062
                addressTypeTorV3:  {},
×
1063
                addressTypeOpaque: {},
×
1064
        }
×
1065
        addAddr := func(t dbAddressType, addr net.Addr) {
×
1066
                newAddresses[t] = append(newAddresses[t], addr.String())
×
1067
        }
×
1068

1069
        for _, address := range addresses {
×
1070
                switch addr := address.(type) {
×
1071
                case *net.TCPAddr:
×
1072
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
1073
                                addAddr(addressTypeIPv4, addr)
×
1074
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
1075
                                addAddr(addressTypeIPv6, addr)
×
1076
                        } else {
×
1077
                                return fmt.Errorf("unhandled IP address: %v",
×
1078
                                        addr)
×
1079
                        }
×
1080

1081
                case *tor.OnionAddr:
×
1082
                        switch len(addr.OnionService) {
×
1083
                        case tor.V2Len:
×
1084
                                addAddr(addressTypeTorV2, addr)
×
1085
                        case tor.V3Len:
×
1086
                                addAddr(addressTypeTorV3, addr)
×
1087
                        default:
×
1088
                                return fmt.Errorf("invalid length for a tor " +
×
1089
                                        "address")
×
1090
                        }
1091

1092
                case *lnwire.OpaqueAddrs:
×
1093
                        addAddr(addressTypeOpaque, addr)
×
1094

1095
                default:
×
1096
                        return fmt.Errorf("unhandled address type: %T", addr)
×
1097
                }
1098
        }
1099

1100
        // Any remaining entries in newAddresses are new addresses that need to
1101
        // be added to the database for the first time.
1102
        for addrType, addrList := range newAddresses {
×
1103
                for position, addr := range addrList {
×
1104
                        err := db.InsertNodeAddress(
×
1105
                                ctx, sqlc.InsertNodeAddressParams{
×
1106
                                        NodeID:   nodeID,
×
1107
                                        Type:     int16(addrType),
×
1108
                                        Address:  addr,
×
1109
                                        Position: int32(position),
×
1110
                                },
×
1111
                        )
×
1112
                        if err != nil {
×
1113
                                return fmt.Errorf("unable to insert "+
×
1114
                                        "node(%d) address(%v): %w", nodeID,
×
1115
                                        addr, err)
×
1116
                        }
×
1117
                }
1118
        }
1119

1120
        return nil
×
1121
}
1122

1123
// getNodeAddresses fetches the addresses for a node with the given public key.
1124
func getNodeAddresses(ctx context.Context, db SQLQueries, nodePub []byte) (bool,
1125
        []net.Addr, error) {
×
1126

×
1127
        // GetNodeAddressesByPubKey ensures that the addresses for a given type
×
1128
        // are returned in the same order as they were inserted.
×
1129
        rows, err := db.GetNodeAddressesByPubKey(
×
1130
                ctx, sqlc.GetNodeAddressesByPubKeyParams{
×
1131
                        Version: int16(ProtocolV1),
×
1132
                        PubKey:  nodePub,
×
1133
                },
×
1134
        )
×
1135
        if err != nil {
×
1136
                return false, nil, err
×
1137
        }
×
1138

1139
        // GetNodeAddressesByPubKey uses a left join so there should always be
1140
        // at least one row returned if the node exists even if it has no
1141
        // addresses.
1142
        if len(rows) == 0 {
×
1143
                return false, nil, nil
×
1144
        }
×
1145

1146
        addresses := make([]net.Addr, 0, len(rows))
×
1147
        for _, addr := range rows {
×
1148
                if !(addr.Type.Valid && addr.Address.Valid) {
×
1149
                        continue
×
1150
                }
1151

1152
                address := addr.Address.String
×
1153

×
1154
                switch dbAddressType(addr.Type.Int16) {
×
1155
                case addressTypeIPv4:
×
1156
                        tcp, err := net.ResolveTCPAddr("tcp4", address)
×
1157
                        if err != nil {
×
1158
                                return false, nil, nil
×
1159
                        }
×
1160
                        tcp.IP = tcp.IP.To4()
×
1161

×
1162
                        addresses = append(addresses, tcp)
×
1163

1164
                case addressTypeIPv6:
×
1165
                        tcp, err := net.ResolveTCPAddr("tcp6", address)
×
1166
                        if err != nil {
×
1167
                                return false, nil, nil
×
1168
                        }
×
1169
                        addresses = append(addresses, tcp)
×
1170

1171
                case addressTypeTorV3, addressTypeTorV2:
×
1172
                        service, portStr, err := net.SplitHostPort(address)
×
1173
                        if err != nil {
×
1174
                                return false, nil, fmt.Errorf("unable to "+
×
1175
                                        "split tor v3 address: %v",
×
1176
                                        addr.Address)
×
1177
                        }
×
1178

1179
                        port, err := strconv.Atoi(portStr)
×
1180
                        if err != nil {
×
1181
                                return false, nil, err
×
1182
                        }
×
1183

1184
                        addresses = append(addresses, &tor.OnionAddr{
×
1185
                                OnionService: service,
×
1186
                                Port:         port,
×
1187
                        })
×
1188

1189
                case addressTypeOpaque:
×
1190
                        opaque, err := hex.DecodeString(address)
×
1191
                        if err != nil {
×
1192
                                return false, nil, fmt.Errorf("unable to "+
×
1193
                                        "decode opaque address: %v", addr)
×
1194
                        }
×
1195

1196
                        addresses = append(addresses, &lnwire.OpaqueAddrs{
×
1197
                                Payload: opaque,
×
1198
                        })
×
1199

1200
                default:
×
1201
                        return false, nil, fmt.Errorf("unknown address "+
×
1202
                                "type: %v", addr.Type)
×
1203
                }
1204
        }
1205

1206
        return true, addresses, nil
×
1207
}
1208

1209
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
1210
// database. This includes updating any existing types, inserting any new types,
1211
// and deleting any types that are no longer present.
1212
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
1213
        nodeID int64, extraFields map[uint64][]byte) error {
×
1214

×
1215
        // Get any existing extra signed fields for the node.
×
1216
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
1217
        if err != nil {
×
1218
                return err
×
1219
        }
×
1220

1221
        // Make a lookup map of the existing field types so that we can use it
1222
        // to keep track of any fields we should delete.
1223
        m := make(map[uint64]bool)
×
1224
        for _, field := range existingFields {
×
1225
                m[uint64(field.Type)] = true
×
1226
        }
×
1227

1228
        // For all the new fields, we'll upsert them and remove them from the
1229
        // map of existing fields.
1230
        for tlvType, value := range extraFields {
×
1231
                err = db.UpsertNodeExtraType(
×
1232
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
1233
                                NodeID: nodeID,
×
1234
                                Type:   int64(tlvType),
×
1235
                                Value:  value,
×
1236
                        },
×
1237
                )
×
1238
                if err != nil {
×
1239
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
1240
                                "signed field(%v): %w", nodeID, tlvType, err)
×
1241
                }
×
1242

1243
                // Remove the field from the map of existing fields if it was
1244
                // present.
1245
                delete(m, tlvType)
×
1246
        }
1247

1248
        // For all the fields that are left in the map of existing fields, we'll
1249
        // delete them as they are no longer present in the new set of fields.
1250
        for tlvType := range m {
×
1251
                err = db.DeleteExtraNodeType(
×
1252
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
1253
                                NodeID: nodeID,
×
1254
                                Type:   int64(tlvType),
×
1255
                        },
×
1256
                )
×
1257
                if err != nil {
×
1258
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
1259
                                "signed field(%v): %w", nodeID, tlvType, err)
×
1260
                }
×
1261
        }
1262

1263
        return nil
×
1264
}
1265

1266
// getSourceNode returns the DB node ID and pub key of the source node for the
1267
// specified protocol version.
1268
func getSourceNode(ctx context.Context, db SQLQueries,
1269
        version ProtocolVersion) (int64, route.Vertex, error) {
×
1270

×
1271
        var pubKey route.Vertex
×
1272

×
1273
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
1274
        if err != nil {
×
1275
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
1276
                        err)
×
1277
        }
×
1278

1279
        if len(nodes) == 0 {
×
1280
                return 0, pubKey, ErrSourceNodeNotSet
×
1281
        } else if len(nodes) > 1 {
×
1282
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
1283
                        "protocol %s found", version)
×
1284
        }
×
1285

1286
        copy(pubKey[:], nodes[0].PubKey)
×
1287

×
1288
        return nodes[0].NodeID, pubKey, nil
×
1289
}
1290

1291
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
1292
// This then produces a map from TLV type to value. If the input is not a
1293
// valid TLV stream, then an error is returned.
1294
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
1295
        r := bytes.NewReader(data)
×
1296

×
1297
        tlvStream, err := tlv.NewStream()
×
1298
        if err != nil {
×
1299
                return nil, err
×
1300
        }
×
1301

1302
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
1303
        // pass it into the P2P decoding variant.
1304
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
1305
        if err != nil {
×
1306
                return nil, err
×
1307
        }
×
1308
        if len(parsedTypes) == 0 {
×
1309
                return nil, nil
×
1310
        }
×
1311

1312
        records := make(map[uint64][]byte)
×
1313
        for k, v := range parsedTypes {
×
1314
                records[uint64(k)] = v
×
1315
        }
×
1316

1317
        return records, nil
×
1318
}
1319

1320
// insertChannel inserts a new channel record into the database.
1321
func insertChannel(ctx context.Context, db SQLQueries,
1322
        edge *models.ChannelEdgeInfo) error {
×
1323

×
1324
        var chanIDB [8]byte
×
1325
        byteOrder.PutUint64(chanIDB[:], edge.ChannelID)
×
1326

×
1327
        // Make sure that the channel doesn't already exist. We do this
×
1328
        // explicitly instead of relying on catching a unique constraint error
×
1329
        // because relying on SQL to throw that error would abort the entire
×
1330
        // batch of transactions.
×
1331
        _, err := db.GetChannelBySCID(
×
1332
                ctx, sqlc.GetChannelBySCIDParams{
×
1333
                        Scid:    chanIDB[:],
×
1334
                        Version: int16(ProtocolV1),
×
1335
                },
×
1336
        )
×
1337
        if err == nil {
×
1338
                return ErrEdgeAlreadyExist
×
1339
        } else if !errors.Is(err, sql.ErrNoRows) {
×
1340
                return fmt.Errorf("unable to fetch channel: %w", err)
×
1341
        }
×
1342

1343
        // Make sure that at least a "shell" entry for each node is present in
1344
        // the nodes table.
1345
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
1346
        if err != nil {
×
1347
                return fmt.Errorf("unable to create shell node: %w", err)
×
1348
        }
×
1349

1350
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
1351
        if err != nil {
×
1352
                return fmt.Errorf("unable to create shell node: %w", err)
×
1353
        }
×
1354

1355
        var capacity sql.NullInt64
×
1356
        if edge.Capacity != 0 {
×
1357
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
1358
        }
×
1359

1360
        createParams := sqlc.CreateChannelParams{
×
1361
                Version:     int16(ProtocolV1),
×
1362
                Scid:        chanIDB[:],
×
1363
                NodeID1:     node1DBID,
×
1364
                NodeID2:     node2DBID,
×
1365
                Outpoint:    edge.ChannelPoint.String(),
×
1366
                Capacity:    capacity,
×
1367
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
1368
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
1369
        }
×
1370

×
1371
        if edge.AuthProof != nil {
×
1372
                proof := edge.AuthProof
×
1373

×
1374
                createParams.Node1Signature = proof.NodeSig1Bytes
×
1375
                createParams.Node2Signature = proof.NodeSig2Bytes
×
1376
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
1377
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
1378
        }
×
1379

1380
        // Insert the new channel record.
1381
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
1382
        if err != nil {
×
1383
                return err
×
1384
        }
×
1385

1386
        // Insert any channel features.
1387
        if len(edge.Features) != 0 {
×
1388
                chanFeatures := lnwire.NewRawFeatureVector()
×
1389
                err := chanFeatures.Decode(bytes.NewReader(edge.Features))
×
1390
                if err != nil {
×
1391
                        return err
×
1392
                }
×
1393

1394
                fv := lnwire.NewFeatureVector(chanFeatures, lnwire.Features)
×
1395
                for feature := range fv.Features() {
×
1396
                        err = db.InsertChannelFeature(
×
1397
                                ctx, sqlc.InsertChannelFeatureParams{
×
1398
                                        ChannelID:  dbChanID,
×
1399
                                        FeatureBit: int32(feature),
×
1400
                                },
×
1401
                        )
×
1402
                        if err != nil {
×
1403
                                return fmt.Errorf("unable to insert "+
×
1404
                                        "channel(%d) feature(%v): %w", dbChanID,
×
1405
                                        feature, err)
×
1406
                        }
×
1407
                }
1408
        }
1409

1410
        // Finally, insert any extra TLV fields in the channel announcement.
1411
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
1412
        if err != nil {
×
1413
                return fmt.Errorf("unable to marshal extra opaque data: %w",
×
1414
                        err)
×
1415
        }
×
1416

1417
        for tlvType, value := range extra {
×
1418
                err := db.CreateChannelExtraType(
×
1419
                        ctx, sqlc.CreateChannelExtraTypeParams{
×
1420
                                ChannelID: dbChanID,
×
1421
                                Type:      int64(tlvType),
×
1422
                                Value:     value,
×
1423
                        },
×
1424
                )
×
1425
                if err != nil {
×
1426
                        return fmt.Errorf("unable to upsert channel(%d) extra "+
×
1427
                                "signed field(%v): %w", edge.ChannelID,
×
1428
                                tlvType, err)
×
1429
                }
×
1430
        }
1431

1432
        return nil
×
1433
}
1434

1435
// maybeCreateShellNode checks if a shell node entry exists for the
1436
// given public key. If it does not exist, then a new shell node entry is
1437
// created. The ID of the node is returned. A shell node only has a protocol
1438
// version and public key persisted.
1439
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
1440
        pubKey route.Vertex) (int64, error) {
×
1441

×
1442
        dbNode, err := db.GetNodeByPubKey(
×
1443
                ctx, sqlc.GetNodeByPubKeyParams{
×
1444
                        PubKey:  pubKey[:],
×
1445
                        Version: int16(ProtocolV1),
×
1446
                },
×
1447
        )
×
1448
        // The node exists. Return the ID.
×
1449
        if err == nil {
×
1450
                return dbNode.ID, nil
×
1451
        } else if !errors.Is(err, sql.ErrNoRows) {
×
1452
                return 0, err
×
1453
        }
×
1454

1455
        // Otherwise, the node does not exist, so we create a shell entry for
1456
        // it.
1457
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
1458
                Version: int16(ProtocolV1),
×
1459
                PubKey:  pubKey[:],
×
1460
        })
×
1461
        if err != nil {
×
1462
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
1463
        }
×
1464

1465
        return id, nil
×
1466
}
1467

1468
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
1469
// the database. This includes updating any existing types, inserting any new
1470
// types, and deleting any types that are no longer present.
1471
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
NEW
1472
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
NEW
1473

×
NEW
1474
        // Get any existing extra signed fields for the channel policy.
×
NEW
1475
        existingFields, err := db.GetChannelPolicyExtraTypes(
×
NEW
1476
                ctx, sqlc.GetChannelPolicyExtraTypesParams{
×
NEW
1477
                        ID: chanPolicyID,
×
NEW
1478
                },
×
NEW
1479
        )
×
NEW
1480
        if err != nil {
×
NEW
1481
                return err
×
NEW
1482
        }
×
1483

1484
        // Make a lookup map of the existing field types so that we can use it
1485
        // to keep track of any fields we should delete.
NEW
1486
        m := make(map[uint64]bool)
×
NEW
1487
        for _, field := range existingFields {
×
NEW
1488
                if field.PolicyID != chanPolicyID {
×
NEW
1489
                        return fmt.Errorf("channel policy ID mismatch: "+
×
NEW
1490
                                "expected %d, got %d", chanPolicyID,
×
NEW
1491
                                field.PolicyID)
×
NEW
1492
                }
×
1493

NEW
1494
                m[uint64(field.Type)] = true
×
1495
        }
1496

1497
        // For all the new fields, we'll upsert them and remove them from the
1498
        // map of existing fields.
NEW
1499
        for tlvType, value := range extraFields {
×
NEW
1500
                err = db.UpsertChanPolicyExtraType(
×
NEW
1501
                        ctx, sqlc.UpsertChanPolicyExtraTypeParams{
×
NEW
1502
                                ChannelPolicyID: chanPolicyID,
×
NEW
1503
                                Type:            int64(tlvType),
×
NEW
1504
                                Value:           value,
×
NEW
1505
                        },
×
NEW
1506
                )
×
NEW
1507
                if err != nil {
×
NEW
1508
                        return fmt.Errorf("unable to upsert "+
×
NEW
1509
                                "channel_policy(%d) extra signed field(%v): %w",
×
NEW
1510
                                chanPolicyID, tlvType, err)
×
NEW
1511
                }
×
1512

1513
                // Remove the field from the map of existing fields if it was
1514
                // present.
NEW
1515
                delete(m, tlvType)
×
1516
        }
1517

1518
        // For all the fields that are left in the map of existing fields, we'll
1519
        // delete them as they are no longer present in the new set of fields.
NEW
1520
        for tlvType := range m {
×
NEW
1521
                err = db.DeleteChannelPolicyExtraType(
×
NEW
1522
                        ctx, sqlc.DeleteChannelPolicyExtraTypeParams{
×
NEW
1523
                                ChannelPolicyID: chanPolicyID,
×
NEW
1524
                                Type:            int64(tlvType),
×
NEW
1525
                        },
×
NEW
1526
                )
×
NEW
1527
                if err != nil {
×
NEW
1528
                        return fmt.Errorf("unable to delete "+
×
NEW
1529
                                "channel_policy(%d) extra signed field(%v): %w",
×
NEW
1530
                                chanPolicyID, tlvType, err)
×
NEW
1531
                }
×
1532
        }
1533

NEW
1534
        return nil
×
1535
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc