• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 15706382757

17 Jun 2025 11:48AM UTC coverage: 58.354% (-10.2%) from 68.507%
15706382757

Pull #9956

github

web-flow
Merge eb9269fcd into a5c4a7c54
Pull Request #9956: multi: add `context.Context` param to some `graphdb.V1Store` methods

47 of 59 new or added lines in 13 files covered. (79.66%)

28405 existing lines in 455 files now uncovered.

97819 of 167629 relevant lines covered (58.35%)

1.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "math"
11
        "net"
12
        "strconv"
13
        "sync"
14
        "time"
15

16
        "github.com/btcsuite/btcd/btcec/v2"
17
        "github.com/lightningnetwork/lnd/batch"
18
        "github.com/lightningnetwork/lnd/graph/db/models"
19
        "github.com/lightningnetwork/lnd/lnwire"
20
        "github.com/lightningnetwork/lnd/routing/route"
21
        "github.com/lightningnetwork/lnd/sqldb"
22
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
23
        "github.com/lightningnetwork/lnd/tlv"
24
        "github.com/lightningnetwork/lnd/tor"
25
)
26

27
// ProtocolVersion is an enum that defines the gossip protocol version of a
28
// message.
29
type ProtocolVersion uint8
30

31
const (
32
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
33
        ProtocolV1 ProtocolVersion = 1
34
)
35

36
// String returns a string representation of the protocol version.
37
func (v ProtocolVersion) String() string {
×
38
        return fmt.Sprintf("V%d", v)
×
39
}
×
40

41
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
42
// execute queries against the SQL graph tables.
43
//
44
//nolint:ll,interfacebloat
45
type SQLQueries interface {
46
        /*
47
                Node queries.
48
        */
49
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
50
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.Node, error)
51
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.Node, error)
52
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
53

54
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.NodeExtraType, error)
55
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
56
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
57

58
        InsertNodeAddress(ctx context.Context, arg sqlc.InsertNodeAddressParams) error
59
        GetNodeAddressesByPubKey(ctx context.Context, arg sqlc.GetNodeAddressesByPubKeyParams) ([]sqlc.GetNodeAddressesByPubKeyRow, error)
60
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
61

62
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
63
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.NodeFeature, error)
64
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
65
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
66

67
        /*
68
                Source node queries.
69
        */
70
        AddSourceNode(ctx context.Context, nodeID int64) error
71
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
72

73
        /*
74
                Channel queries.
75
        */
76
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
77
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.Channel, error)
78
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
79

80
        CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
81
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
82
}
83

84
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
85
// database operations.
86
type BatchedSQLQueries interface {
87
        SQLQueries
88
        sqldb.BatchedTx[SQLQueries]
89
}
90

91
// SQLStore is an implementation of the V1Store interface that uses a SQL
92
// database as the backend.
93
//
94
// NOTE: currently, this temporarily embeds the KVStore struct so that we can
95
// implement the V1Store interface incrementally. For any method not
96
// implemented,  things will fall back to the KVStore. This is ONLY the case
97
// for the time being while this struct is purely used in unit tests only.
98
type SQLStore struct {
99
        db BatchedSQLQueries
100

101
        // cacheMu guards all caches (rejectCache and chanCache). If
102
        // this mutex will be acquired at the same time as the DB mutex then
103
        // the cacheMu MUST be acquired first to prevent deadlock.
104
        cacheMu     sync.RWMutex
105
        rejectCache *rejectCache
106
        chanCache   *channelCache
107

108
        chanScheduler batch.Scheduler[SQLQueries]
109
        nodeScheduler batch.Scheduler[SQLQueries]
110

111
        // Temporary fall-back to the KVStore so that we can implement the
112
        // interface incrementally.
113
        *KVStore
114
}
115

116
// A compile-time assertion to ensure that SQLStore implements the V1Store
117
// interface.
118
var _ V1Store = (*SQLStore)(nil)
119

120
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
121
// storage backend.
122
func NewSQLStore(db BatchedSQLQueries, kvStore *KVStore,
123
        options ...StoreOptionModifier) (*SQLStore, error) {
×
124

×
125
        opts := DefaultOptions()
×
126
        for _, o := range options {
×
127
                o(opts)
×
128
        }
×
129

130
        if opts.NoMigration {
×
131
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
132
                        "supported for SQL stores")
×
133
        }
×
134

135
        s := &SQLStore{
×
136
                db:          db,
×
137
                KVStore:     kvStore,
×
138
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
139
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
140
        }
×
141

×
142
        s.chanScheduler = batch.NewTimeScheduler(
×
143
                db, &s.cacheMu, opts.BatchCommitInterval,
×
144
        )
×
145
        s.nodeScheduler = batch.NewTimeScheduler(
×
146
                db, nil, opts.BatchCommitInterval,
×
147
        )
×
148

×
149
        return s, nil
×
150
}
151

152
// AddLightningNode adds a vertex/node to the graph database. If the node is not
153
// in the database from before, this will add a new, unconnected one to the
154
// graph. If it is present from before, this will update that node's
155
// information.
156
//
157
// NOTE: part of the V1Store interface.
158
func (s *SQLStore) AddLightningNode(ctx context.Context,
NEW
159
        node *models.LightningNode, opts ...batch.SchedulerOption) error {
×
160

×
161
        r := &batch.Request[SQLQueries]{
×
162
                Opts: batch.NewSchedulerOptions(opts...),
×
163
                Do: func(queries SQLQueries) error {
×
164
                        _, err := upsertNode(ctx, queries, node)
×
165
                        return err
×
166
                },
×
167
        }
168

169
        return s.nodeScheduler.Execute(ctx, r)
×
170
}
171

172
// FetchLightningNode attempts to look up a target node by its identity public
173
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
174
// returned.
175
//
176
// NOTE: part of the V1Store interface.
177
func (s *SQLStore) FetchLightningNode(ctx context.Context,
NEW
178
        pubKey route.Vertex) (*models.LightningNode, error) {
×
179

×
180
        var node *models.LightningNode
×
181
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
182
                var err error
×
183
                _, node, err = getNodeByPubKey(ctx, db, pubKey)
×
184

×
185
                return err
×
186
        }, sqldb.NoOpReset)
×
187
        if err != nil {
×
188
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
189
        }
×
190

191
        return node, nil
×
192
}
193

194
// HasLightningNode determines if the graph has a vertex identified by the
195
// target node identity public key. If the node exists in the database, a
196
// timestamp of when the data for the node was lasted updated is returned along
197
// with a true boolean. Otherwise, an empty time.Time is returned with a false
198
// boolean.
199
//
200
// NOTE: part of the V1Store interface.
201
func (s *SQLStore) HasLightningNode(ctx context.Context,
NEW
202
        pubKey [33]byte) (time.Time, bool, error) {
×
203

×
204
        var (
×
205
                exists     bool
×
206
                lastUpdate time.Time
×
207
        )
×
208
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
209
                dbNode, err := db.GetNodeByPubKey(
×
210
                        ctx, sqlc.GetNodeByPubKeyParams{
×
211
                                Version: int16(ProtocolV1),
×
212
                                PubKey:  pubKey[:],
×
213
                        },
×
214
                )
×
215
                if errors.Is(err, sql.ErrNoRows) {
×
216
                        return nil
×
217
                } else if err != nil {
×
218
                        return fmt.Errorf("unable to fetch node: %w", err)
×
219
                }
×
220

221
                exists = true
×
222

×
223
                if dbNode.LastUpdate.Valid {
×
224
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
225
                }
×
226

227
                return nil
×
228
        }, sqldb.NoOpReset)
229
        if err != nil {
×
230
                return time.Time{}, false,
×
231
                        fmt.Errorf("unable to fetch node: %w", err)
×
232
        }
×
233

234
        return lastUpdate, exists, nil
×
235
}
236

237
// AddrsForNode returns all known addresses for the target node public key
238
// that the graph DB is aware of. The returned boolean indicates if the
239
// given node is unknown to the graph DB or not.
240
//
241
// NOTE: part of the V1Store interface.
242
func (s *SQLStore) AddrsForNode(ctx context.Context,
NEW
243
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
244

×
245
        var (
×
246
                addresses []net.Addr
×
247
                known     bool
×
248
        )
×
249
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
250
                var err error
×
251
                known, addresses, err = getNodeAddresses(
×
252
                        ctx, db, nodePub.SerializeCompressed(),
×
253
                )
×
254
                if err != nil {
×
255
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
256
                                err)
×
257
                }
×
258

259
                return nil
×
260
        }, sqldb.NoOpReset)
261
        if err != nil {
×
262
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
263
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
264
        }
×
265

266
        return known, addresses, nil
×
267
}
268

269
// DeleteLightningNode starts a new database transaction to remove a vertex/node
270
// from the database according to the node's public key.
271
//
272
// NOTE: part of the V1Store interface.
273
func (s *SQLStore) DeleteLightningNode(ctx context.Context,
NEW
274
        pubKey route.Vertex) error {
×
275

×
276
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
277
                res, err := db.DeleteNodeByPubKey(
×
278
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
279
                                Version: int16(ProtocolV1),
×
280
                                PubKey:  pubKey[:],
×
281
                        },
×
282
                )
×
283
                if err != nil {
×
284
                        return err
×
285
                }
×
286

287
                rows, err := res.RowsAffected()
×
288
                if err != nil {
×
289
                        return err
×
290
                }
×
291

292
                if rows == 0 {
×
293
                        return ErrGraphNodeNotFound
×
294
                } else if rows > 1 {
×
295
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
296
                }
×
297

298
                return err
×
299
        }, sqldb.NoOpReset)
300
        if err != nil {
×
301
                return fmt.Errorf("unable to delete node: %w", err)
×
302
        }
×
303

304
        return nil
×
305
}
306

307
// FetchNodeFeatures returns the features of the given node. If no features are
308
// known for the node, an empty feature vector is returned.
309
//
310
// NOTE: this is part of the graphdb.NodeTraverser interface.
311
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
312
        *lnwire.FeatureVector, error) {
×
313

×
314
        ctx := context.TODO()
×
315

×
316
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
317
}
×
318

319
// LookupAlias attempts to return the alias as advertised by the target node.
320
//
321
// NOTE: part of the V1Store interface.
322
func (s *SQLStore) LookupAlias(pub *btcec.PublicKey) (string, error) {
×
323
        var (
×
324
                ctx   = context.TODO()
×
325
                alias string
×
326
        )
×
327
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
328
                dbNode, err := db.GetNodeByPubKey(
×
329
                        ctx, sqlc.GetNodeByPubKeyParams{
×
330
                                Version: int16(ProtocolV1),
×
331
                                PubKey:  pub.SerializeCompressed(),
×
332
                        },
×
333
                )
×
334
                if errors.Is(err, sql.ErrNoRows) {
×
335
                        return ErrNodeAliasNotFound
×
336
                } else if err != nil {
×
337
                        return fmt.Errorf("unable to fetch node: %w", err)
×
338
                }
×
339

340
                if !dbNode.Alias.Valid {
×
341
                        return ErrNodeAliasNotFound
×
342
                }
×
343

344
                alias = dbNode.Alias.String
×
345

×
346
                return nil
×
347
        }, sqldb.NoOpReset)
348
        if err != nil {
×
349
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
350
        }
×
351

352
        return alias, nil
×
353
}
354

355
// SourceNode returns the source node of the graph. The source node is treated
356
// as the center node within a star-graph. This method may be used to kick off
357
// a path finding algorithm in order to explore the reachability of another
358
// node based off the source node.
359
//
360
// NOTE: part of the V1Store interface.
361
func (s *SQLStore) SourceNode() (*models.LightningNode, error) {
×
362
        ctx := context.TODO()
×
363

×
364
        var node *models.LightningNode
×
365
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
366
                _, nodePub, err := getSourceNode(ctx, db, ProtocolV1)
×
367
                if err != nil {
×
368
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
369
                                err)
×
370
                }
×
371

372
                _, node, err = getNodeByPubKey(ctx, db, nodePub)
×
373

×
374
                return err
×
375
        }, sqldb.NoOpReset)
376
        if err != nil {
×
377
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
378
        }
×
379

380
        return node, nil
×
381
}
382

383
// SetSourceNode sets the source node within the graph database. The source
384
// node is to be used as the center of a star-graph within path finding
385
// algorithms.
386
//
387
// NOTE: part of the V1Store interface.
388
func (s *SQLStore) SetSourceNode(node *models.LightningNode) error {
×
389
        ctx := context.TODO()
×
390

×
391
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
392
                id, err := upsertNode(ctx, db, node)
×
393
                if err != nil {
×
394
                        return fmt.Errorf("unable to upsert source node: %w",
×
395
                                err)
×
396
                }
×
397

398
                // Make sure that if a source node for this version is already
399
                // set, then the ID is the same as the one we are about to set.
400
                dbSourceNodeID, _, err := getSourceNode(ctx, db, ProtocolV1)
×
401
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
402
                        return fmt.Errorf("unable to fetch source node: %w",
×
403
                                err)
×
404
                } else if err == nil {
×
405
                        if dbSourceNodeID != id {
×
406
                                return fmt.Errorf("v1 source node already "+
×
407
                                        "set to a different node: %d vs %d",
×
408
                                        dbSourceNodeID, id)
×
409
                        }
×
410

411
                        return nil
×
412
                }
413

414
                return db.AddSourceNode(ctx, id)
×
415
        }, sqldb.NoOpReset)
416
}
417

418
// NodeUpdatesInHorizon returns all the known lightning node which have an
419
// update timestamp within the passed range. This method can be used by two
420
// nodes to quickly determine if they have the same set of up to date node
421
// announcements.
422
//
423
// NOTE: This is part of the V1Store interface.
424
func (s *SQLStore) NodeUpdatesInHorizon(startTime,
425
        endTime time.Time) ([]models.LightningNode, error) {
×
426

×
427
        ctx := context.TODO()
×
428

×
429
        var nodes []models.LightningNode
×
430
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
431
                dbNodes, err := db.GetNodesByLastUpdateRange(
×
432
                        ctx, sqlc.GetNodesByLastUpdateRangeParams{
×
433
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
434
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
435
                        },
×
436
                )
×
437
                if err != nil {
×
438
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
439
                }
×
440

441
                for _, dbNode := range dbNodes {
×
442
                        node, err := buildNode(ctx, db, &dbNode)
×
443
                        if err != nil {
×
444
                                return fmt.Errorf("unable to build node: %w",
×
445
                                        err)
×
446
                        }
×
447

448
                        nodes = append(nodes, *node)
×
449
                }
450

451
                return nil
×
452
        }, sqldb.NoOpReset)
453
        if err != nil {
×
454
                return nil, fmt.Errorf("unable to fetch nodes: %w", err)
×
455
        }
×
456

457
        return nodes, nil
×
458
}
459

460
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
461
// undirected edge from the two target nodes are created. The information stored
462
// denotes the static attributes of the channel, such as the channelID, the keys
463
// involved in creation of the channel, and the set of features that the channel
464
// supports. The chanPoint and chanID are used to uniquely identify the edge
465
// globally within the database.
466
//
467
// NOTE: part of the V1Store interface.
468
func (s *SQLStore) AddChannelEdge(edge *models.ChannelEdgeInfo,
469
        opts ...batch.SchedulerOption) error {
×
470

×
471
        ctx := context.TODO()
×
472

×
473
        var alreadyExists bool
×
474
        r := &batch.Request[SQLQueries]{
×
475
                Opts: batch.NewSchedulerOptions(opts...),
×
476
                Reset: func() {
×
477
                        alreadyExists = false
×
478
                },
×
479
                Do: func(tx SQLQueries) error {
×
480
                        err := insertChannel(ctx, tx, edge)
×
481

×
482
                        // Silence ErrEdgeAlreadyExist so that the batch can
×
483
                        // succeed, but propagate the error via local state.
×
484
                        if errors.Is(err, ErrEdgeAlreadyExist) {
×
485
                                alreadyExists = true
×
486
                                return nil
×
487
                        }
×
488

489
                        return err
×
490
                },
491
                OnCommit: func(err error) error {
×
492
                        switch {
×
493
                        case err != nil:
×
494
                                return err
×
495
                        case alreadyExists:
×
496
                                return ErrEdgeAlreadyExist
×
497
                        default:
×
498
                                s.rejectCache.remove(edge.ChannelID)
×
499
                                s.chanCache.remove(edge.ChannelID)
×
500
                                return nil
×
501
                        }
502
                },
503
        }
504

505
        return s.chanScheduler.Execute(ctx, r)
×
506
}
507

508
// HighestChanID returns the "highest" known channel ID in the channel graph.
509
// This represents the "newest" channel from the PoV of the chain. This method
510
// can be used by peers to quickly determine if their graphs are in sync.
511
//
512
// NOTE: This is part of the V1Store interface.
513
func (s *SQLStore) HighestChanID() (uint64, error) {
×
514
        ctx := context.TODO()
×
515

×
516
        var highestChanID uint64
×
517
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
518
                chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
×
519
                if errors.Is(err, sql.ErrNoRows) {
×
520
                        return nil
×
521
                } else if err != nil {
×
522
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
523
                                err)
×
524
                }
×
525

526
                highestChanID = byteOrder.Uint64(chanID)
×
527

×
528
                return nil
×
529
        }, sqldb.NoOpReset)
530
        if err != nil {
×
531
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
532
        }
×
533

534
        return highestChanID, nil
×
535
}
536

537
// getNodeByPubKey attempts to look up a target node by its public key.
538
func getNodeByPubKey(ctx context.Context, db SQLQueries,
539
        pubKey route.Vertex) (int64, *models.LightningNode, error) {
×
540

×
541
        dbNode, err := db.GetNodeByPubKey(
×
542
                ctx, sqlc.GetNodeByPubKeyParams{
×
543
                        Version: int16(ProtocolV1),
×
544
                        PubKey:  pubKey[:],
×
545
                },
×
546
        )
×
547
        if errors.Is(err, sql.ErrNoRows) {
×
548
                return 0, nil, ErrGraphNodeNotFound
×
549
        } else if err != nil {
×
550
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
551
        }
×
552

553
        node, err := buildNode(ctx, db, &dbNode)
×
554
        if err != nil {
×
555
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
556
        }
×
557

558
        return dbNode.ID, node, nil
×
559
}
560

561
// buildNode constructs a LightningNode instance from the given database node
562
// record. The node's features, addresses and extra signed fields are also
563
// fetched from the database and set on the node.
564
func buildNode(ctx context.Context, db SQLQueries, dbNode *sqlc.Node) (
565
        *models.LightningNode, error) {
×
566

×
567
        if dbNode.Version != int16(ProtocolV1) {
×
568
                return nil, fmt.Errorf("unsupported node version: %d",
×
569
                        dbNode.Version)
×
570
        }
×
571

572
        var pub [33]byte
×
573
        copy(pub[:], dbNode.PubKey)
×
574

×
575
        node := &models.LightningNode{
×
576
                PubKeyBytes: pub,
×
577
                Features:    lnwire.EmptyFeatureVector(),
×
578
                LastUpdate:  time.Unix(0, 0),
×
579
        }
×
580

×
581
        if len(dbNode.Signature) == 0 {
×
582
                return node, nil
×
583
        }
×
584

585
        node.HaveNodeAnnouncement = true
×
586
        node.AuthSigBytes = dbNode.Signature
×
587
        node.Alias = dbNode.Alias.String
×
588
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
589

×
590
        var err error
×
591
        node.Color, err = DecodeHexColor(dbNode.Color.String)
×
592
        if err != nil {
×
593
                return nil, fmt.Errorf("unable to decode color: %w", err)
×
594
        }
×
595

596
        // Fetch the node's features.
597
        node.Features, err = getNodeFeatures(ctx, db, dbNode.ID)
×
598
        if err != nil {
×
599
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
600
                        "features: %w", dbNode.ID, err)
×
601
        }
×
602

603
        // Fetch the node's addresses.
604
        _, node.Addresses, err = getNodeAddresses(ctx, db, pub[:])
×
605
        if err != nil {
×
606
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
607
                        "addresses: %w", dbNode.ID, err)
×
608
        }
×
609

610
        // Fetch the node's extra signed fields.
611
        extraTLVMap, err := getNodeExtraSignedFields(ctx, db, dbNode.ID)
×
612
        if err != nil {
×
613
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
614
                        "extra signed fields: %w", dbNode.ID, err)
×
615
        }
×
616

617
        recs, err := lnwire.CustomRecords(extraTLVMap).Serialize()
×
618
        if err != nil {
×
619
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
620
                        "fields: %w", err)
×
621
        }
×
622

623
        if len(recs) != 0 {
×
624
                node.ExtraOpaqueData = recs
×
625
        }
×
626

627
        return node, nil
×
628
}
629

630
// getNodeFeatures fetches the feature bits and constructs the feature vector
631
// for a node with the given DB ID.
632
func getNodeFeatures(ctx context.Context, db SQLQueries,
633
        nodeID int64) (*lnwire.FeatureVector, error) {
×
634

×
635
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
636
        if err != nil {
×
637
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
638
                        nodeID, err)
×
639
        }
×
640

641
        features := lnwire.EmptyFeatureVector()
×
642
        for _, feature := range rows {
×
643
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
644
        }
×
645

646
        return features, nil
×
647
}
648

649
// getNodeExtraSignedFields fetches the extra signed fields for a node with the
650
// given DB ID.
651
func getNodeExtraSignedFields(ctx context.Context, db SQLQueries,
652
        nodeID int64) (map[uint64][]byte, error) {
×
653

×
654
        fields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
655
        if err != nil {
×
656
                return nil, fmt.Errorf("unable to get node(%d) extra "+
×
657
                        "signed fields: %w", nodeID, err)
×
658
        }
×
659

660
        extraFields := make(map[uint64][]byte)
×
661
        for _, field := range fields {
×
662
                extraFields[uint64(field.Type)] = field.Value
×
663
        }
×
664

665
        return extraFields, nil
×
666
}
667

668
// upsertNode upserts the node record into the database. If the node already
669
// exists, then the node's information is updated. If the node doesn't exist,
670
// then a new node is created. The node's features, addresses and extra TLV
671
// types are also updated. The node's DB ID is returned.
672
func upsertNode(ctx context.Context, db SQLQueries,
673
        node *models.LightningNode) (int64, error) {
×
674

×
675
        params := sqlc.UpsertNodeParams{
×
676
                Version: int16(ProtocolV1),
×
677
                PubKey:  node.PubKeyBytes[:],
×
678
        }
×
679

×
680
        if node.HaveNodeAnnouncement {
×
681
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
682
                params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
×
683
                params.Alias = sqldb.SQLStr(node.Alias)
×
684
                params.Signature = node.AuthSigBytes
×
685
        }
×
686

687
        nodeID, err := db.UpsertNode(ctx, params)
×
688
        if err != nil {
×
689
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
690
                        err)
×
691
        }
×
692

693
        // We can exit here if we don't have the announcement yet.
694
        if !node.HaveNodeAnnouncement {
×
695
                return nodeID, nil
×
696
        }
×
697

698
        // Update the node's features.
699
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
700
        if err != nil {
×
701
                return 0, fmt.Errorf("inserting node features: %w", err)
×
702
        }
×
703

704
        // Update the node's addresses.
705
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
706
        if err != nil {
×
707
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
708
        }
×
709

710
        // Convert the flat extra opaque data into a map of TLV types to
711
        // values.
712
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
713
        if err != nil {
×
714
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
715
                        err)
×
716
        }
×
717

718
        // Update the node's extra signed fields.
719
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
720
        if err != nil {
×
721
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
722
        }
×
723

724
        return nodeID, nil
×
725
}
726

727
// upsertNodeFeatures updates the node's features node_features table. This
728
// includes deleting any feature bits no longer present and inserting any new
729
// feature bits. If the feature bit does not yet exist in the features table,
730
// then an entry is created in that table first.
731
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
732
        features *lnwire.FeatureVector) error {
×
733

×
734
        // Get any existing features for the node.
×
735
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
736
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
737
                return err
×
738
        }
×
739

740
        // Copy the nodes latest set of feature bits.
741
        newFeatures := make(map[int32]struct{})
×
742
        if features != nil {
×
743
                for feature := range features.Features() {
×
744
                        newFeatures[int32(feature)] = struct{}{}
×
745
                }
×
746
        }
747

748
        // For any current feature that already exists in the DB, remove it from
749
        // the in-memory map. For any existing feature that does not exist in
750
        // the in-memory map, delete it from the database.
751
        for _, feature := range existingFeatures {
×
752
                // The feature is still present, so there are no updates to be
×
753
                // made.
×
754
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
755
                        delete(newFeatures, feature.FeatureBit)
×
756
                        continue
×
757
                }
758

759
                // The feature is no longer present, so we remove it from the
760
                // database.
761
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
762
                        NodeID:     nodeID,
×
763
                        FeatureBit: feature.FeatureBit,
×
764
                })
×
765
                if err != nil {
×
766
                        return fmt.Errorf("unable to delete node(%d) "+
×
767
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
768
                                err)
×
769
                }
×
770
        }
771

772
        // Any remaining entries in newFeatures are new features that need to be
773
        // added to the database for the first time.
774
        for feature := range newFeatures {
×
775
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
776
                        NodeID:     nodeID,
×
777
                        FeatureBit: feature,
×
778
                })
×
779
                if err != nil {
×
780
                        return fmt.Errorf("unable to insert node(%d) "+
×
781
                                "feature(%v): %w", nodeID, feature, err)
×
782
                }
×
783
        }
784

785
        return nil
×
786
}
787

788
// fetchNodeFeatures fetches the features for a node with the given public key.
789
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
790
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
791

×
792
        rows, err := queries.GetNodeFeaturesByPubKey(
×
793
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
794
                        PubKey:  nodePub[:],
×
795
                        Version: int16(ProtocolV1),
×
796
                },
×
797
        )
×
798
        if err != nil {
×
799
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
800
                        nodePub, err)
×
801
        }
×
802

803
        features := lnwire.EmptyFeatureVector()
×
804
        for _, bit := range rows {
×
805
                features.Set(lnwire.FeatureBit(bit))
×
806
        }
×
807

808
        return features, nil
×
809
}
810

811
// dbAddressType is an enum type that represents the different address types
812
// that we store in the node_addresses table. The address type determines how
813
// the address is to be serialised/deserialize.
814
type dbAddressType uint8
815

816
const (
817
        addressTypeIPv4   dbAddressType = 1
818
        addressTypeIPv6   dbAddressType = 2
819
        addressTypeTorV2  dbAddressType = 3
820
        addressTypeTorV3  dbAddressType = 4
821
        addressTypeOpaque dbAddressType = math.MaxInt8
822
)
823

824
// upsertNodeAddresses updates the node's addresses in the database. This
825
// includes deleting any existing addresses and inserting the new set of
826
// addresses. The deletion is necessary since the ordering of the addresses may
827
// change, and we need to ensure that the database reflects the latest set of
828
// addresses so that at the time of reconstructing the node announcement, the
829
// order is preserved and the signature over the message remains valid.
830
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
831
        addresses []net.Addr) error {
×
832

×
833
        // Delete any existing addresses for the node. This is required since
×
834
        // even if the new set of addresses is the same, the ordering may have
×
835
        // changed for a given address type.
×
836
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
837
        if err != nil {
×
838
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
839
                        nodeID, err)
×
840
        }
×
841

842
        // Copy the nodes latest set of addresses.
843
        newAddresses := map[dbAddressType][]string{
×
844
                addressTypeIPv4:   {},
×
845
                addressTypeIPv6:   {},
×
846
                addressTypeTorV2:  {},
×
847
                addressTypeTorV3:  {},
×
848
                addressTypeOpaque: {},
×
849
        }
×
850
        addAddr := func(t dbAddressType, addr net.Addr) {
×
851
                newAddresses[t] = append(newAddresses[t], addr.String())
×
852
        }
×
853

854
        for _, address := range addresses {
×
855
                switch addr := address.(type) {
×
856
                case *net.TCPAddr:
×
857
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
858
                                addAddr(addressTypeIPv4, addr)
×
859
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
860
                                addAddr(addressTypeIPv6, addr)
×
861
                        } else {
×
862
                                return fmt.Errorf("unhandled IP address: %v",
×
863
                                        addr)
×
864
                        }
×
865

866
                case *tor.OnionAddr:
×
867
                        switch len(addr.OnionService) {
×
868
                        case tor.V2Len:
×
869
                                addAddr(addressTypeTorV2, addr)
×
870
                        case tor.V3Len:
×
871
                                addAddr(addressTypeTorV3, addr)
×
872
                        default:
×
873
                                return fmt.Errorf("invalid length for a tor " +
×
874
                                        "address")
×
875
                        }
876

877
                case *lnwire.OpaqueAddrs:
×
878
                        addAddr(addressTypeOpaque, addr)
×
879

880
                default:
×
881
                        return fmt.Errorf("unhandled address type: %T", addr)
×
882
                }
883
        }
884

885
        // Any remaining entries in newAddresses are new addresses that need to
886
        // be added to the database for the first time.
887
        for addrType, addrList := range newAddresses {
×
888
                for position, addr := range addrList {
×
889
                        err := db.InsertNodeAddress(
×
890
                                ctx, sqlc.InsertNodeAddressParams{
×
891
                                        NodeID:   nodeID,
×
892
                                        Type:     int16(addrType),
×
893
                                        Address:  addr,
×
894
                                        Position: int32(position),
×
895
                                },
×
896
                        )
×
897
                        if err != nil {
×
898
                                return fmt.Errorf("unable to insert "+
×
899
                                        "node(%d) address(%v): %w", nodeID,
×
900
                                        addr, err)
×
901
                        }
×
902
                }
903
        }
904

905
        return nil
×
906
}
907

908
// getNodeAddresses fetches the addresses for a node with the given public key.
909
func getNodeAddresses(ctx context.Context, db SQLQueries, nodePub []byte) (bool,
910
        []net.Addr, error) {
×
911

×
912
        // GetNodeAddressesByPubKey ensures that the addresses for a given type
×
913
        // are returned in the same order as they were inserted.
×
914
        rows, err := db.GetNodeAddressesByPubKey(
×
915
                ctx, sqlc.GetNodeAddressesByPubKeyParams{
×
916
                        Version: int16(ProtocolV1),
×
917
                        PubKey:  nodePub,
×
918
                },
×
919
        )
×
920
        if err != nil {
×
921
                return false, nil, err
×
922
        }
×
923

924
        // GetNodeAddressesByPubKey uses a left join so there should always be
925
        // at least one row returned if the node exists even if it has no
926
        // addresses.
927
        if len(rows) == 0 {
×
928
                return false, nil, nil
×
929
        }
×
930

931
        addresses := make([]net.Addr, 0, len(rows))
×
932
        for _, addr := range rows {
×
933
                if !(addr.Type.Valid && addr.Address.Valid) {
×
934
                        continue
×
935
                }
936

937
                address := addr.Address.String
×
938

×
939
                switch dbAddressType(addr.Type.Int16) {
×
940
                case addressTypeIPv4:
×
941
                        tcp, err := net.ResolveTCPAddr("tcp4", address)
×
942
                        if err != nil {
×
943
                                return false, nil, nil
×
944
                        }
×
945
                        tcp.IP = tcp.IP.To4()
×
946

×
947
                        addresses = append(addresses, tcp)
×
948

949
                case addressTypeIPv6:
×
950
                        tcp, err := net.ResolveTCPAddr("tcp6", address)
×
951
                        if err != nil {
×
952
                                return false, nil, nil
×
953
                        }
×
954
                        addresses = append(addresses, tcp)
×
955

956
                case addressTypeTorV3, addressTypeTorV2:
×
957
                        service, portStr, err := net.SplitHostPort(address)
×
958
                        if err != nil {
×
959
                                return false, nil, fmt.Errorf("unable to "+
×
960
                                        "split tor v3 address: %v",
×
961
                                        addr.Address)
×
962
                        }
×
963

964
                        port, err := strconv.Atoi(portStr)
×
965
                        if err != nil {
×
966
                                return false, nil, err
×
967
                        }
×
968

969
                        addresses = append(addresses, &tor.OnionAddr{
×
970
                                OnionService: service,
×
971
                                Port:         port,
×
972
                        })
×
973

974
                case addressTypeOpaque:
×
975
                        opaque, err := hex.DecodeString(address)
×
976
                        if err != nil {
×
977
                                return false, nil, fmt.Errorf("unable to "+
×
978
                                        "decode opaque address: %v", addr)
×
979
                        }
×
980

981
                        addresses = append(addresses, &lnwire.OpaqueAddrs{
×
982
                                Payload: opaque,
×
983
                        })
×
984

985
                default:
×
986
                        return false, nil, fmt.Errorf("unknown address "+
×
987
                                "type: %v", addr.Type)
×
988
                }
989
        }
990

991
        return true, addresses, nil
×
992
}
993

994
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
995
// database. This includes updating any existing types, inserting any new types,
996
// and deleting any types that are no longer present.
997
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
998
        nodeID int64, extraFields map[uint64][]byte) error {
×
999

×
1000
        // Get any existing extra signed fields for the node.
×
1001
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
1002
        if err != nil {
×
1003
                return err
×
1004
        }
×
1005

1006
        // Make a lookup map of the existing field types so that we can use it
1007
        // to keep track of any fields we should delete.
1008
        m := make(map[uint64]bool)
×
1009
        for _, field := range existingFields {
×
1010
                m[uint64(field.Type)] = true
×
1011
        }
×
1012

1013
        // For all the new fields, we'll upsert them and remove them from the
1014
        // map of existing fields.
1015
        for tlvType, value := range extraFields {
×
1016
                err = db.UpsertNodeExtraType(
×
1017
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
1018
                                NodeID: nodeID,
×
1019
                                Type:   int64(tlvType),
×
1020
                                Value:  value,
×
1021
                        },
×
1022
                )
×
1023
                if err != nil {
×
1024
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
1025
                                "signed field(%v): %w", nodeID, tlvType, err)
×
1026
                }
×
1027

1028
                // Remove the field from the map of existing fields if it was
1029
                // present.
1030
                delete(m, tlvType)
×
1031
        }
1032

1033
        // For all the fields that are left in the map of existing fields, we'll
1034
        // delete them as they are no longer present in the new set of fields.
1035
        for tlvType := range m {
×
1036
                err = db.DeleteExtraNodeType(
×
1037
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
1038
                                NodeID: nodeID,
×
1039
                                Type:   int64(tlvType),
×
1040
                        },
×
1041
                )
×
1042
                if err != nil {
×
1043
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
1044
                                "signed field(%v): %w", nodeID, tlvType, err)
×
1045
                }
×
1046
        }
1047

1048
        return nil
×
1049
}
1050

1051
// getSourceNode returns the DB node ID and pub key of the source node for the
1052
// specified protocol version.
1053
func getSourceNode(ctx context.Context, db SQLQueries,
1054
        version ProtocolVersion) (int64, route.Vertex, error) {
×
1055

×
1056
        var pubKey route.Vertex
×
1057

×
1058
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
1059
        if err != nil {
×
1060
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
1061
                        err)
×
1062
        }
×
1063

1064
        if len(nodes) == 0 {
×
1065
                return 0, pubKey, ErrSourceNodeNotSet
×
1066
        } else if len(nodes) > 1 {
×
1067
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
1068
                        "protocol %s found", version)
×
1069
        }
×
1070

1071
        copy(pubKey[:], nodes[0].PubKey)
×
1072

×
1073
        return nodes[0].NodeID, pubKey, nil
×
1074
}
1075

1076
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
1077
// This then produces a map from TLV type to value. If the input is not a
1078
// valid TLV stream, then an error is returned.
1079
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
1080
        r := bytes.NewReader(data)
×
1081

×
1082
        tlvStream, err := tlv.NewStream()
×
1083
        if err != nil {
×
1084
                return nil, err
×
1085
        }
×
1086

1087
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
1088
        // pass it into the P2P decoding variant.
1089
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
1090
        if err != nil {
×
1091
                return nil, err
×
1092
        }
×
1093
        if len(parsedTypes) == 0 {
×
1094
                return nil, nil
×
1095
        }
×
1096

1097
        records := make(map[uint64][]byte)
×
1098
        for k, v := range parsedTypes {
×
1099
                records[uint64(k)] = v
×
1100
        }
×
1101

1102
        return records, nil
×
1103
}
1104

1105
// insertChannel inserts a new channel record into the database.
1106
func insertChannel(ctx context.Context, db SQLQueries,
1107
        edge *models.ChannelEdgeInfo) error {
×
1108

×
1109
        var chanIDB [8]byte
×
1110
        byteOrder.PutUint64(chanIDB[:], edge.ChannelID)
×
1111

×
1112
        // Make sure that the channel doesn't already exist. We do this
×
1113
        // explicitly instead of relying on catching a unique constraint error
×
1114
        // because relying on SQL to throw that error would abort the entire
×
1115
        // batch of transactions.
×
1116
        _, err := db.GetChannelBySCID(
×
1117
                ctx, sqlc.GetChannelBySCIDParams{
×
1118
                        Scid:    chanIDB[:],
×
1119
                        Version: int16(ProtocolV1),
×
1120
                },
×
1121
        )
×
1122
        if err == nil {
×
1123
                return ErrEdgeAlreadyExist
×
1124
        } else if !errors.Is(err, sql.ErrNoRows) {
×
1125
                return fmt.Errorf("unable to fetch channel: %w", err)
×
1126
        }
×
1127

1128
        // Make sure that at least a "shell" entry for each node is present in
1129
        // the nodes table.
1130
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
1131
        if err != nil {
×
1132
                return fmt.Errorf("unable to create shell node: %w", err)
×
1133
        }
×
1134

1135
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
1136
        if err != nil {
×
1137
                return fmt.Errorf("unable to create shell node: %w", err)
×
1138
        }
×
1139

1140
        var capacity sql.NullInt64
×
1141
        if edge.Capacity != 0 {
×
1142
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
1143
        }
×
1144

1145
        createParams := sqlc.CreateChannelParams{
×
1146
                Version:     int16(ProtocolV1),
×
1147
                Scid:        chanIDB[:],
×
1148
                NodeID1:     node1DBID,
×
1149
                NodeID2:     node2DBID,
×
1150
                Outpoint:    edge.ChannelPoint.String(),
×
1151
                Capacity:    capacity,
×
1152
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
1153
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
1154
        }
×
1155

×
1156
        if edge.AuthProof != nil {
×
1157
                proof := edge.AuthProof
×
1158

×
1159
                createParams.Node1Signature = proof.NodeSig1Bytes
×
1160
                createParams.Node2Signature = proof.NodeSig2Bytes
×
1161
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
1162
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
1163
        }
×
1164

1165
        // Insert the new channel record.
1166
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
1167
        if err != nil {
×
1168
                return err
×
1169
        }
×
1170

1171
        // Insert any channel features.
1172
        if len(edge.Features) != 0 {
×
1173
                chanFeatures := lnwire.NewRawFeatureVector()
×
1174
                err := chanFeatures.Decode(bytes.NewReader(edge.Features))
×
1175
                if err != nil {
×
1176
                        return err
×
1177
                }
×
1178

1179
                fv := lnwire.NewFeatureVector(chanFeatures, lnwire.Features)
×
1180
                for feature := range fv.Features() {
×
1181
                        err = db.InsertChannelFeature(
×
1182
                                ctx, sqlc.InsertChannelFeatureParams{
×
1183
                                        ChannelID:  dbChanID,
×
1184
                                        FeatureBit: int32(feature),
×
1185
                                },
×
1186
                        )
×
1187
                        if err != nil {
×
1188
                                return fmt.Errorf("unable to insert "+
×
1189
                                        "channel(%d) feature(%v): %w", dbChanID,
×
1190
                                        feature, err)
×
1191
                        }
×
1192
                }
1193
        }
1194

1195
        // Finally, insert any extra TLV fields in the channel announcement.
1196
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
1197
        if err != nil {
×
1198
                return fmt.Errorf("unable to marshal extra opaque data: %w",
×
1199
                        err)
×
1200
        }
×
1201

1202
        for tlvType, value := range extra {
×
1203
                err := db.CreateChannelExtraType(
×
1204
                        ctx, sqlc.CreateChannelExtraTypeParams{
×
1205
                                ChannelID: dbChanID,
×
1206
                                Type:      int64(tlvType),
×
1207
                                Value:     value,
×
1208
                        },
×
1209
                )
×
1210
                if err != nil {
×
1211
                        return fmt.Errorf("unable to upsert channel(%d) extra "+
×
1212
                                "signed field(%v): %w", edge.ChannelID,
×
1213
                                tlvType, err)
×
1214
                }
×
1215
        }
1216

1217
        return nil
×
1218
}
1219

1220
// maybeCreateShellNode checks if a shell node entry exists for the
1221
// given public key. If it does not exist, then a new shell node entry is
1222
// created. The ID of the node is returned. A shell node only has a protocol
1223
// version and public key persisted.
1224
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
1225
        pubKey route.Vertex) (int64, error) {
×
1226

×
1227
        dbNode, err := db.GetNodeByPubKey(
×
1228
                ctx, sqlc.GetNodeByPubKeyParams{
×
1229
                        PubKey:  pubKey[:],
×
1230
                        Version: int16(ProtocolV1),
×
1231
                },
×
1232
        )
×
1233
        // The node exists. Return the ID.
×
1234
        if err == nil {
×
1235
                return dbNode.ID, nil
×
1236
        } else if !errors.Is(err, sql.ErrNoRows) {
×
1237
                return 0, err
×
1238
        }
×
1239

1240
        // Otherwise, the node does not exist, so we create a shell entry for
1241
        // it.
1242
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
1243
                Version: int16(ProtocolV1),
×
1244
                PubKey:  pubKey[:],
×
1245
        })
×
1246
        if err != nil {
×
1247
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
1248
        }
×
1249

1250
        return id, nil
×
1251
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc