• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 16448527814

22 Jul 2025 03:19PM UTC coverage: 67.239% (+9.7%) from 57.535%
16448527814

Pull #10081

github

web-flow
Merge ddc0e95ed into f09c7aee4
Pull Request #10081: graph/db: use `/*SLICE:<field_name>*/` to optimise various graph queries

20 of 471 new or added lines in 4 files covered. (4.25%)

39 existing lines in 9 files now uncovered.

135503 of 201523 relevant lines covered (67.24%)

21726.29 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "maps"
11
        "math"
12
        "net"
13
        "slices"
14
        "strconv"
15
        "sync"
16
        "time"
17

18
        "github.com/btcsuite/btcd/btcec/v2"
19
        "github.com/btcsuite/btcd/btcutil"
20
        "github.com/btcsuite/btcd/chaincfg/chainhash"
21
        "github.com/btcsuite/btcd/wire"
22
        "github.com/lightningnetwork/lnd/aliasmgr"
23
        "github.com/lightningnetwork/lnd/batch"
24
        "github.com/lightningnetwork/lnd/fn/v2"
25
        "github.com/lightningnetwork/lnd/graph/db/models"
26
        "github.com/lightningnetwork/lnd/lnwire"
27
        "github.com/lightningnetwork/lnd/routing/route"
28
        "github.com/lightningnetwork/lnd/sqldb"
29
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
30
        "github.com/lightningnetwork/lnd/tlv"
31
        "github.com/lightningnetwork/lnd/tor"
32
)
33

34
// pageSize is the limit for the number of records that can be returned
35
// in a paginated query. This can be tuned after some benchmarks.
36
const pageSize = 2000
37

38
// ProtocolVersion is an enum that defines the gossip protocol version of a
39
// message.
40
type ProtocolVersion uint8
41

42
const (
43
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
44
        ProtocolV1 ProtocolVersion = 1
45
)
46

47
// String returns a string representation of the protocol version.
48
func (v ProtocolVersion) String() string {
×
49
        return fmt.Sprintf("V%d", v)
×
50
}
×
51

52
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
53
// execute queries against the SQL graph tables.
54
//
55
//nolint:ll,interfacebloat
56
type SQLQueries interface {
57
        /*
58
                Node queries.
59
        */
60
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
61
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.GraphNode, error)
62
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
63
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.GraphNode, error)
64
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.GraphNode, error)
65
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
66
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
67
        DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error)
68
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
69
        DeleteNode(ctx context.Context, id int64) error
70

71
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeExtraType, error)
72
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
73
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
74

75
        InsertNodeAddress(ctx context.Context, arg sqlc.InsertNodeAddressParams) error
76
        GetNodeAddressesByPubKey(ctx context.Context, arg sqlc.GetNodeAddressesByPubKeyParams) ([]sqlc.GetNodeAddressesByPubKeyRow, error)
77
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
78

79
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
80
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeFeature, error)
81
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
82
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
83

84
        /*
85
                Source node queries.
86
        */
87
        AddSourceNode(ctx context.Context, nodeID int64) error
88
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
89

90
        /*
91
                Channel queries.
92
        */
93
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
94
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
95
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.GraphChannel, error)
96
        GetChannelsBySCIDs(ctx context.Context, arg sqlc.GetChannelsBySCIDsParams) ([]sqlc.GraphChannel, error)
97
        GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]sqlc.GetChannelsByOutpointsRow, error)
98
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
99
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
100
        GetChannelsBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelsBySCIDWithPoliciesParams) ([]sqlc.GetChannelsBySCIDWithPoliciesRow, error)
101
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
102
        GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]sqlc.GetChannelFeaturesAndExtrasRow, error)
103
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
104
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
105
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
106
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
107
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
108
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
109
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.GraphChannel, error)
110
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
111
        DeleteChannels(ctx context.Context, ids []int64) error
112

113
        CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
114
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
115

116
        /*
117
                Channel Policy table queries.
118
        */
119
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
120
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.GraphChannelPolicy, error)
121
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
122

123
        InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error
124
        GetChannelPolicyExtraTypes(ctx context.Context, arg sqlc.GetChannelPolicyExtraTypesParams) ([]sqlc.GetChannelPolicyExtraTypesRow, error)
125
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
126

127
        /*
128
                Zombie index queries.
129
        */
130
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
131
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.GraphZombieChannel, error)
132
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
133
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
134
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
135

136
        /*
137
                Prune log table queries.
138
        */
139
        GetPruneTip(ctx context.Context) (sqlc.GraphPruneLog, error)
140
        GetPruneHashByHeight(ctx context.Context, blockHeight int64) ([]byte, error)
141
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
142
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
143

144
        /*
145
                Closed SCID table queries.
146
        */
147
        InsertClosedChannel(ctx context.Context, scid []byte) error
148
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
149
}
150

151
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
152
// database operations.
153
type BatchedSQLQueries interface {
154
        SQLQueries
155
        sqldb.BatchedTx[SQLQueries]
156
}
157

158
// SQLStore is an implementation of the V1Store interface that uses a SQL
159
// database as the backend.
160
type SQLStore struct {
161
        cfg *SQLStoreConfig
162
        db  BatchedSQLQueries
163

164
        // cacheMu guards all caches (rejectCache and chanCache). If
165
        // this mutex will be acquired at the same time as the DB mutex then
166
        // the cacheMu MUST be acquired first to prevent deadlock.
167
        cacheMu     sync.RWMutex
168
        rejectCache *rejectCache
169
        chanCache   *channelCache
170

171
        chanScheduler batch.Scheduler[SQLQueries]
172
        nodeScheduler batch.Scheduler[SQLQueries]
173

174
        srcNodes  map[ProtocolVersion]*srcNodeInfo
175
        srcNodeMu sync.Mutex
176
}
177

178
// A compile-time assertion to ensure that SQLStore implements the V1Store
179
// interface.
180
var _ V1Store = (*SQLStore)(nil)
181

182
// SQLStoreConfig holds the configuration for the SQLStore.
183
type SQLStoreConfig struct {
184
        // ChainHash is the genesis hash for the chain that all the gossip
185
        // messages in this store are aimed at.
186
        ChainHash chainhash.Hash
187

188
        // PaginationCfg is the configuration for paginated queries.
189
        PaginationCfg *sqldb.PagedQueryConfig
190
}
191

192
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
193
// storage backend.
194
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
195
        options ...StoreOptionModifier) (*SQLStore, error) {
×
196

×
197
        opts := DefaultOptions()
×
198
        for _, o := range options {
×
199
                o(opts)
×
200
        }
×
201

202
        if opts.NoMigration {
×
203
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
204
                        "supported for SQL stores")
×
205
        }
×
206

207
        s := &SQLStore{
×
208
                cfg:         cfg,
×
209
                db:          db,
×
210
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
211
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
212
                srcNodes:    make(map[ProtocolVersion]*srcNodeInfo),
×
213
        }
×
214

×
215
        s.chanScheduler = batch.NewTimeScheduler(
×
216
                db, &s.cacheMu, opts.BatchCommitInterval,
×
217
        )
×
218
        s.nodeScheduler = batch.NewTimeScheduler(
×
219
                db, nil, opts.BatchCommitInterval,
×
220
        )
×
221

×
222
        return s, nil
×
223
}
224

225
// AddLightningNode adds a vertex/node to the graph database. If the node is not
226
// in the database from before, this will add a new, unconnected one to the
227
// graph. If it is present from before, this will update that node's
228
// information.
229
//
230
// NOTE: part of the V1Store interface.
231
func (s *SQLStore) AddLightningNode(ctx context.Context,
232
        node *models.LightningNode, opts ...batch.SchedulerOption) error {
×
233

×
234
        r := &batch.Request[SQLQueries]{
×
235
                Opts: batch.NewSchedulerOptions(opts...),
×
236
                Do: func(queries SQLQueries) error {
×
237
                        _, err := upsertNode(ctx, queries, node)
×
238
                        return err
×
239
                },
×
240
        }
241

242
        return s.nodeScheduler.Execute(ctx, r)
×
243
}
244

245
// FetchLightningNode attempts to look up a target node by its identity public
246
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
247
// returned.
248
//
249
// NOTE: part of the V1Store interface.
250
func (s *SQLStore) FetchLightningNode(ctx context.Context,
251
        pubKey route.Vertex) (*models.LightningNode, error) {
×
252

×
253
        var node *models.LightningNode
×
254
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
255
                var err error
×
256
                _, node, err = getNodeByPubKey(ctx, db, pubKey)
×
257

×
258
                return err
×
259
        }, sqldb.NoOpReset)
×
260
        if err != nil {
×
261
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
262
        }
×
263

264
        return node, nil
×
265
}
266

267
// HasLightningNode determines if the graph has a vertex identified by the
268
// target node identity public key. If the node exists in the database, a
269
// timestamp of when the data for the node was lasted updated is returned along
270
// with a true boolean. Otherwise, an empty time.Time is returned with a false
271
// boolean.
272
//
273
// NOTE: part of the V1Store interface.
274
func (s *SQLStore) HasLightningNode(ctx context.Context,
275
        pubKey [33]byte) (time.Time, bool, error) {
×
276

×
277
        var (
×
278
                exists     bool
×
279
                lastUpdate time.Time
×
280
        )
×
281
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
282
                dbNode, err := db.GetNodeByPubKey(
×
283
                        ctx, sqlc.GetNodeByPubKeyParams{
×
284
                                Version: int16(ProtocolV1),
×
285
                                PubKey:  pubKey[:],
×
286
                        },
×
287
                )
×
288
                if errors.Is(err, sql.ErrNoRows) {
×
289
                        return nil
×
290
                } else if err != nil {
×
291
                        return fmt.Errorf("unable to fetch node: %w", err)
×
292
                }
×
293

294
                exists = true
×
295

×
296
                if dbNode.LastUpdate.Valid {
×
297
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
298
                }
×
299

300
                return nil
×
301
        }, sqldb.NoOpReset)
302
        if err != nil {
×
303
                return time.Time{}, false,
×
304
                        fmt.Errorf("unable to fetch node: %w", err)
×
305
        }
×
306

307
        return lastUpdate, exists, nil
×
308
}
309

310
// AddrsForNode returns all known addresses for the target node public key
311
// that the graph DB is aware of. The returned boolean indicates if the
312
// given node is unknown to the graph DB or not.
313
//
314
// NOTE: part of the V1Store interface.
315
func (s *SQLStore) AddrsForNode(ctx context.Context,
316
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
317

×
318
        var (
×
319
                addresses []net.Addr
×
320
                known     bool
×
321
        )
×
322
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
323
                var err error
×
324
                known, addresses, err = getNodeAddresses(
×
325
                        ctx, db, nodePub.SerializeCompressed(),
×
326
                )
×
327
                if err != nil {
×
328
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
329
                                err)
×
330
                }
×
331

332
                return nil
×
333
        }, sqldb.NoOpReset)
334
        if err != nil {
×
335
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
336
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
337
        }
×
338

339
        return known, addresses, nil
×
340
}
341

342
// DeleteLightningNode starts a new database transaction to remove a vertex/node
343
// from the database according to the node's public key.
344
//
345
// NOTE: part of the V1Store interface.
346
func (s *SQLStore) DeleteLightningNode(ctx context.Context,
347
        pubKey route.Vertex) error {
×
348

×
349
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
350
                res, err := db.DeleteNodeByPubKey(
×
351
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
352
                                Version: int16(ProtocolV1),
×
353
                                PubKey:  pubKey[:],
×
354
                        },
×
355
                )
×
356
                if err != nil {
×
357
                        return err
×
358
                }
×
359

360
                rows, err := res.RowsAffected()
×
361
                if err != nil {
×
362
                        return err
×
363
                }
×
364

365
                if rows == 0 {
×
366
                        return ErrGraphNodeNotFound
×
367
                } else if rows > 1 {
×
368
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
369
                }
×
370

371
                return err
×
372
        }, sqldb.NoOpReset)
373
        if err != nil {
×
374
                return fmt.Errorf("unable to delete node: %w", err)
×
375
        }
×
376

377
        return nil
×
378
}
379

380
// FetchNodeFeatures returns the features of the given node. If no features are
381
// known for the node, an empty feature vector is returned.
382
//
383
// NOTE: this is part of the graphdb.NodeTraverser interface.
384
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
385
        *lnwire.FeatureVector, error) {
×
386

×
387
        ctx := context.TODO()
×
388

×
389
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
390
}
×
391

392
// DisabledChannelIDs returns the channel ids of disabled channels.
393
// A channel is disabled when two of the associated ChanelEdgePolicies
394
// have their disabled bit on.
395
//
396
// NOTE: part of the V1Store interface.
397
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
398
        var (
×
399
                ctx     = context.TODO()
×
400
                chanIDs []uint64
×
401
        )
×
402
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
403
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
404
                if err != nil {
×
405
                        return fmt.Errorf("unable to fetch disabled "+
×
406
                                "channels: %w", err)
×
407
                }
×
408

409
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
410

×
411
                return nil
×
412
        }, sqldb.NoOpReset)
413
        if err != nil {
×
414
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
415
                        err)
×
416
        }
×
417

418
        return chanIDs, nil
×
419
}
420

421
// LookupAlias attempts to return the alias as advertised by the target node.
422
//
423
// NOTE: part of the V1Store interface.
424
func (s *SQLStore) LookupAlias(ctx context.Context,
425
        pub *btcec.PublicKey) (string, error) {
×
426

×
427
        var alias string
×
428
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
429
                dbNode, err := db.GetNodeByPubKey(
×
430
                        ctx, sqlc.GetNodeByPubKeyParams{
×
431
                                Version: int16(ProtocolV1),
×
432
                                PubKey:  pub.SerializeCompressed(),
×
433
                        },
×
434
                )
×
435
                if errors.Is(err, sql.ErrNoRows) {
×
436
                        return ErrNodeAliasNotFound
×
437
                } else if err != nil {
×
438
                        return fmt.Errorf("unable to fetch node: %w", err)
×
439
                }
×
440

441
                if !dbNode.Alias.Valid {
×
442
                        return ErrNodeAliasNotFound
×
443
                }
×
444

445
                alias = dbNode.Alias.String
×
446

×
447
                return nil
×
448
        }, sqldb.NoOpReset)
449
        if err != nil {
×
450
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
451
        }
×
452

453
        return alias, nil
×
454
}
455

456
// SourceNode returns the source node of the graph. The source node is treated
457
// as the center node within a star-graph. This method may be used to kick off
458
// a path finding algorithm in order to explore the reachability of another
459
// node based off the source node.
460
//
461
// NOTE: part of the V1Store interface.
462
func (s *SQLStore) SourceNode(ctx context.Context) (*models.LightningNode,
463
        error) {
×
464

×
465
        var node *models.LightningNode
×
466
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
467
                _, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
468
                if err != nil {
×
469
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
470
                                err)
×
471
                }
×
472

473
                _, node, err = getNodeByPubKey(ctx, db, nodePub)
×
474

×
475
                return err
×
476
        }, sqldb.NoOpReset)
477
        if err != nil {
×
478
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
479
        }
×
480

481
        return node, nil
×
482
}
483

484
// SetSourceNode sets the source node within the graph database. The source
485
// node is to be used as the center of a star-graph within path finding
486
// algorithms.
487
//
488
// NOTE: part of the V1Store interface.
489
func (s *SQLStore) SetSourceNode(ctx context.Context,
490
        node *models.LightningNode) error {
×
491

×
492
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
493
                id, err := upsertNode(ctx, db, node)
×
494
                if err != nil {
×
495
                        return fmt.Errorf("unable to upsert source node: %w",
×
496
                                err)
×
497
                }
×
498

499
                // Make sure that if a source node for this version is already
500
                // set, then the ID is the same as the one we are about to set.
501
                dbSourceNodeID, _, err := s.getSourceNode(ctx, db, ProtocolV1)
×
502
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
503
                        return fmt.Errorf("unable to fetch source node: %w",
×
504
                                err)
×
505
                } else if err == nil {
×
506
                        if dbSourceNodeID != id {
×
507
                                return fmt.Errorf("v1 source node already "+
×
508
                                        "set to a different node: %d vs %d",
×
509
                                        dbSourceNodeID, id)
×
510
                        }
×
511

512
                        return nil
×
513
                }
514

515
                return db.AddSourceNode(ctx, id)
×
516
        }, sqldb.NoOpReset)
517
}
518

519
// NodeUpdatesInHorizon returns all the known lightning node which have an
520
// update timestamp within the passed range. This method can be used by two
521
// nodes to quickly determine if they have the same set of up to date node
522
// announcements.
523
//
524
// NOTE: This is part of the V1Store interface.
525
func (s *SQLStore) NodeUpdatesInHorizon(startTime,
526
        endTime time.Time) ([]models.LightningNode, error) {
×
527

×
528
        ctx := context.TODO()
×
529

×
530
        var nodes []models.LightningNode
×
531
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
532
                dbNodes, err := db.GetNodesByLastUpdateRange(
×
533
                        ctx, sqlc.GetNodesByLastUpdateRangeParams{
×
534
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
535
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
536
                        },
×
537
                )
×
538
                if err != nil {
×
539
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
540
                }
×
541

542
                for _, dbNode := range dbNodes {
×
543
                        node, err := buildNode(ctx, db, &dbNode)
×
544
                        if err != nil {
×
545
                                return fmt.Errorf("unable to build node: %w",
×
546
                                        err)
×
547
                        }
×
548

549
                        nodes = append(nodes, *node)
×
550
                }
551

552
                return nil
×
553
        }, sqldb.NoOpReset)
554
        if err != nil {
×
555
                return nil, fmt.Errorf("unable to fetch nodes: %w", err)
×
556
        }
×
557

558
        return nodes, nil
×
559
}
560

561
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
562
// undirected edge from the two target nodes are created. The information stored
563
// denotes the static attributes of the channel, such as the channelID, the keys
564
// involved in creation of the channel, and the set of features that the channel
565
// supports. The chanPoint and chanID are used to uniquely identify the edge
566
// globally within the database.
567
//
568
// NOTE: part of the V1Store interface.
569
func (s *SQLStore) AddChannelEdge(ctx context.Context,
570
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
571

×
572
        var alreadyExists bool
×
573
        r := &batch.Request[SQLQueries]{
×
574
                Opts: batch.NewSchedulerOptions(opts...),
×
575
                Reset: func() {
×
576
                        alreadyExists = false
×
577
                },
×
578
                Do: func(tx SQLQueries) error {
×
579
                        _, err := insertChannel(ctx, tx, edge)
×
580

×
581
                        // Silence ErrEdgeAlreadyExist so that the batch can
×
582
                        // succeed, but propagate the error via local state.
×
583
                        if errors.Is(err, ErrEdgeAlreadyExist) {
×
584
                                alreadyExists = true
×
585
                                return nil
×
586
                        }
×
587

588
                        return err
×
589
                },
590
                OnCommit: func(err error) error {
×
591
                        switch {
×
592
                        case err != nil:
×
593
                                return err
×
594
                        case alreadyExists:
×
595
                                return ErrEdgeAlreadyExist
×
596
                        default:
×
597
                                s.rejectCache.remove(edge.ChannelID)
×
598
                                s.chanCache.remove(edge.ChannelID)
×
599
                                return nil
×
600
                        }
601
                },
602
        }
603

604
        return s.chanScheduler.Execute(ctx, r)
×
605
}
606

607
// HighestChanID returns the "highest" known channel ID in the channel graph.
608
// This represents the "newest" channel from the PoV of the chain. This method
609
// can be used by peers to quickly determine if their graphs are in sync.
610
//
611
// NOTE: This is part of the V1Store interface.
612
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
613
        var highestChanID uint64
×
614
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
615
                chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
×
616
                if errors.Is(err, sql.ErrNoRows) {
×
617
                        return nil
×
618
                } else if err != nil {
×
619
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
620
                                err)
×
621
                }
×
622

623
                highestChanID = byteOrder.Uint64(chanID)
×
624

×
625
                return nil
×
626
        }, sqldb.NoOpReset)
627
        if err != nil {
×
628
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
629
        }
×
630

631
        return highestChanID, nil
×
632
}
633

634
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
635
// within the database for the referenced channel. The `flags` attribute within
636
// the ChannelEdgePolicy determines which of the directed edges are being
637
// updated. If the flag is 1, then the first node's information is being
638
// updated, otherwise it's the second node's information. The node ordering is
639
// determined by the lexicographical ordering of the identity public keys of the
640
// nodes on either side of the channel.
641
//
642
// NOTE: part of the V1Store interface.
643
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
644
        edge *models.ChannelEdgePolicy,
645
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
646

×
647
        var (
×
648
                isUpdate1    bool
×
649
                edgeNotFound bool
×
650
                from, to     route.Vertex
×
651
        )
×
652

×
653
        r := &batch.Request[SQLQueries]{
×
654
                Opts: batch.NewSchedulerOptions(opts...),
×
655
                Reset: func() {
×
656
                        isUpdate1 = false
×
657
                        edgeNotFound = false
×
658
                },
×
659
                Do: func(tx SQLQueries) error {
×
660
                        var err error
×
661
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
662
                                ctx, tx, edge,
×
663
                        )
×
664
                        if err != nil {
×
665
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
666
                        }
×
667

668
                        // Silence ErrEdgeNotFound so that the batch can
669
                        // succeed, but propagate the error via local state.
670
                        if errors.Is(err, ErrEdgeNotFound) {
×
671
                                edgeNotFound = true
×
672
                                return nil
×
673
                        }
×
674

675
                        return err
×
676
                },
677
                OnCommit: func(err error) error {
×
678
                        switch {
×
679
                        case err != nil:
×
680
                                return err
×
681
                        case edgeNotFound:
×
682
                                return ErrEdgeNotFound
×
683
                        default:
×
684
                                s.updateEdgeCache(edge, isUpdate1)
×
685
                                return nil
×
686
                        }
687
                },
688
        }
689

690
        err := s.chanScheduler.Execute(ctx, r)
×
691

×
692
        return from, to, err
×
693
}
694

695
// updateEdgeCache updates our reject and channel caches with the new
696
// edge policy information.
697
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
698
        isUpdate1 bool) {
×
699

×
700
        // If an entry for this channel is found in reject cache, we'll modify
×
701
        // the entry with the updated timestamp for the direction that was just
×
702
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
703
        // during the next query for this edge.
×
704
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
705
                if isUpdate1 {
×
706
                        entry.upd1Time = e.LastUpdate.Unix()
×
707
                } else {
×
708
                        entry.upd2Time = e.LastUpdate.Unix()
×
709
                }
×
710
                s.rejectCache.insert(e.ChannelID, entry)
×
711
        }
712

713
        // If an entry for this channel is found in channel cache, we'll modify
714
        // the entry with the updated policy for the direction that was just
715
        // written. If the edge doesn't exist, we'll defer loading the info and
716
        // policies and lazily read from disk during the next query.
717
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
718
                if isUpdate1 {
×
719
                        channel.Policy1 = e
×
720
                } else {
×
721
                        channel.Policy2 = e
×
722
                }
×
723
                s.chanCache.insert(e.ChannelID, channel)
×
724
        }
725
}
726

727
// ForEachSourceNodeChannel iterates through all channels of the source node,
728
// executing the passed callback on each. The call-back is provided with the
729
// channel's outpoint, whether we have a policy for the channel and the channel
730
// peer's node information.
731
//
732
// NOTE: part of the V1Store interface.
733
func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context,
734
        cb func(chanPoint wire.OutPoint, havePolicy bool,
735
                otherNode *models.LightningNode) error, reset func()) error {
×
736

×
737
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
738
                nodeID, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
739
                if err != nil {
×
740
                        return fmt.Errorf("unable to fetch source node: %w",
×
741
                                err)
×
742
                }
×
743

744
                return forEachNodeChannel(
×
745
                        ctx, db, s.cfg.ChainHash, nodeID,
×
746
                        func(info *models.ChannelEdgeInfo,
×
747
                                outPolicy *models.ChannelEdgePolicy,
×
748
                                _ *models.ChannelEdgePolicy) error {
×
749

×
750
                                // Fetch the other node.
×
751
                                var (
×
752
                                        otherNodePub [33]byte
×
753
                                        node1        = info.NodeKey1Bytes
×
754
                                        node2        = info.NodeKey2Bytes
×
755
                                )
×
756
                                switch {
×
757
                                case bytes.Equal(node1[:], nodePub[:]):
×
758
                                        otherNodePub = node2
×
759
                                case bytes.Equal(node2[:], nodePub[:]):
×
760
                                        otherNodePub = node1
×
761
                                default:
×
762
                                        return fmt.Errorf("node not " +
×
763
                                                "participating in this channel")
×
764
                                }
765

766
                                _, otherNode, err := getNodeByPubKey(
×
767
                                        ctx, db, otherNodePub,
×
768
                                )
×
769
                                if err != nil {
×
770
                                        return fmt.Errorf("unable to fetch "+
×
771
                                                "other node(%x): %w",
×
772
                                                otherNodePub, err)
×
773
                                }
×
774

775
                                return cb(
×
776
                                        info.ChannelPoint, outPolicy != nil,
×
777
                                        otherNode,
×
778
                                )
×
779
                        },
780
                )
781
        }, reset)
782
}
783

784
// ForEachNode iterates through all the stored vertices/nodes in the graph,
785
// executing the passed callback with each node encountered. If the callback
786
// returns an error, then the transaction is aborted and the iteration stops
787
// early. Any operations performed on the NodeTx passed to the call-back are
788
// executed under the same read transaction and so, methods on the NodeTx object
789
// _MUST_ only be called from within the call-back.
790
//
791
// NOTE: part of the V1Store interface.
792
func (s *SQLStore) ForEachNode(ctx context.Context,
793
        cb func(tx NodeRTx) error, reset func()) error {
×
794

×
795
        var lastID int64 = 0
×
796
        handleNode := func(db SQLQueries, dbNode sqlc.GraphNode) error {
×
797
                node, err := buildNode(ctx, db, &dbNode)
×
798
                if err != nil {
×
799
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
800
                                dbNode.ID, err)
×
801
                }
×
802

803
                err = cb(
×
804
                        newSQLGraphNodeTx(db, s.cfg.ChainHash, dbNode.ID, node),
×
805
                )
×
806
                if err != nil {
×
807
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
808
                                dbNode.ID, err)
×
809
                }
×
810

811
                return nil
×
812
        }
813

814
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
815
                for {
×
816
                        nodes, err := db.ListNodesPaginated(
×
817
                                ctx, sqlc.ListNodesPaginatedParams{
×
818
                                        Version: int16(ProtocolV1),
×
819
                                        ID:      lastID,
×
820
                                        Limit:   pageSize,
×
821
                                },
×
822
                        )
×
823
                        if err != nil {
×
824
                                return fmt.Errorf("unable to fetch nodes: %w",
×
825
                                        err)
×
826
                        }
×
827

828
                        if len(nodes) == 0 {
×
829
                                break
×
830
                        }
831

832
                        for _, dbNode := range nodes {
×
833
                                err = handleNode(db, dbNode)
×
834
                                if err != nil {
×
835
                                        return err
×
836
                                }
×
837

838
                                lastID = dbNode.ID
×
839
                        }
840
                }
841

842
                return nil
×
843
        }, reset)
844
}
845

846
// sqlGraphNodeTx is an implementation of the NodeRTx interface backed by the
847
// SQLStore and a SQL transaction.
848
type sqlGraphNodeTx struct {
849
        db    SQLQueries
850
        id    int64
851
        node  *models.LightningNode
852
        chain chainhash.Hash
853
}
854

855
// A compile-time constraint to ensure sqlGraphNodeTx implements the NodeRTx
856
// interface.
857
var _ NodeRTx = (*sqlGraphNodeTx)(nil)
858

859
func newSQLGraphNodeTx(db SQLQueries, chain chainhash.Hash,
860
        id int64, node *models.LightningNode) *sqlGraphNodeTx {
×
861

×
862
        return &sqlGraphNodeTx{
×
863
                db:    db,
×
864
                chain: chain,
×
865
                id:    id,
×
866
                node:  node,
×
867
        }
×
868
}
×
869

870
// Node returns the raw information of the node.
871
//
872
// NOTE: This is a part of the NodeRTx interface.
873
func (s *sqlGraphNodeTx) Node() *models.LightningNode {
×
874
        return s.node
×
875
}
×
876

877
// ForEachChannel can be used to iterate over the node's channels under the same
878
// transaction used to fetch the node.
879
//
880
// NOTE: This is a part of the NodeRTx interface.
881
func (s *sqlGraphNodeTx) ForEachChannel(cb func(*models.ChannelEdgeInfo,
882
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
×
883

×
884
        ctx := context.TODO()
×
885

×
886
        return forEachNodeChannel(ctx, s.db, s.chain, s.id, cb)
×
887
}
×
888

889
// FetchNode fetches the node with the given pub key under the same transaction
890
// used to fetch the current node. The returned node is also a NodeRTx and any
891
// operations on that NodeRTx will also be done under the same transaction.
892
//
893
// NOTE: This is a part of the NodeRTx interface.
894
func (s *sqlGraphNodeTx) FetchNode(nodePub route.Vertex) (NodeRTx, error) {
×
895
        ctx := context.TODO()
×
896

×
897
        id, node, err := getNodeByPubKey(ctx, s.db, nodePub)
×
898
        if err != nil {
×
899
                return nil, fmt.Errorf("unable to fetch V1 node(%x): %w",
×
900
                        nodePub, err)
×
901
        }
×
902

903
        return newSQLGraphNodeTx(s.db, s.chain, id, node), nil
×
904
}
905

906
// ForEachNodeDirectedChannel iterates through all channels of a given node,
907
// executing the passed callback on the directed edge representing the channel
908
// and its incoming policy. If the callback returns an error, then the iteration
909
// is halted with the error propagated back up to the caller.
910
//
911
// Unknown policies are passed into the callback as nil values.
912
//
913
// NOTE: this is part of the graphdb.NodeTraverser interface.
914
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
915
        cb func(channel *DirectedChannel) error, reset func()) error {
×
916

×
917
        var ctx = context.TODO()
×
918

×
919
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
920
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
921
        }, reset)
×
922
}
923

924
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
925
// graph, executing the passed callback with each node encountered. If the
926
// callback returns an error, then the transaction is aborted and the iteration
927
// stops early.
928
//
929
// NOTE: This is a part of the V1Store interface.
930
func (s *SQLStore) ForEachNodeCacheable(ctx context.Context,
931
        cb func(route.Vertex, *lnwire.FeatureVector) error,
932
        reset func()) error {
×
933

×
934
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
935
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
936
                        nodePub route.Vertex) error {
×
937

×
938
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
939
                        if err != nil {
×
940
                                return fmt.Errorf("unable to fetch node "+
×
941
                                        "features: %w", err)
×
942
                        }
×
943

944
                        return cb(nodePub, features)
×
945
                })
946
        }, reset)
947
        if err != nil {
×
948
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
949
        }
×
950

951
        return nil
×
952
}
953

954
// ForEachNodeChannel iterates through all channels of the given node,
955
// executing the passed callback with an edge info structure and the policies
956
// of each end of the channel. The first edge policy is the outgoing edge *to*
957
// the connecting node, while the second is the incoming edge *from* the
958
// connecting node. If the callback returns an error, then the iteration is
959
// halted with the error propagated back up to the caller.
960
//
961
// Unknown policies are passed into the callback as nil values.
962
//
963
// NOTE: part of the V1Store interface.
964
func (s *SQLStore) ForEachNodeChannel(ctx context.Context, nodePub route.Vertex,
965
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
966
                *models.ChannelEdgePolicy) error, reset func()) error {
×
967

×
968
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
969
                dbNode, err := db.GetNodeByPubKey(
×
970
                        ctx, sqlc.GetNodeByPubKeyParams{
×
971
                                Version: int16(ProtocolV1),
×
972
                                PubKey:  nodePub[:],
×
973
                        },
×
974
                )
×
975
                if errors.Is(err, sql.ErrNoRows) {
×
976
                        return nil
×
977
                } else if err != nil {
×
978
                        return fmt.Errorf("unable to fetch node: %w", err)
×
979
                }
×
980

981
                return forEachNodeChannel(
×
982
                        ctx, db, s.cfg.ChainHash, dbNode.ID, cb,
×
983
                )
×
984
        }, reset)
985
}
986

987
// ChanUpdatesInHorizon returns all the known channel edges which have at least
988
// one edge that has an update timestamp within the specified horizon.
989
//
990
// NOTE: This is part of the V1Store interface.
991
func (s *SQLStore) ChanUpdatesInHorizon(startTime,
992
        endTime time.Time) ([]ChannelEdge, error) {
×
993

×
994
        s.cacheMu.Lock()
×
995
        defer s.cacheMu.Unlock()
×
996

×
997
        var (
×
998
                ctx = context.TODO()
×
999
                // To ensure we don't return duplicate ChannelEdges, we'll use
×
1000
                // an additional map to keep track of the edges already seen to
×
1001
                // prevent re-adding it.
×
1002
                edgesSeen    = make(map[uint64]struct{})
×
1003
                edgesToCache = make(map[uint64]ChannelEdge)
×
1004
                edges        []ChannelEdge
×
1005
                hits         int
×
1006
        )
×
1007
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1008
                rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
1009
                        ctx, sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
1010
                                Version:   int16(ProtocolV1),
×
1011
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
1012
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
1013
                        },
×
1014
                )
×
1015
                if err != nil {
×
1016
                        return err
×
1017
                }
×
1018

1019
                for _, row := range rows {
×
1020
                        // If we've already retrieved the info and policies for
×
1021
                        // this edge, then we can skip it as we don't need to do
×
1022
                        // so again.
×
1023
                        chanIDInt := byteOrder.Uint64(row.GraphChannel.Scid)
×
1024
                        if _, ok := edgesSeen[chanIDInt]; ok {
×
1025
                                continue
×
1026
                        }
1027

1028
                        if channel, ok := s.chanCache.get(chanIDInt); ok {
×
1029
                                hits++
×
1030
                                edgesSeen[chanIDInt] = struct{}{}
×
1031
                                edges = append(edges, channel)
×
1032

×
1033
                                continue
×
1034
                        }
1035

1036
                        node1, node2, err := buildNodes(
×
1037
                                ctx, db, row.GraphNode, row.GraphNode_2,
×
1038
                        )
×
1039
                        if err != nil {
×
1040
                                return err
×
1041
                        }
×
1042

1043
                        channel, err := getAndBuildEdgeInfo(
×
1044
                                ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
1045
                                row.GraphChannel, node1.PubKeyBytes,
×
1046
                                node2.PubKeyBytes,
×
1047
                        )
×
1048
                        if err != nil {
×
1049
                                return fmt.Errorf("unable to build channel "+
×
1050
                                        "info: %w", err)
×
1051
                        }
×
1052

1053
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1054
                        if err != nil {
×
1055
                                return fmt.Errorf("unable to extract channel "+
×
1056
                                        "policies: %w", err)
×
1057
                        }
×
1058

1059
                        p1, p2, err := getAndBuildChanPolicies(
×
1060
                                ctx, db, dbPol1, dbPol2, channel.ChannelID,
×
1061
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
1062
                        )
×
1063
                        if err != nil {
×
1064
                                return fmt.Errorf("unable to build channel "+
×
1065
                                        "policies: %w", err)
×
1066
                        }
×
1067

1068
                        edgesSeen[chanIDInt] = struct{}{}
×
1069
                        chanEdge := ChannelEdge{
×
1070
                                Info:    channel,
×
1071
                                Policy1: p1,
×
1072
                                Policy2: p2,
×
1073
                                Node1:   node1,
×
1074
                                Node2:   node2,
×
1075
                        }
×
1076
                        edges = append(edges, chanEdge)
×
1077
                        edgesToCache[chanIDInt] = chanEdge
×
1078
                }
1079

1080
                return nil
×
1081
        }, func() {
×
1082
                edgesSeen = make(map[uint64]struct{})
×
1083
                edgesToCache = make(map[uint64]ChannelEdge)
×
1084
                edges = nil
×
1085
        })
×
1086
        if err != nil {
×
1087
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
1088
        }
×
1089

1090
        // Insert any edges loaded from disk into the cache.
1091
        for chanid, channel := range edgesToCache {
×
1092
                s.chanCache.insert(chanid, channel)
×
1093
        }
×
1094

1095
        if len(edges) > 0 {
×
1096
                log.Debugf("ChanUpdatesInHorizon hit percentage: %.2f (%d/%d)",
×
1097
                        float64(hits)*100/float64(len(edges)), hits, len(edges))
×
1098
        } else {
×
1099
                log.Debugf("ChanUpdatesInHorizon returned no edges in "+
×
1100
                        "horizon (%s, %s)", startTime, endTime)
×
1101
        }
×
1102

1103
        return edges, nil
×
1104
}
1105

1106
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1107
// data to the call-back.
1108
//
1109
// NOTE: The callback contents MUST not be modified.
1110
//
1111
// NOTE: part of the V1Store interface.
1112
func (s *SQLStore) ForEachNodeCached(ctx context.Context,
1113
        cb func(node route.Vertex, chans map[uint64]*DirectedChannel) error,
1114
        reset func()) error {
×
1115

×
1116
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1117
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
1118
                        nodePub route.Vertex) error {
×
1119

×
1120
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
1121
                        if err != nil {
×
1122
                                return fmt.Errorf("unable to fetch "+
×
1123
                                        "node(id=%d) features: %w", nodeID, err)
×
1124
                        }
×
1125

1126
                        toNodeCallback := func() route.Vertex {
×
1127
                                return nodePub
×
1128
                        }
×
1129

1130
                        rows, err := db.ListChannelsByNodeID(
×
1131
                                ctx, sqlc.ListChannelsByNodeIDParams{
×
1132
                                        Version: int16(ProtocolV1),
×
1133
                                        NodeID1: nodeID,
×
1134
                                },
×
1135
                        )
×
1136
                        if err != nil {
×
1137
                                return fmt.Errorf("unable to fetch channels "+
×
1138
                                        "of node(id=%d): %w", nodeID, err)
×
1139
                        }
×
1140

1141
                        channels := make(map[uint64]*DirectedChannel, len(rows))
×
1142
                        for _, row := range rows {
×
1143
                                node1, node2, err := buildNodeVertices(
×
1144
                                        row.Node1Pubkey, row.Node2Pubkey,
×
1145
                                )
×
1146
                                if err != nil {
×
1147
                                        return err
×
1148
                                }
×
1149

1150
                                e, err := getAndBuildEdgeInfo(
×
1151
                                        ctx, db, s.cfg.ChainHash,
×
1152
                                        row.GraphChannel.ID, row.GraphChannel,
×
1153
                                        node1, node2,
×
1154
                                )
×
1155
                                if err != nil {
×
1156
                                        return fmt.Errorf("unable to build "+
×
1157
                                                "channel info: %w", err)
×
1158
                                }
×
1159

1160
                                dbPol1, dbPol2, err := extractChannelPolicies(
×
1161
                                        row,
×
1162
                                )
×
1163
                                if err != nil {
×
1164
                                        return fmt.Errorf("unable to "+
×
1165
                                                "extract channel "+
×
1166
                                                "policies: %w", err)
×
1167
                                }
×
1168

1169
                                p1, p2, err := getAndBuildChanPolicies(
×
1170
                                        ctx, db, dbPol1, dbPol2, e.ChannelID,
×
1171
                                        node1, node2,
×
1172
                                )
×
1173
                                if err != nil {
×
1174
                                        return fmt.Errorf("unable to "+
×
1175
                                                "build channel policies: %w",
×
1176
                                                err)
×
1177
                                }
×
1178

1179
                                // Determine the outgoing and incoming policy
1180
                                // for this channel and node combo.
1181
                                outPolicy, inPolicy := p1, p2
×
1182
                                if p1 != nil && p1.ToNode == nodePub {
×
1183
                                        outPolicy, inPolicy = p2, p1
×
1184
                                } else if p2 != nil && p2.ToNode != nodePub {
×
1185
                                        outPolicy, inPolicy = p2, p1
×
1186
                                }
×
1187

1188
                                var cachedInPolicy *models.CachedEdgePolicy
×
1189
                                if inPolicy != nil {
×
1190
                                        cachedInPolicy = models.NewCachedPolicy(
×
1191
                                                p2,
×
1192
                                        )
×
1193
                                        cachedInPolicy.ToNodePubKey =
×
1194
                                                toNodeCallback
×
1195
                                        cachedInPolicy.ToNodeFeatures =
×
1196
                                                features
×
1197
                                }
×
1198

1199
                                var inboundFee lnwire.Fee
×
1200
                                outPolicy.InboundFee.WhenSome(
×
1201
                                        func(fee lnwire.Fee) {
×
1202
                                                inboundFee = fee
×
1203
                                        },
×
1204
                                )
1205

1206
                                directedChannel := &DirectedChannel{
×
1207
                                        ChannelID: e.ChannelID,
×
1208
                                        IsNode1: nodePub ==
×
1209
                                                e.NodeKey1Bytes,
×
1210
                                        OtherNode:    e.NodeKey2Bytes,
×
1211
                                        Capacity:     e.Capacity,
×
1212
                                        OutPolicySet: p1 != nil,
×
1213
                                        InPolicy:     cachedInPolicy,
×
1214
                                        InboundFee:   inboundFee,
×
1215
                                }
×
1216

×
1217
                                if nodePub == e.NodeKey2Bytes {
×
1218
                                        directedChannel.OtherNode =
×
1219
                                                e.NodeKey1Bytes
×
1220
                                }
×
1221

1222
                                channels[e.ChannelID] = directedChannel
×
1223
                        }
1224

1225
                        return cb(nodePub, channels)
×
1226
                })
1227
        }, reset)
1228
}
1229

1230
// ForEachChannelCacheable iterates through all the channel edges stored
1231
// within the graph and invokes the passed callback for each edge. The
1232
// callback takes two edges as since this is a directed graph, both the
1233
// in/out edges are visited. If the callback returns an error, then the
1234
// transaction is aborted and the iteration stops early.
1235
//
1236
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1237
// pointer for that particular channel edge routing policy will be
1238
// passed into the callback.
1239
//
1240
// NOTE: this method is like ForEachChannel but fetches only the data
1241
// required for the graph cache.
1242
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1243
        *models.CachedEdgePolicy, *models.CachedEdgePolicy) error,
1244
        reset func()) error {
×
1245

×
1246
        ctx := context.TODO()
×
1247

×
1248
        handleChannel := func(db SQLQueries,
×
1249
                row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
×
1250

×
1251
                node1, node2, err := buildNodeVertices(
×
1252
                        row.Node1Pubkey, row.Node2Pubkey,
×
1253
                )
×
1254
                if err != nil {
×
1255
                        return err
×
1256
                }
×
1257

1258
                edge := buildCacheableChannelInfo(
×
1259
                        row.GraphChannel, node1, node2,
×
1260
                )
×
1261

×
1262
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1263
                if err != nil {
×
1264
                        return err
×
1265
                }
×
1266

1267
                var pol1, pol2 *models.CachedEdgePolicy
×
1268
                if dbPol1 != nil {
×
1269
                        policy1, err := buildChanPolicy(
×
1270
                                *dbPol1, edge.ChannelID, nil, node2,
×
1271
                        )
×
1272
                        if err != nil {
×
1273
                                return err
×
1274
                        }
×
1275

1276
                        pol1 = models.NewCachedPolicy(policy1)
×
1277
                }
1278
                if dbPol2 != nil {
×
1279
                        policy2, err := buildChanPolicy(
×
1280
                                *dbPol2, edge.ChannelID, nil, node1,
×
1281
                        )
×
1282
                        if err != nil {
×
1283
                                return err
×
1284
                        }
×
1285

1286
                        pol2 = models.NewCachedPolicy(policy2)
×
1287
                }
1288

1289
                if err := cb(edge, pol1, pol2); err != nil {
×
1290
                        return err
×
1291
                }
×
1292

1293
                return nil
×
1294
        }
1295

1296
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1297
                lastID := int64(-1)
×
1298
                for {
×
1299
                        //nolint:ll
×
1300
                        rows, err := db.ListChannelsWithPoliciesPaginated(
×
1301
                                ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
1302
                                        Version: int16(ProtocolV1),
×
1303
                                        ID:      lastID,
×
1304
                                        Limit:   pageSize,
×
1305
                                },
×
1306
                        )
×
1307
                        if err != nil {
×
1308
                                return err
×
1309
                        }
×
1310

1311
                        if len(rows) == 0 {
×
1312
                                break
×
1313
                        }
1314

1315
                        for _, row := range rows {
×
1316
                                err := handleChannel(db, row)
×
1317
                                if err != nil {
×
1318
                                        return err
×
1319
                                }
×
1320

1321
                                lastID = row.GraphChannel.ID
×
1322
                        }
1323
                }
1324

1325
                return nil
×
1326
        }, reset)
1327
}
1328

1329
// ForEachChannel iterates through all the channel edges stored within the
1330
// graph and invokes the passed callback for each edge. The callback takes two
1331
// edges as since this is a directed graph, both the in/out edges are visited.
1332
// If the callback returns an error, then the transaction is aborted and the
1333
// iteration stops early.
1334
//
1335
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1336
// for that particular channel edge routing policy will be passed into the
1337
// callback.
1338
//
1339
// NOTE: part of the V1Store interface.
1340
func (s *SQLStore) ForEachChannel(ctx context.Context,
1341
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1342
                *models.ChannelEdgePolicy) error, reset func()) error {
×
1343

×
1344
        handleChannel := func(db SQLQueries,
×
1345
                row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
×
1346

×
1347
                node1, node2, err := buildNodeVertices(
×
1348
                        row.Node1Pubkey, row.Node2Pubkey,
×
1349
                )
×
1350
                if err != nil {
×
1351
                        return fmt.Errorf("unable to build node vertices: %w",
×
1352
                                err)
×
1353
                }
×
1354

1355
                edge, err := getAndBuildEdgeInfo(
×
1356
                        ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
1357
                        row.GraphChannel, node1, node2,
×
1358
                )
×
1359
                if err != nil {
×
1360
                        return fmt.Errorf("unable to build channel info: %w",
×
1361
                                err)
×
1362
                }
×
1363

1364
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1365
                if err != nil {
×
1366
                        return fmt.Errorf("unable to extract channel "+
×
1367
                                "policies: %w", err)
×
1368
                }
×
1369

1370
                p1, p2, err := getAndBuildChanPolicies(
×
1371
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1372
                )
×
1373
                if err != nil {
×
1374
                        return fmt.Errorf("unable to build channel "+
×
1375
                                "policies: %w", err)
×
1376
                }
×
1377

1378
                err = cb(edge, p1, p2)
×
1379
                if err != nil {
×
1380
                        return fmt.Errorf("callback failed for channel "+
×
1381
                                "id=%d: %w", edge.ChannelID, err)
×
1382
                }
×
1383

1384
                return nil
×
1385
        }
1386

1387
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1388
                lastID := int64(-1)
×
1389
                for {
×
1390
                        //nolint:ll
×
1391
                        rows, err := db.ListChannelsWithPoliciesPaginated(
×
1392
                                ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
1393
                                        Version: int16(ProtocolV1),
×
1394
                                        ID:      lastID,
×
1395
                                        Limit:   pageSize,
×
1396
                                },
×
1397
                        )
×
1398
                        if err != nil {
×
1399
                                return err
×
1400
                        }
×
1401

1402
                        if len(rows) == 0 {
×
1403
                                break
×
1404
                        }
1405

1406
                        for _, row := range rows {
×
1407
                                err := handleChannel(db, row)
×
1408
                                if err != nil {
×
1409
                                        return err
×
1410
                                }
×
1411

1412
                                lastID = row.GraphChannel.ID
×
1413
                        }
1414
                }
1415

1416
                return nil
×
1417
        }, reset)
1418
}
1419

1420
// FilterChannelRange returns the channel ID's of all known channels which were
1421
// mined in a block height within the passed range. The channel IDs are grouped
1422
// by their common block height. This method can be used to quickly share with a
1423
// peer the set of channels we know of within a particular range to catch them
1424
// up after a period of time offline. If withTimestamps is true then the
1425
// timestamp info of the latest received channel update messages of the channel
1426
// will be included in the response.
1427
//
1428
// NOTE: This is part of the V1Store interface.
1429
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1430
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1431

×
1432
        var (
×
1433
                ctx       = context.TODO()
×
1434
                startSCID = &lnwire.ShortChannelID{
×
1435
                        BlockHeight: startHeight,
×
1436
                }
×
1437
                endSCID = lnwire.ShortChannelID{
×
1438
                        BlockHeight: endHeight,
×
1439
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1440
                        TxPosition:  math.MaxUint16,
×
1441
                }
×
1442
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1443
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1444
        )
×
1445

×
1446
        // 1) get all channels where channelID is between start and end chan ID.
×
1447
        // 2) skip if not public (ie, no channel_proof)
×
1448
        // 3) collect that channel.
×
1449
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1450
        //    and add those timestamps to the collected channel.
×
1451
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1452
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1453
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1454
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1455
                                StartScid: chanIDStart,
×
1456
                                EndScid:   chanIDEnd,
×
1457
                        },
×
1458
                )
×
1459
                if err != nil {
×
1460
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1461
                                err)
×
1462
                }
×
1463

1464
                for _, dbChan := range dbChans {
×
1465
                        cid := lnwire.NewShortChanIDFromInt(
×
1466
                                byteOrder.Uint64(dbChan.Scid),
×
1467
                        )
×
1468
                        chanInfo := NewChannelUpdateInfo(
×
1469
                                cid, time.Time{}, time.Time{},
×
1470
                        )
×
1471

×
1472
                        if !withTimestamps {
×
1473
                                channelsPerBlock[cid.BlockHeight] = append(
×
1474
                                        channelsPerBlock[cid.BlockHeight],
×
1475
                                        chanInfo,
×
1476
                                )
×
1477

×
1478
                                continue
×
1479
                        }
1480

1481
                        //nolint:ll
1482
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1483
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1484
                                        Version:   int16(ProtocolV1),
×
1485
                                        ChannelID: dbChan.ID,
×
1486
                                        NodeID:    dbChan.NodeID1,
×
1487
                                },
×
1488
                        )
×
1489
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1490
                                return fmt.Errorf("unable to fetch node1 "+
×
1491
                                        "policy: %w", err)
×
1492
                        } else if err == nil {
×
1493
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1494
                                        node1Policy.LastUpdate.Int64, 0,
×
1495
                                )
×
1496
                        }
×
1497

1498
                        //nolint:ll
1499
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1500
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1501
                                        Version:   int16(ProtocolV1),
×
1502
                                        ChannelID: dbChan.ID,
×
1503
                                        NodeID:    dbChan.NodeID2,
×
1504
                                },
×
1505
                        )
×
1506
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1507
                                return fmt.Errorf("unable to fetch node2 "+
×
1508
                                        "policy: %w", err)
×
1509
                        } else if err == nil {
×
1510
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1511
                                        node2Policy.LastUpdate.Int64, 0,
×
1512
                                )
×
1513
                        }
×
1514

1515
                        channelsPerBlock[cid.BlockHeight] = append(
×
1516
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1517
                        )
×
1518
                }
1519

1520
                return nil
×
1521
        }, func() {
×
1522
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1523
        })
×
1524
        if err != nil {
×
1525
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1526
        }
×
1527

1528
        if len(channelsPerBlock) == 0 {
×
1529
                return nil, nil
×
1530
        }
×
1531

1532
        // Return the channel ranges in ascending block height order.
1533
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1534
        slices.Sort(blocks)
×
1535

×
1536
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1537
                return BlockChannelRange{
×
1538
                        Height:   block,
×
1539
                        Channels: channelsPerBlock[block],
×
1540
                }
×
1541
        }), nil
×
1542
}
1543

1544
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1545
// zombie. This method is used on an ad-hoc basis, when channels need to be
1546
// marked as zombies outside the normal pruning cycle.
1547
//
1548
// NOTE: part of the V1Store interface.
1549
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1550
        pubKey1, pubKey2 [33]byte) error {
×
1551

×
1552
        ctx := context.TODO()
×
1553

×
1554
        s.cacheMu.Lock()
×
1555
        defer s.cacheMu.Unlock()
×
1556

×
1557
        chanIDB := channelIDToBytes(chanID)
×
1558

×
1559
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1560
                return db.UpsertZombieChannel(
×
1561
                        ctx, sqlc.UpsertZombieChannelParams{
×
1562
                                Version:  int16(ProtocolV1),
×
1563
                                Scid:     chanIDB,
×
1564
                                NodeKey1: pubKey1[:],
×
1565
                                NodeKey2: pubKey2[:],
×
1566
                        },
×
1567
                )
×
1568
        }, sqldb.NoOpReset)
×
1569
        if err != nil {
×
1570
                return fmt.Errorf("unable to upsert zombie channel "+
×
1571
                        "(channel_id=%d): %w", chanID, err)
×
1572
        }
×
1573

1574
        s.rejectCache.remove(chanID)
×
1575
        s.chanCache.remove(chanID)
×
1576

×
1577
        return nil
×
1578
}
1579

1580
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1581
//
1582
// NOTE: part of the V1Store interface.
1583
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1584
        s.cacheMu.Lock()
×
1585
        defer s.cacheMu.Unlock()
×
1586

×
1587
        var (
×
1588
                ctx     = context.TODO()
×
1589
                chanIDB = channelIDToBytes(chanID)
×
1590
        )
×
1591

×
1592
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1593
                res, err := db.DeleteZombieChannel(
×
1594
                        ctx, sqlc.DeleteZombieChannelParams{
×
1595
                                Scid:    chanIDB,
×
1596
                                Version: int16(ProtocolV1),
×
1597
                        },
×
1598
                )
×
1599
                if err != nil {
×
1600
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1601
                                err)
×
1602
                }
×
1603

1604
                rows, err := res.RowsAffected()
×
1605
                if err != nil {
×
1606
                        return err
×
1607
                }
×
1608

1609
                if rows == 0 {
×
1610
                        return ErrZombieEdgeNotFound
×
1611
                } else if rows > 1 {
×
1612
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1613
                                "expected 1", rows)
×
1614
                }
×
1615

1616
                return nil
×
1617
        }, sqldb.NoOpReset)
1618
        if err != nil {
×
1619
                return fmt.Errorf("unable to mark edge live "+
×
1620
                        "(channel_id=%d): %w", chanID, err)
×
1621
        }
×
1622

1623
        s.rejectCache.remove(chanID)
×
1624
        s.chanCache.remove(chanID)
×
1625

×
1626
        return err
×
1627
}
1628

1629
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1630
// zombie, then the two node public keys corresponding to this edge are also
1631
// returned.
1632
//
1633
// NOTE: part of the V1Store interface.
1634
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
1635
        error) {
×
1636

×
1637
        var (
×
1638
                ctx              = context.TODO()
×
1639
                isZombie         bool
×
1640
                pubKey1, pubKey2 route.Vertex
×
1641
                chanIDB          = channelIDToBytes(chanID)
×
1642
        )
×
1643

×
1644
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1645
                zombie, err := db.GetZombieChannel(
×
1646
                        ctx, sqlc.GetZombieChannelParams{
×
1647
                                Scid:    chanIDB,
×
1648
                                Version: int16(ProtocolV1),
×
1649
                        },
×
1650
                )
×
1651
                if errors.Is(err, sql.ErrNoRows) {
×
1652
                        return nil
×
1653
                }
×
1654
                if err != nil {
×
1655
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1656
                                err)
×
1657
                }
×
1658

1659
                copy(pubKey1[:], zombie.NodeKey1)
×
1660
                copy(pubKey2[:], zombie.NodeKey2)
×
1661
                isZombie = true
×
1662

×
1663
                return nil
×
1664
        }, sqldb.NoOpReset)
1665
        if err != nil {
×
1666
                return false, route.Vertex{}, route.Vertex{},
×
1667
                        fmt.Errorf("%w: %w (chanID=%d)",
×
1668
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
×
1669
        }
×
1670

1671
        return isZombie, pubKey1, pubKey2, nil
×
1672
}
1673

1674
// NumZombies returns the current number of zombie channels in the graph.
1675
//
1676
// NOTE: part of the V1Store interface.
1677
func (s *SQLStore) NumZombies() (uint64, error) {
×
1678
        var (
×
1679
                ctx        = context.TODO()
×
1680
                numZombies uint64
×
1681
        )
×
1682
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1683
                count, err := db.CountZombieChannels(ctx, int16(ProtocolV1))
×
1684
                if err != nil {
×
1685
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1686
                                err)
×
1687
                }
×
1688

1689
                numZombies = uint64(count)
×
1690

×
1691
                return nil
×
1692
        }, sqldb.NoOpReset)
1693
        if err != nil {
×
1694
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1695
        }
×
1696

1697
        return numZombies, nil
×
1698
}
1699

1700
// DeleteChannelEdges removes edges with the given channel IDs from the
1701
// database and marks them as zombies. This ensures that we're unable to re-add
1702
// it to our database once again. If an edge does not exist within the
1703
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1704
// true, then when we mark these edges as zombies, we'll set up the keys such
1705
// that we require the node that failed to send the fresh update to be the one
1706
// that resurrects the channel from its zombie state. The markZombie bool
1707
// denotes whether to mark the channel as a zombie.
1708
//
1709
// NOTE: part of the V1Store interface.
1710
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1711
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
1712

×
1713
        s.cacheMu.Lock()
×
1714
        defer s.cacheMu.Unlock()
×
1715

×
NEW
1716
        // Keep track of which channels we end up finding so that we can
×
NEW
1717
        // correctly return ErrEdgeNotFound if we do not find a channel.
×
NEW
1718
        chanLookup := make(map[uint64]struct{}, len(chanIDs))
×
NEW
1719
        for _, chanID := range chanIDs {
×
NEW
1720
                chanLookup[chanID] = struct{}{}
×
NEW
1721
        }
×
1722

1723
        var (
×
1724
                ctx     = context.TODO()
×
1725
                deleted []*models.ChannelEdgeInfo
×
1726
        )
×
1727
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
NEW
1728
                chanIDsToDelete := make([]int64, 0, len(chanIDs))
×
NEW
1729
                chanCallBack := func(ctx context.Context,
×
NEW
1730
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
1731

×
NEW
1732
                        // Deleting the entry from the map indicates that we
×
NEW
1733
                        // have found the channel.
×
NEW
1734
                        scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
NEW
1735
                        delete(chanLookup, scid)
×
1736

×
1737
                        node1, node2, err := buildNodeVertices(
×
1738
                                row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
1739
                        )
×
1740
                        if err != nil {
×
1741
                                return err
×
1742
                        }
×
1743

1744
                        info, err := getAndBuildEdgeInfo(
×
1745
                                ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
1746
                                row.GraphChannel, node1, node2,
×
1747
                        )
×
1748
                        if err != nil {
×
1749
                                return err
×
1750
                        }
×
1751

1752
                        deleted = append(deleted, info)
×
NEW
1753
                        chanIDsToDelete = append(
×
NEW
1754
                                chanIDsToDelete, row.GraphChannel.ID,
×
NEW
1755
                        )
×
1756

×
1757
                        if !markZombie {
×
NEW
1758
                                return nil
×
UNCOV
1759
                        }
×
1760

1761
                        nodeKey1, nodeKey2 := info.NodeKey1Bytes,
×
1762
                                info.NodeKey2Bytes
×
1763
                        if strictZombiePruning {
×
1764
                                var e1UpdateTime, e2UpdateTime *time.Time
×
1765
                                if row.Policy1LastUpdate.Valid {
×
1766
                                        e1Time := time.Unix(
×
1767
                                                row.Policy1LastUpdate.Int64, 0,
×
1768
                                        )
×
1769
                                        e1UpdateTime = &e1Time
×
1770
                                }
×
1771
                                if row.Policy2LastUpdate.Valid {
×
1772
                                        e2Time := time.Unix(
×
1773
                                                row.Policy2LastUpdate.Int64, 0,
×
1774
                                        )
×
1775
                                        e2UpdateTime = &e2Time
×
1776
                                }
×
1777

1778
                                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
1779
                                        info, e1UpdateTime, e2UpdateTime,
×
1780
                                )
×
1781
                        }
1782

1783
                        err = db.UpsertZombieChannel(
×
1784
                                ctx, sqlc.UpsertZombieChannelParams{
×
1785
                                        Version:  int16(ProtocolV1),
×
NEW
1786
                                        Scid:     channelIDToBytes(scid),
×
1787
                                        NodeKey1: nodeKey1[:],
×
1788
                                        NodeKey2: nodeKey2[:],
×
1789
                                },
×
1790
                        )
×
1791
                        if err != nil {
×
1792
                                return fmt.Errorf("unable to mark channel as "+
×
1793
                                        "zombie: %w", err)
×
1794
                        }
×
1795

NEW
1796
                        return nil
×
1797
                }
1798

NEW
1799
                err := s.forEachChanWithPoliciesInSCIDList(
×
NEW
1800
                        ctx, db, chanCallBack, chanIDs,
×
NEW
1801
                )
×
NEW
1802
                if err != nil {
×
NEW
1803
                        return err
×
NEW
1804
                }
×
1805

NEW
1806
                if len(chanLookup) > 0 {
×
NEW
1807
                        return ErrEdgeNotFound
×
NEW
1808
                }
×
1809

NEW
1810
                return s.deleteChannels(ctx, db, chanIDsToDelete)
×
1811
        }, func() {
×
1812
                deleted = nil
×
NEW
1813

×
NEW
1814
                // Re-fill the lookup map.
×
NEW
1815
                for _, chanID := range chanIDs {
×
NEW
1816
                        chanLookup[chanID] = struct{}{}
×
NEW
1817
                }
×
1818
        })
1819
        if err != nil {
×
1820
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1821
                        err)
×
1822
        }
×
1823

1824
        for _, chanID := range chanIDs {
×
1825
                s.rejectCache.remove(chanID)
×
1826
                s.chanCache.remove(chanID)
×
1827
        }
×
1828

1829
        return deleted, nil
×
1830
}
1831

1832
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1833
// channel identified by the channel ID. If the channel can't be found, then
1834
// ErrEdgeNotFound is returned. A struct which houses the general information
1835
// for the channel itself is returned as well as two structs that contain the
1836
// routing policies for the channel in either direction.
1837
//
1838
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1839
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1840
// the ChannelEdgeInfo will only include the public keys of each node.
1841
//
1842
// NOTE: part of the V1Store interface.
1843
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1844
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1845
        *models.ChannelEdgePolicy, error) {
×
1846

×
1847
        var (
×
1848
                ctx              = context.TODO()
×
1849
                edge             *models.ChannelEdgeInfo
×
1850
                policy1, policy2 *models.ChannelEdgePolicy
×
1851
                chanIDB          = channelIDToBytes(chanID)
×
1852
        )
×
1853
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1854
                row, err := db.GetChannelBySCIDWithPolicies(
×
1855
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1856
                                Scid:    chanIDB,
×
1857
                                Version: int16(ProtocolV1),
×
1858
                        },
×
1859
                )
×
1860
                if errors.Is(err, sql.ErrNoRows) {
×
1861
                        // First check if this edge is perhaps in the zombie
×
1862
                        // index.
×
1863
                        zombie, err := db.GetZombieChannel(
×
1864
                                ctx, sqlc.GetZombieChannelParams{
×
1865
                                        Scid:    chanIDB,
×
1866
                                        Version: int16(ProtocolV1),
×
1867
                                },
×
1868
                        )
×
1869
                        if errors.Is(err, sql.ErrNoRows) {
×
1870
                                return ErrEdgeNotFound
×
1871
                        } else if err != nil {
×
1872
                                return fmt.Errorf("unable to check if "+
×
1873
                                        "channel is zombie: %w", err)
×
1874
                        }
×
1875

1876
                        // At this point, we know the channel is a zombie, so
1877
                        // we'll return an error indicating this, and we will
1878
                        // populate the edge info with the public keys of each
1879
                        // party as this is the only information we have about
1880
                        // it.
1881
                        edge = &models.ChannelEdgeInfo{}
×
1882
                        copy(edge.NodeKey1Bytes[:], zombie.NodeKey1)
×
1883
                        copy(edge.NodeKey2Bytes[:], zombie.NodeKey2)
×
1884

×
1885
                        return ErrZombieEdge
×
1886
                } else if err != nil {
×
1887
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1888
                }
×
1889

1890
                node1, node2, err := buildNodeVertices(
×
1891
                        row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
1892
                )
×
1893
                if err != nil {
×
1894
                        return err
×
1895
                }
×
1896

1897
                edge, err = getAndBuildEdgeInfo(
×
1898
                        ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
1899
                        row.GraphChannel, node1, node2,
×
1900
                )
×
1901
                if err != nil {
×
1902
                        return fmt.Errorf("unable to build channel info: %w",
×
1903
                                err)
×
1904
                }
×
1905

1906
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1907
                if err != nil {
×
1908
                        return fmt.Errorf("unable to extract channel "+
×
1909
                                "policies: %w", err)
×
1910
                }
×
1911

1912
                policy1, policy2, err = getAndBuildChanPolicies(
×
1913
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1914
                )
×
1915
                if err != nil {
×
1916
                        return fmt.Errorf("unable to build channel "+
×
1917
                                "policies: %w", err)
×
1918
                }
×
1919

1920
                return nil
×
1921
        }, sqldb.NoOpReset)
1922
        if err != nil {
×
1923
                // If we are returning the ErrZombieEdge, then we also need to
×
1924
                // return the edge info as the method comment indicates that
×
1925
                // this will be populated when the edge is a zombie.
×
1926
                return edge, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1927
                        err)
×
1928
        }
×
1929

1930
        return edge, policy1, policy2, nil
×
1931
}
1932

1933
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
1934
// the channel identified by the funding outpoint. If the channel can't be
1935
// found, then ErrEdgeNotFound is returned. A struct which houses the general
1936
// information for the channel itself is returned as well as two structs that
1937
// contain the routing policies for the channel in either direction.
1938
//
1939
// NOTE: part of the V1Store interface.
1940
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
1941
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1942
        *models.ChannelEdgePolicy, error) {
×
1943

×
1944
        var (
×
1945
                ctx              = context.TODO()
×
1946
                edge             *models.ChannelEdgeInfo
×
1947
                policy1, policy2 *models.ChannelEdgePolicy
×
1948
        )
×
1949
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1950
                row, err := db.GetChannelByOutpointWithPolicies(
×
1951
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
1952
                                Outpoint: op.String(),
×
1953
                                Version:  int16(ProtocolV1),
×
1954
                        },
×
1955
                )
×
1956
                if errors.Is(err, sql.ErrNoRows) {
×
1957
                        return ErrEdgeNotFound
×
1958
                } else if err != nil {
×
1959
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1960
                }
×
1961

1962
                node1, node2, err := buildNodeVertices(
×
1963
                        row.Node1Pubkey, row.Node2Pubkey,
×
1964
                )
×
1965
                if err != nil {
×
1966
                        return err
×
1967
                }
×
1968

1969
                edge, err = getAndBuildEdgeInfo(
×
1970
                        ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
1971
                        row.GraphChannel, node1, node2,
×
1972
                )
×
1973
                if err != nil {
×
1974
                        return fmt.Errorf("unable to build channel info: %w",
×
1975
                                err)
×
1976
                }
×
1977

1978
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1979
                if err != nil {
×
1980
                        return fmt.Errorf("unable to extract channel "+
×
1981
                                "policies: %w", err)
×
1982
                }
×
1983

1984
                policy1, policy2, err = getAndBuildChanPolicies(
×
1985
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1986
                )
×
1987
                if err != nil {
×
1988
                        return fmt.Errorf("unable to build channel "+
×
1989
                                "policies: %w", err)
×
1990
                }
×
1991

1992
                return nil
×
1993
        }, sqldb.NoOpReset)
1994
        if err != nil {
×
1995
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1996
                        err)
×
1997
        }
×
1998

1999
        return edge, policy1, policy2, nil
×
2000
}
2001

2002
// HasChannelEdge returns true if the database knows of a channel edge with the
2003
// passed channel ID, and false otherwise. If an edge with that ID is found
2004
// within the graph, then two time stamps representing the last time the edge
2005
// was updated for both directed edges are returned along with the boolean. If
2006
// it is not found, then the zombie index is checked and its result is returned
2007
// as the second boolean.
2008
//
2009
// NOTE: part of the V1Store interface.
2010
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
2011
        bool, error) {
×
2012

×
2013
        ctx := context.TODO()
×
2014

×
2015
        var (
×
2016
                exists          bool
×
2017
                isZombie        bool
×
2018
                node1LastUpdate time.Time
×
2019
                node2LastUpdate time.Time
×
2020
        )
×
2021

×
2022
        // We'll query the cache with the shared lock held to allow multiple
×
2023
        // readers to access values in the cache concurrently if they exist.
×
2024
        s.cacheMu.RLock()
×
2025
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2026
                s.cacheMu.RUnlock()
×
2027
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2028
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2029
                exists, isZombie = entry.flags.unpack()
×
2030

×
2031
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2032
        }
×
2033
        s.cacheMu.RUnlock()
×
2034

×
2035
        s.cacheMu.Lock()
×
2036
        defer s.cacheMu.Unlock()
×
2037

×
2038
        // The item was not found with the shared lock, so we'll acquire the
×
2039
        // exclusive lock and check the cache again in case another method added
×
2040
        // the entry to the cache while no lock was held.
×
2041
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2042
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2043
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2044
                exists, isZombie = entry.flags.unpack()
×
2045

×
2046
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2047
        }
×
2048

2049
        chanIDB := channelIDToBytes(chanID)
×
2050
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2051
                channel, err := db.GetChannelBySCID(
×
2052
                        ctx, sqlc.GetChannelBySCIDParams{
×
2053
                                Scid:    chanIDB,
×
2054
                                Version: int16(ProtocolV1),
×
2055
                        },
×
2056
                )
×
2057
                if errors.Is(err, sql.ErrNoRows) {
×
2058
                        // Check if it is a zombie channel.
×
2059
                        isZombie, err = db.IsZombieChannel(
×
2060
                                ctx, sqlc.IsZombieChannelParams{
×
2061
                                        Scid:    chanIDB,
×
2062
                                        Version: int16(ProtocolV1),
×
2063
                                },
×
2064
                        )
×
2065
                        if err != nil {
×
2066
                                return fmt.Errorf("could not check if channel "+
×
2067
                                        "is zombie: %w", err)
×
2068
                        }
×
2069

2070
                        return nil
×
2071
                } else if err != nil {
×
2072
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2073
                }
×
2074

2075
                exists = true
×
2076

×
2077
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
2078
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2079
                                Version:   int16(ProtocolV1),
×
2080
                                ChannelID: channel.ID,
×
2081
                                NodeID:    channel.NodeID1,
×
2082
                        },
×
2083
                )
×
2084
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2085
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2086
                                err)
×
2087
                } else if err == nil {
×
2088
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
2089
                }
×
2090

2091
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
2092
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2093
                                Version:   int16(ProtocolV1),
×
2094
                                ChannelID: channel.ID,
×
2095
                                NodeID:    channel.NodeID2,
×
2096
                        },
×
2097
                )
×
2098
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2099
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2100
                                err)
×
2101
                } else if err == nil {
×
2102
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
2103
                }
×
2104

2105
                return nil
×
2106
        }, sqldb.NoOpReset)
2107
        if err != nil {
×
2108
                return time.Time{}, time.Time{}, false, false,
×
2109
                        fmt.Errorf("unable to fetch channel: %w", err)
×
2110
        }
×
2111

2112
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
2113
                upd1Time: node1LastUpdate.Unix(),
×
2114
                upd2Time: node2LastUpdate.Unix(),
×
2115
                flags:    packRejectFlags(exists, isZombie),
×
2116
        })
×
2117

×
2118
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2119
}
2120

2121
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2122
// passed channel point (outpoint). If the passed channel doesn't exist within
2123
// the database, then ErrEdgeNotFound is returned.
2124
//
2125
// NOTE: part of the V1Store interface.
2126
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
2127
        var (
×
2128
                ctx       = context.TODO()
×
2129
                channelID uint64
×
2130
        )
×
2131
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2132
                chanID, err := db.GetSCIDByOutpoint(
×
2133
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2134
                                Outpoint: chanPoint.String(),
×
2135
                                Version:  int16(ProtocolV1),
×
2136
                        },
×
2137
                )
×
2138
                if errors.Is(err, sql.ErrNoRows) {
×
2139
                        return ErrEdgeNotFound
×
2140
                } else if err != nil {
×
2141
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2142
                                err)
×
2143
                }
×
2144

2145
                channelID = byteOrder.Uint64(chanID)
×
2146

×
2147
                return nil
×
2148
        }, sqldb.NoOpReset)
2149
        if err != nil {
×
2150
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2151
        }
×
2152

2153
        return channelID, nil
×
2154
}
2155

2156
// IsPublicNode is a helper method that determines whether the node with the
2157
// given public key is seen as a public node in the graph from the graph's
2158
// source node's point of view.
2159
//
2160
// NOTE: part of the V1Store interface.
2161
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
2162
        ctx := context.TODO()
×
2163

×
2164
        var isPublic bool
×
2165
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2166
                var err error
×
2167
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2168

×
2169
                return err
×
2170
        }, sqldb.NoOpReset)
×
2171
        if err != nil {
×
2172
                return false, fmt.Errorf("unable to check if node is "+
×
2173
                        "public: %w", err)
×
2174
        }
×
2175

2176
        return isPublic, nil
×
2177
}
2178

2179
// FetchChanInfos returns the set of channel edges that correspond to the passed
2180
// channel ID's. If an edge is the query is unknown to the database, it will
2181
// skipped and the result will contain only those edges that exist at the time
2182
// of the query. This can be used to respond to peer queries that are seeking to
2183
// fill in gaps in their view of the channel graph.
2184
//
2185
// NOTE: part of the V1Store interface.
2186
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2187
        var (
×
2188
                ctx   = context.TODO()
×
NEW
2189
                edges = make(map[uint64]ChannelEdge)
×
2190
        )
×
2191
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
2192
                chanCallBack := func(ctx context.Context,
×
NEW
2193
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
2194

×
2195
                        node1, node2, err := buildNodes(
×
2196
                                ctx, db, row.GraphNode, row.GraphNode_2,
×
2197
                        )
×
2198
                        if err != nil {
×
2199
                                return fmt.Errorf("unable to fetch nodes: %w",
×
2200
                                        err)
×
2201
                        }
×
2202

2203
                        edge, err := getAndBuildEdgeInfo(
×
2204
                                ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
2205
                                row.GraphChannel, node1.PubKeyBytes,
×
2206
                                node2.PubKeyBytes,
×
2207
                        )
×
2208
                        if err != nil {
×
2209
                                return fmt.Errorf("unable to build "+
×
2210
                                        "channel info: %w", err)
×
2211
                        }
×
2212

2213
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2214
                        if err != nil {
×
2215
                                return fmt.Errorf("unable to extract channel "+
×
2216
                                        "policies: %w", err)
×
2217
                        }
×
2218

2219
                        p1, p2, err := getAndBuildChanPolicies(
×
2220
                                ctx, db, dbPol1, dbPol2, edge.ChannelID,
×
2221
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
2222
                        )
×
2223
                        if err != nil {
×
2224
                                return fmt.Errorf("unable to build channel "+
×
2225
                                        "policies: %w", err)
×
2226
                        }
×
2227

NEW
2228
                        edges[edge.ChannelID] = ChannelEdge{
×
2229
                                Info:    edge,
×
2230
                                Policy1: p1,
×
2231
                                Policy2: p2,
×
2232
                                Node1:   node1,
×
2233
                                Node2:   node2,
×
NEW
2234
                        }
×
NEW
2235

×
NEW
2236
                        return nil
×
2237
                }
2238

NEW
2239
                return s.forEachChanWithPoliciesInSCIDList(
×
NEW
2240
                        ctx, db, chanCallBack, chanIDs,
×
NEW
2241
                )
×
2242
        }, func() {
×
NEW
2243
                clear(edges)
×
2244
        })
×
2245
        if err != nil {
×
2246
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2247
        }
×
2248

NEW
2249
        res := make([]ChannelEdge, 0, len(edges))
×
NEW
2250
        for _, chanID := range chanIDs {
×
NEW
2251
                edge, ok := edges[chanID]
×
NEW
2252
                if !ok {
×
NEW
2253
                        continue
×
2254
                }
2255

NEW
2256
                res = append(res, edge)
×
2257
        }
2258

NEW
2259
        return res, nil
×
2260
}
2261

2262
// forEachChanWithPoliciesInSCIDList is a wrapper around the
2263
// GetChannelsBySCIDWithPolicies query that allows us to iterate through
2264
// channels in a paginated manner.
2265
func (s *SQLStore) forEachChanWithPoliciesInSCIDList(ctx context.Context,
2266
        db SQLQueries, cb func(ctx context.Context,
2267
                row sqlc.GetChannelsBySCIDWithPoliciesRow) error,
NEW
2268
        chanIDs []uint64) error {
×
NEW
2269

×
NEW
2270
        queryWrapper := func(ctx context.Context,
×
NEW
2271
                scids [][]byte) ([]sqlc.GetChannelsBySCIDWithPoliciesRow,
×
NEW
2272
                error) {
×
NEW
2273

×
NEW
2274
                return db.GetChannelsBySCIDWithPolicies(
×
NEW
2275
                        ctx, sqlc.GetChannelsBySCIDWithPoliciesParams{
×
NEW
2276
                                Version: int16(ProtocolV1),
×
NEW
2277
                                Scids:   scids,
×
NEW
2278
                        },
×
NEW
2279
                )
×
NEW
2280
        }
×
2281

NEW
2282
        return sqldb.ExecutePagedQuery(
×
NEW
2283
                ctx, s.cfg.PaginationCfg, chanIDs, channelIDToBytes,
×
NEW
2284
                queryWrapper, cb,
×
NEW
2285
        )
×
2286
}
2287

2288
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2289
// ID's that we don't know and are not known zombies of the passed set. In other
2290
// words, we perform a set difference of our set of chan ID's and the ones
2291
// passed in. This method can be used by callers to determine the set of
2292
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2293
// known zombies is also returned.
2294
//
2295
// NOTE: part of the V1Store interface.
2296
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
2297
        []ChannelUpdateInfo, error) {
×
2298

×
2299
        var (
×
2300
                ctx          = context.TODO()
×
2301
                newChanIDs   []uint64
×
2302
                knownZombies []ChannelUpdateInfo
×
NEW
2303
                infoLookup   = make(
×
NEW
2304
                        map[uint64]ChannelUpdateInfo, len(chansInfo),
×
NEW
2305
                )
×
2306
        )
×
NEW
2307

×
NEW
2308
        // We first build a lookup map of the channel ID's to the
×
NEW
2309
        // ChannelUpdateInfo. This allows us to quickly delete channels that we
×
NEW
2310
        // already know about.
×
NEW
2311
        for _, chanInfo := range chansInfo {
×
NEW
2312
                infoLookup[chanInfo.ShortChannelID.ToUint64()] = chanInfo
×
NEW
2313
        }
×
2314

2315
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
2316
                // The call-back function deletes known channels from
×
NEW
2317
                // infoLookup, so that we can later check which channels are
×
NEW
2318
                // zombies by only looking at the remaining channels in the set.
×
NEW
2319
                cb := func(ctx context.Context,
×
NEW
2320
                        channel sqlc.GraphChannel) error {
×
NEW
2321

×
NEW
2322
                        delete(infoLookup, byteOrder.Uint64(channel.Scid))
×
NEW
2323

×
NEW
2324
                        return nil
×
NEW
2325
                }
×
2326

NEW
2327
                err := s.forEachChanInSCIDList(ctx, db, cb, chansInfo)
×
NEW
2328
                if err != nil {
×
NEW
2329
                        return fmt.Errorf("unable to iterate through "+
×
NEW
2330
                                "channels: %w", err)
×
NEW
2331
                }
×
2332

2333
                // We want to ensure that we deal with the channels in the
2334
                // same order that they were passed in, so we iterate over the
2335
                // original chansInfo slice and then check if that channel is
2336
                // still in the infoLookup map.
2337
                for _, chanInfo := range chansInfo {
×
2338
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
NEW
2339
                        if _, ok := infoLookup[channelID]; !ok {
×
2340
                                continue
×
2341
                        }
2342

2343
                        isZombie, err := db.IsZombieChannel(
×
2344
                                ctx, sqlc.IsZombieChannelParams{
×
NEW
2345
                                        Scid:    channelIDToBytes(channelID),
×
2346
                                        Version: int16(ProtocolV1),
×
2347
                                },
×
2348
                        )
×
2349
                        if err != nil {
×
2350
                                return fmt.Errorf("unable to fetch zombie "+
×
2351
                                        "channel: %w", err)
×
2352
                        }
×
2353

2354
                        if isZombie {
×
2355
                                knownZombies = append(knownZombies, chanInfo)
×
2356

×
2357
                                continue
×
2358
                        }
2359

2360
                        newChanIDs = append(newChanIDs, channelID)
×
2361
                }
2362

2363
                return nil
×
2364
        }, func() {
×
2365
                newChanIDs = nil
×
2366
                knownZombies = nil
×
NEW
2367
                // Rebuild the infoLookup map in case of a rollback.
×
NEW
2368
                for _, chanInfo := range chansInfo {
×
NEW
2369
                        scid := chanInfo.ShortChannelID.ToUint64()
×
NEW
2370
                        infoLookup[scid] = chanInfo
×
NEW
2371
                }
×
2372
        })
2373
        if err != nil {
×
2374
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2375
        }
×
2376

2377
        return newChanIDs, knownZombies, nil
×
2378
}
2379

2380
// forEachChanInSCIDList is a helper method that executes a paged query
2381
// against the database to fetch all channels that match the passed
2382
// ChannelUpdateInfo slice. The callback function is called for each channel
2383
// that is found.
2384
func (s *SQLStore) forEachChanInSCIDList(ctx context.Context, db SQLQueries,
2385
        cb func(ctx context.Context, channel sqlc.GraphChannel) error,
NEW
2386
        chansInfo []ChannelUpdateInfo) error {
×
NEW
2387

×
NEW
2388
        queryWrapper := func(ctx context.Context,
×
NEW
2389
                scids [][]byte) ([]sqlc.GraphChannel, error) {
×
NEW
2390

×
NEW
2391
                return db.GetChannelsBySCIDs(
×
NEW
2392
                        ctx, sqlc.GetChannelsBySCIDsParams{
×
NEW
2393
                                Version: int16(ProtocolV1),
×
NEW
2394
                                Scids:   scids,
×
NEW
2395
                        },
×
NEW
2396
                )
×
NEW
2397
        }
×
2398

NEW
2399
        chanIDConverter := func(chanInfo ChannelUpdateInfo) []byte {
×
NEW
2400
                channelID := chanInfo.ShortChannelID.ToUint64()
×
NEW
2401

×
NEW
2402
                return channelIDToBytes(channelID)
×
NEW
2403
        }
×
2404

NEW
2405
        return sqldb.ExecutePagedQuery(
×
NEW
2406
                ctx, s.cfg.PaginationCfg, chansInfo, chanIDConverter,
×
NEW
2407
                queryWrapper, cb,
×
NEW
2408
        )
×
2409
}
2410

2411
// PruneGraphNodes is a garbage collection method which attempts to prune out
2412
// any nodes from the channel graph that are currently unconnected. This ensure
2413
// that we only maintain a graph of reachable nodes. In the event that a pruned
2414
// node gains more channels, it will be re-added back to the graph.
2415
//
2416
// NOTE: this prunes nodes across protocol versions. It will never prune the
2417
// source nodes.
2418
//
2419
// NOTE: part of the V1Store interface.
2420
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
2421
        var ctx = context.TODO()
×
2422

×
2423
        var prunedNodes []route.Vertex
×
2424
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2425
                var err error
×
2426
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2427

×
2428
                return err
×
2429
        }, func() {
×
2430
                prunedNodes = nil
×
2431
        })
×
2432
        if err != nil {
×
2433
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2434
        }
×
2435

2436
        return prunedNodes, nil
×
2437
}
2438

2439
// PruneGraph prunes newly closed channels from the channel graph in response
2440
// to a new block being solved on the network. Any transactions which spend the
2441
// funding output of any known channels within he graph will be deleted.
2442
// Additionally, the "prune tip", or the last block which has been used to
2443
// prune the graph is stored so callers can ensure the graph is fully in sync
2444
// with the current UTXO state. A slice of channels that have been closed by
2445
// the target block along with any pruned nodes are returned if the function
2446
// succeeds without error.
2447
//
2448
// NOTE: part of the V1Store interface.
2449
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2450
        blockHash *chainhash.Hash, blockHeight uint32) (
2451
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
2452

×
2453
        ctx := context.TODO()
×
2454

×
2455
        s.cacheMu.Lock()
×
2456
        defer s.cacheMu.Unlock()
×
2457

×
2458
        var (
×
2459
                closedChans []*models.ChannelEdgeInfo
×
2460
                prunedNodes []route.Vertex
×
2461
        )
×
2462
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
NEW
2463
                var chansToDelete []int64
×
NEW
2464

×
NEW
2465
                // Define the callback function for processing each channel.
×
NEW
2466
                channelCallback := func(ctx context.Context,
×
NEW
2467
                        row sqlc.GetChannelsByOutpointsRow) error {
×
2468

×
2469
                        node1, node2, err := buildNodeVertices(
×
2470
                                row.Node1Pubkey, row.Node2Pubkey,
×
2471
                        )
×
2472
                        if err != nil {
×
2473
                                return err
×
2474
                        }
×
2475

2476
                        info, err := getAndBuildEdgeInfo(
×
2477
                                ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
2478
                                row.GraphChannel, node1, node2,
×
2479
                        )
×
2480
                        if err != nil {
×
2481
                                return err
×
2482
                        }
×
2483

2484
                        closedChans = append(closedChans, info)
×
NEW
2485
                        chansToDelete = append(
×
NEW
2486
                                chansToDelete, row.GraphChannel.ID,
×
NEW
2487
                        )
×
NEW
2488

×
NEW
2489
                        return nil
×
2490
                }
2491

NEW
2492
                err := s.forEachChanInOutpoints(
×
NEW
2493
                        ctx, db, spentOutputs, channelCallback,
×
NEW
2494
                )
×
NEW
2495
                if err != nil {
×
NEW
2496
                        return fmt.Errorf("unable to fetch channels by "+
×
NEW
2497
                                "outpoints: %w", err)
×
UNCOV
2498
                }
×
2499

NEW
2500
                err = s.deleteChannels(ctx, db, chansToDelete)
×
NEW
2501
                if err != nil {
×
NEW
2502
                        return fmt.Errorf("unable to delete channels: %w", err)
×
NEW
2503
                }
×
2504

NEW
2505
                err = db.UpsertPruneLogEntry(
×
2506
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
2507
                                BlockHash:   blockHash[:],
×
2508
                                BlockHeight: int64(blockHeight),
×
2509
                        },
×
2510
                )
×
2511
                if err != nil {
×
2512
                        return fmt.Errorf("unable to insert prune log "+
×
2513
                                "entry: %w", err)
×
2514
                }
×
2515

2516
                // Now that we've pruned some channels, we'll also prune any
2517
                // nodes that no longer have any channels.
2518
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2519
                if err != nil {
×
2520
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2521
                                err)
×
2522
                }
×
2523

2524
                return nil
×
2525
        }, func() {
×
2526
                prunedNodes = nil
×
2527
                closedChans = nil
×
2528
        })
×
2529
        if err != nil {
×
2530
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2531
        }
×
2532

2533
        for _, channel := range closedChans {
×
2534
                s.rejectCache.remove(channel.ChannelID)
×
2535
                s.chanCache.remove(channel.ChannelID)
×
2536
        }
×
2537

2538
        return closedChans, prunedNodes, nil
×
2539
}
2540

2541
// forEachChanInOutpoints is a helper function that executes a paginated
2542
// query to fetch channels by their outpoints and applies the given call-back
2543
// to each.
2544
//
2545
// NOTE: this fetches channels for all protocol versions.
2546
func (s *SQLStore) forEachChanInOutpoints(ctx context.Context, db SQLQueries,
2547
        outpoints []*wire.OutPoint, cb func(ctx context.Context,
NEW
2548
                row sqlc.GetChannelsByOutpointsRow) error) error {
×
NEW
2549

×
NEW
2550
        // Create a wrapper that uses the transaction's db instance to execute
×
NEW
2551
        // the query.
×
NEW
2552
        queryWrapper := func(ctx context.Context,
×
NEW
2553
                pageOutpoints []string) ([]sqlc.GetChannelsByOutpointsRow,
×
NEW
2554
                error) {
×
NEW
2555

×
NEW
2556
                return db.GetChannelsByOutpoints(ctx, pageOutpoints)
×
NEW
2557
        }
×
2558

2559
        // Define the conversion function from Outpoint to string.
NEW
2560
        outpointToString := func(outpoint *wire.OutPoint) string {
×
NEW
2561
                return outpoint.String()
×
NEW
2562
        }
×
2563

NEW
2564
        return sqldb.ExecutePagedQuery(
×
NEW
2565
                ctx, s.cfg.PaginationCfg, outpoints, outpointToString,
×
NEW
2566
                queryWrapper, cb,
×
NEW
2567
        )
×
2568
}
2569

2570
func (s *SQLStore) deleteChannels(ctx context.Context, db SQLQueries,
NEW
2571
        dbIDs []int64) error {
×
NEW
2572

×
NEW
2573
        // Create a wrapper that uses the transaction's db instance to execute
×
NEW
2574
        // the query.
×
NEW
2575
        queryWrapper := func(ctx context.Context, ids []int64) ([]any, error) {
×
NEW
2576
                return nil, db.DeleteChannels(ctx, ids)
×
NEW
2577
        }
×
2578

NEW
2579
        idConverter := func(id int64) int64 {
×
NEW
2580
                return id
×
NEW
2581
        }
×
2582

NEW
2583
        return sqldb.ExecutePagedQuery(
×
NEW
2584
                ctx, s.cfg.PaginationCfg, dbIDs, idConverter,
×
NEW
2585
                queryWrapper, func(ctx context.Context, _ any) error {
×
NEW
2586
                        return nil
×
NEW
2587
                },
×
2588
        )
2589
}
2590

2591
// ChannelView returns the verifiable edge information for each active channel
2592
// within the known channel graph. The set of UTXOs (along with their scripts)
2593
// returned are the ones that need to be watched on chain to detect channel
2594
// closes on the resident blockchain.
2595
//
2596
// NOTE: part of the V1Store interface.
2597
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
2598
        var (
×
2599
                ctx        = context.TODO()
×
2600
                edgePoints []EdgePoint
×
2601
        )
×
2602

×
2603
        handleChannel := func(db SQLQueries,
×
2604
                channel sqlc.ListChannelsPaginatedRow) error {
×
2605

×
2606
                pkScript, err := genMultiSigP2WSH(
×
2607
                        channel.BitcoinKey1, channel.BitcoinKey2,
×
2608
                )
×
2609
                if err != nil {
×
2610
                        return err
×
2611
                }
×
2612

2613
                op, err := wire.NewOutPointFromString(channel.Outpoint)
×
2614
                if err != nil {
×
2615
                        return err
×
2616
                }
×
2617

2618
                edgePoints = append(edgePoints, EdgePoint{
×
2619
                        FundingPkScript: pkScript,
×
2620
                        OutPoint:        *op,
×
2621
                })
×
2622

×
2623
                return nil
×
2624
        }
2625

2626
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2627
                lastID := int64(-1)
×
2628
                for {
×
2629
                        rows, err := db.ListChannelsPaginated(
×
2630
                                ctx, sqlc.ListChannelsPaginatedParams{
×
2631
                                        Version: int16(ProtocolV1),
×
2632
                                        ID:      lastID,
×
2633
                                        Limit:   pageSize,
×
2634
                                },
×
2635
                        )
×
2636
                        if err != nil {
×
2637
                                return err
×
2638
                        }
×
2639

2640
                        if len(rows) == 0 {
×
2641
                                break
×
2642
                        }
2643

2644
                        for _, row := range rows {
×
2645
                                err := handleChannel(db, row)
×
2646
                                if err != nil {
×
2647
                                        return err
×
2648
                                }
×
2649

2650
                                lastID = row.ID
×
2651
                        }
2652
                }
2653

2654
                return nil
×
2655
        }, func() {
×
2656
                edgePoints = nil
×
2657
        })
×
2658
        if err != nil {
×
2659
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
2660
        }
×
2661

2662
        return edgePoints, nil
×
2663
}
2664

2665
// PruneTip returns the block height and hash of the latest block that has been
2666
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2667
// to tell if the graph is currently in sync with the current best known UTXO
2668
// state.
2669
//
2670
// NOTE: part of the V1Store interface.
2671
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
2672
        var (
×
2673
                ctx       = context.TODO()
×
2674
                tipHash   chainhash.Hash
×
2675
                tipHeight uint32
×
2676
        )
×
2677
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2678
                pruneTip, err := db.GetPruneTip(ctx)
×
2679
                if errors.Is(err, sql.ErrNoRows) {
×
2680
                        return ErrGraphNeverPruned
×
2681
                } else if err != nil {
×
2682
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
2683
                }
×
2684

2685
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
2686
                tipHeight = uint32(pruneTip.BlockHeight)
×
2687

×
2688
                return nil
×
2689
        }, sqldb.NoOpReset)
2690
        if err != nil {
×
2691
                return nil, 0, err
×
2692
        }
×
2693

2694
        return &tipHash, tipHeight, nil
×
2695
}
2696

2697
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2698
//
2699
// NOTE: this prunes nodes across protocol versions. It will never prune the
2700
// source nodes.
2701
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
2702
        db SQLQueries) ([]route.Vertex, error) {
×
2703

×
2704
        nodeKeys, err := db.DeleteUnconnectedNodes(ctx)
×
2705
        if err != nil {
×
2706
                return nil, fmt.Errorf("unable to delete unconnected "+
×
2707
                        "nodes: %w", err)
×
2708
        }
×
2709

2710
        prunedNodes := make([]route.Vertex, len(nodeKeys))
×
2711
        for i, nodeKey := range nodeKeys {
×
2712
                pub, err := route.NewVertexFromBytes(nodeKey)
×
2713
                if err != nil {
×
2714
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
2715
                                "from bytes: %w", err)
×
2716
                }
×
2717

2718
                prunedNodes[i] = pub
×
2719
        }
2720

2721
        return prunedNodes, nil
×
2722
}
2723

2724
// DisconnectBlockAtHeight is used to indicate that the block specified
2725
// by the passed height has been disconnected from the main chain. This
2726
// will "rewind" the graph back to the height below, deleting channels
2727
// that are no longer confirmed from the graph. The prune log will be
2728
// set to the last prune height valid for the remaining chain.
2729
// Channels that were removed from the graph resulting from the
2730
// disconnected block are returned.
2731
//
2732
// NOTE: part of the V1Store interface.
2733
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
2734
        []*models.ChannelEdgeInfo, error) {
×
2735

×
2736
        ctx := context.TODO()
×
2737

×
2738
        var (
×
2739
                // Every channel having a ShortChannelID starting at 'height'
×
2740
                // will no longer be confirmed.
×
2741
                startShortChanID = lnwire.ShortChannelID{
×
2742
                        BlockHeight: height,
×
2743
                }
×
2744

×
2745
                // Delete everything after this height from the db up until the
×
2746
                // SCID alias range.
×
2747
                endShortChanID = aliasmgr.StartingAlias
×
2748

×
2749
                removedChans []*models.ChannelEdgeInfo
×
2750

×
2751
                chanIDStart = channelIDToBytes(startShortChanID.ToUint64())
×
2752
                chanIDEnd   = channelIDToBytes(endShortChanID.ToUint64())
×
2753
        )
×
2754

×
2755
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2756
                rows, err := db.GetChannelsBySCIDRange(
×
2757
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
2758
                                StartScid: chanIDStart,
×
2759
                                EndScid:   chanIDEnd,
×
2760
                        },
×
2761
                )
×
2762
                if err != nil {
×
2763
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2764
                }
×
2765

NEW
2766
                chanIDsToDelete := make([]int64, len(rows))
×
NEW
2767
                for i, row := range rows {
×
2768
                        node1, node2, err := buildNodeVertices(
×
2769
                                row.Node1PubKey, row.Node2PubKey,
×
2770
                        )
×
2771
                        if err != nil {
×
2772
                                return err
×
2773
                        }
×
2774

2775
                        channel, err := getAndBuildEdgeInfo(
×
2776
                                ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
2777
                                row.GraphChannel, node1, node2,
×
2778
                        )
×
2779
                        if err != nil {
×
2780
                                return err
×
2781
                        }
×
2782

NEW
2783
                        chanIDsToDelete[i] = row.GraphChannel.ID
×
2784
                        removedChans = append(removedChans, channel)
×
2785
                }
2786

NEW
2787
                err = s.deleteChannels(ctx, db, chanIDsToDelete)
×
NEW
2788
                if err != nil {
×
NEW
2789
                        return fmt.Errorf("unable to delete channels: %w", err)
×
NEW
2790
                }
×
2791

2792
                return db.DeletePruneLogEntriesInRange(
×
2793
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2794
                                StartHeight: int64(height),
×
2795
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
2796
                        },
×
2797
                )
×
2798
        }, func() {
×
2799
                removedChans = nil
×
2800
        })
×
2801
        if err != nil {
×
2802
                return nil, fmt.Errorf("unable to disconnect block at "+
×
2803
                        "height: %w", err)
×
2804
        }
×
2805

2806
        for _, channel := range removedChans {
×
2807
                s.rejectCache.remove(channel.ChannelID)
×
2808
                s.chanCache.remove(channel.ChannelID)
×
2809
        }
×
2810

2811
        return removedChans, nil
×
2812
}
2813

2814
// AddEdgeProof sets the proof of an existing edge in the graph database.
2815
//
2816
// NOTE: part of the V1Store interface.
2817
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
2818
        proof *models.ChannelAuthProof) error {
×
2819

×
2820
        var (
×
2821
                ctx       = context.TODO()
×
2822
                scidBytes = channelIDToBytes(scid.ToUint64())
×
2823
        )
×
2824

×
2825
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2826
                res, err := db.AddV1ChannelProof(
×
2827
                        ctx, sqlc.AddV1ChannelProofParams{
×
2828
                                Scid:              scidBytes,
×
2829
                                Node1Signature:    proof.NodeSig1Bytes,
×
2830
                                Node2Signature:    proof.NodeSig2Bytes,
×
2831
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
2832
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
2833
                        },
×
2834
                )
×
2835
                if err != nil {
×
2836
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
2837
                }
×
2838

2839
                n, err := res.RowsAffected()
×
2840
                if err != nil {
×
2841
                        return err
×
2842
                }
×
2843

2844
                if n == 0 {
×
2845
                        return fmt.Errorf("no rows affected when adding edge "+
×
2846
                                "proof for SCID %v", scid)
×
2847
                } else if n > 1 {
×
2848
                        return fmt.Errorf("multiple rows affected when adding "+
×
2849
                                "edge proof for SCID %v: %d rows affected",
×
2850
                                scid, n)
×
2851
                }
×
2852

2853
                return nil
×
2854
        }, sqldb.NoOpReset)
2855
        if err != nil {
×
2856
                return fmt.Errorf("unable to add edge proof: %w", err)
×
2857
        }
×
2858

2859
        return nil
×
2860
}
2861

2862
// PutClosedScid stores a SCID for a closed channel in the database. This is so
2863
// that we can ignore channel announcements that we know to be closed without
2864
// having to validate them and fetch a block.
2865
//
2866
// NOTE: part of the V1Store interface.
2867
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
2868
        var (
×
2869
                ctx     = context.TODO()
×
2870
                chanIDB = channelIDToBytes(scid.ToUint64())
×
2871
        )
×
2872

×
2873
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2874
                return db.InsertClosedChannel(ctx, chanIDB)
×
2875
        }, sqldb.NoOpReset)
×
2876
}
2877

2878
// IsClosedScid checks whether a channel identified by the passed in scid is
2879
// closed. This helps avoid having to perform expensive validation checks.
2880
//
2881
// NOTE: part of the V1Store interface.
2882
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
2883
        var (
×
2884
                ctx      = context.TODO()
×
2885
                isClosed bool
×
2886
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
2887
        )
×
2888
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2889
                var err error
×
2890
                isClosed, err = db.IsClosedChannel(ctx, chanIDB)
×
2891
                if err != nil {
×
2892
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
2893
                                err)
×
2894
                }
×
2895

2896
                return nil
×
2897
        }, sqldb.NoOpReset)
2898
        if err != nil {
×
2899
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
2900
                        err)
×
2901
        }
×
2902

2903
        return isClosed, nil
×
2904
}
2905

2906
// GraphSession will provide the call-back with access to a NodeTraverser
2907
// instance which can be used to perform queries against the channel graph.
2908
//
2909
// NOTE: part of the V1Store interface.
2910
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error,
2911
        reset func()) error {
×
2912

×
2913
        var ctx = context.TODO()
×
2914

×
2915
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2916
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
2917
        }, reset)
×
2918
}
2919

2920
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
2921
// read only transaction for a consistent view of the graph.
2922
type sqlNodeTraverser struct {
2923
        db    SQLQueries
2924
        chain chainhash.Hash
2925
}
2926

2927
// A compile-time assertion to ensure that sqlNodeTraverser implements the
2928
// NodeTraverser interface.
2929
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
2930

2931
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
2932
func newSQLNodeTraverser(db SQLQueries,
2933
        chain chainhash.Hash) *sqlNodeTraverser {
×
2934

×
2935
        return &sqlNodeTraverser{
×
2936
                db:    db,
×
2937
                chain: chain,
×
2938
        }
×
2939
}
×
2940

2941
// ForEachNodeDirectedChannel calls the callback for every channel of the given
2942
// node.
2943
//
2944
// NOTE: Part of the NodeTraverser interface.
2945
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
2946
        cb func(channel *DirectedChannel) error, _ func()) error {
×
2947

×
2948
        ctx := context.TODO()
×
2949

×
2950
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
2951
}
×
2952

2953
// FetchNodeFeatures returns the features of the given node. If the node is
2954
// unknown, assume no additional features are supported.
2955
//
2956
// NOTE: Part of the NodeTraverser interface.
2957
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
2958
        *lnwire.FeatureVector, error) {
×
2959

×
2960
        ctx := context.TODO()
×
2961

×
2962
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
2963
}
×
2964

2965
// forEachNodeDirectedChannel iterates through all channels of a given
2966
// node, executing the passed callback on the directed edge representing the
2967
// channel and its incoming policy. If the node is not found, no error is
2968
// returned.
2969
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
2970
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
2971

×
2972
        toNodeCallback := func() route.Vertex {
×
2973
                return nodePub
×
2974
        }
×
2975

2976
        dbID, err := db.GetNodeIDByPubKey(
×
2977
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
2978
                        Version: int16(ProtocolV1),
×
2979
                        PubKey:  nodePub[:],
×
2980
                },
×
2981
        )
×
2982
        if errors.Is(err, sql.ErrNoRows) {
×
2983
                return nil
×
2984
        } else if err != nil {
×
2985
                return fmt.Errorf("unable to fetch node: %w", err)
×
2986
        }
×
2987

2988
        rows, err := db.ListChannelsByNodeID(
×
2989
                ctx, sqlc.ListChannelsByNodeIDParams{
×
2990
                        Version: int16(ProtocolV1),
×
2991
                        NodeID1: dbID,
×
2992
                },
×
2993
        )
×
2994
        if err != nil {
×
2995
                return fmt.Errorf("unable to fetch channels: %w", err)
×
2996
        }
×
2997

2998
        // Exit early if there are no channels for this node so we don't
2999
        // do the unnecessary feature fetching.
3000
        if len(rows) == 0 {
×
3001
                return nil
×
3002
        }
×
3003

3004
        features, err := getNodeFeatures(ctx, db, dbID)
×
3005
        if err != nil {
×
3006
                return fmt.Errorf("unable to fetch node features: %w", err)
×
3007
        }
×
3008

3009
        for _, row := range rows {
×
3010
                node1, node2, err := buildNodeVertices(
×
3011
                        row.Node1Pubkey, row.Node2Pubkey,
×
3012
                )
×
3013
                if err != nil {
×
3014
                        return fmt.Errorf("unable to build node vertices: %w",
×
3015
                                err)
×
3016
                }
×
3017

3018
                edge := buildCacheableChannelInfo(
×
3019
                        row.GraphChannel, node1, node2,
×
3020
                )
×
3021

×
3022
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3023
                if err != nil {
×
3024
                        return err
×
3025
                }
×
3026

3027
                var p1, p2 *models.CachedEdgePolicy
×
3028
                if dbPol1 != nil {
×
3029
                        policy1, err := buildChanPolicy(
×
3030
                                *dbPol1, edge.ChannelID, nil, node2,
×
3031
                        )
×
3032
                        if err != nil {
×
3033
                                return err
×
3034
                        }
×
3035

3036
                        p1 = models.NewCachedPolicy(policy1)
×
3037
                }
3038
                if dbPol2 != nil {
×
3039
                        policy2, err := buildChanPolicy(
×
3040
                                *dbPol2, edge.ChannelID, nil, node1,
×
3041
                        )
×
3042
                        if err != nil {
×
3043
                                return err
×
3044
                        }
×
3045

3046
                        p2 = models.NewCachedPolicy(policy2)
×
3047
                }
3048

3049
                // Determine the outgoing and incoming policy for this
3050
                // channel and node combo.
3051
                outPolicy, inPolicy := p1, p2
×
3052
                if p1 != nil && node2 == nodePub {
×
3053
                        outPolicy, inPolicy = p2, p1
×
3054
                } else if p2 != nil && node1 != nodePub {
×
3055
                        outPolicy, inPolicy = p2, p1
×
3056
                }
×
3057

3058
                var cachedInPolicy *models.CachedEdgePolicy
×
3059
                if inPolicy != nil {
×
3060
                        cachedInPolicy = inPolicy
×
3061
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
3062
                        cachedInPolicy.ToNodeFeatures = features
×
3063
                }
×
3064

3065
                directedChannel := &DirectedChannel{
×
3066
                        ChannelID:    edge.ChannelID,
×
3067
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
3068
                        OtherNode:    edge.NodeKey2Bytes,
×
3069
                        Capacity:     edge.Capacity,
×
3070
                        OutPolicySet: outPolicy != nil,
×
3071
                        InPolicy:     cachedInPolicy,
×
3072
                }
×
3073
                if outPolicy != nil {
×
3074
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3075
                                directedChannel.InboundFee = fee
×
3076
                        })
×
3077
                }
3078

3079
                if nodePub == edge.NodeKey2Bytes {
×
3080
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
3081
                }
×
3082

3083
                if err := cb(directedChannel); err != nil {
×
3084
                        return err
×
3085
                }
×
3086
        }
3087

3088
        return nil
×
3089
}
3090

3091
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
3092
// and executes the provided callback for each node.
3093
func forEachNodeCacheable(ctx context.Context, db SQLQueries,
3094
        cb func(nodeID int64, nodePub route.Vertex) error) error {
×
3095

×
3096
        lastID := int64(-1)
×
3097

×
3098
        for {
×
3099
                nodes, err := db.ListNodeIDsAndPubKeys(
×
3100
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
3101
                                Version: int16(ProtocolV1),
×
3102
                                ID:      lastID,
×
3103
                                Limit:   pageSize,
×
3104
                        },
×
3105
                )
×
3106
                if err != nil {
×
3107
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
3108
                }
×
3109

3110
                if len(nodes) == 0 {
×
3111
                        break
×
3112
                }
3113

3114
                for _, node := range nodes {
×
3115
                        var pub route.Vertex
×
3116
                        copy(pub[:], node.PubKey)
×
3117

×
3118
                        if err := cb(node.ID, pub); err != nil {
×
3119
                                return fmt.Errorf("forEachNodeCacheable "+
×
3120
                                        "callback failed for node(id=%d): %w",
×
3121
                                        node.ID, err)
×
3122
                        }
×
3123

3124
                        lastID = node.ID
×
3125
                }
3126
        }
3127

3128
        return nil
×
3129
}
3130

3131
// forEachNodeChannel iterates through all channels of a node, executing
3132
// the passed callback on each. The call-back is provided with the channel's
3133
// edge information, the outgoing policy and the incoming policy for the
3134
// channel and node combo.
3135
func forEachNodeChannel(ctx context.Context, db SQLQueries,
3136
        chain chainhash.Hash, id int64, cb func(*models.ChannelEdgeInfo,
3137
                *models.ChannelEdgePolicy,
3138
                *models.ChannelEdgePolicy) error) error {
×
3139

×
3140
        // Get all the V1 channels for this node.Add commentMore actions
×
3141
        rows, err := db.ListChannelsByNodeID(
×
3142
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3143
                        Version: int16(ProtocolV1),
×
3144
                        NodeID1: id,
×
3145
                },
×
3146
        )
×
3147
        if err != nil {
×
3148
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3149
        }
×
3150

3151
        // Call the call-back for each channel and its known policies.
3152
        for _, row := range rows {
×
3153
                node1, node2, err := buildNodeVertices(
×
3154
                        row.Node1Pubkey, row.Node2Pubkey,
×
3155
                )
×
3156
                if err != nil {
×
3157
                        return fmt.Errorf("unable to build node vertices: %w",
×
3158
                                err)
×
3159
                }
×
3160

3161
                edge, err := getAndBuildEdgeInfo(
×
3162
                        ctx, db, chain, row.GraphChannel.ID, row.GraphChannel,
×
3163
                        node1, node2,
×
3164
                )
×
3165
                if err != nil {
×
3166
                        return fmt.Errorf("unable to build channel info: %w",
×
3167
                                err)
×
3168
                }
×
3169

3170
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3171
                if err != nil {
×
3172
                        return fmt.Errorf("unable to extract channel "+
×
3173
                                "policies: %w", err)
×
3174
                }
×
3175

3176
                p1, p2, err := getAndBuildChanPolicies(
×
3177
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
3178
                )
×
3179
                if err != nil {
×
3180
                        return fmt.Errorf("unable to build channel "+
×
3181
                                "policies: %w", err)
×
3182
                }
×
3183

3184
                // Determine the outgoing and incoming policy for this
3185
                // channel and node combo.
3186
                p1ToNode := row.GraphChannel.NodeID2
×
3187
                p2ToNode := row.GraphChannel.NodeID1
×
3188
                outPolicy, inPolicy := p1, p2
×
3189
                if (p1 != nil && p1ToNode == id) ||
×
3190
                        (p2 != nil && p2ToNode != id) {
×
3191

×
3192
                        outPolicy, inPolicy = p2, p1
×
3193
                }
×
3194

3195
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3196
                        return err
×
3197
                }
×
3198
        }
3199

3200
        return nil
×
3201
}
3202

3203
// updateChanEdgePolicy upserts the channel policy info we have stored for
3204
// a channel we already know of.
3205
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3206
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
3207
        error) {
×
3208

×
3209
        var (
×
3210
                node1Pub, node2Pub route.Vertex
×
3211
                isNode1            bool
×
3212
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3213
        )
×
3214

×
3215
        // Check that this edge policy refers to a channel that we already
×
3216
        // know of. We do this explicitly so that we can return the appropriate
×
3217
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
3218
        // abort the transaction which would abort the entire batch.
×
3219
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3220
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3221
                        Scid:    chanIDB,
×
3222
                        Version: int16(ProtocolV1),
×
3223
                },
×
3224
        )
×
3225
        if errors.Is(err, sql.ErrNoRows) {
×
3226
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3227
        } else if err != nil {
×
3228
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3229
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3230
        }
×
3231

3232
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3233
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3234

×
3235
        // Figure out which node this edge is from.
×
3236
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3237
        nodeID := dbChan.NodeID1
×
3238
        if !isNode1 {
×
3239
                nodeID = dbChan.NodeID2
×
3240
        }
×
3241

3242
        var (
×
3243
                inboundBase sql.NullInt64
×
3244
                inboundRate sql.NullInt64
×
3245
        )
×
3246
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3247
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3248
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3249
        })
×
3250

3251
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
3252
                Version:     int16(ProtocolV1),
×
3253
                ChannelID:   dbChan.ID,
×
3254
                NodeID:      nodeID,
×
3255
                Timelock:    int32(edge.TimeLockDelta),
×
3256
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
3257
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
3258
                MinHtlcMsat: int64(edge.MinHTLC),
×
3259
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3260
                Disabled: sql.NullBool{
×
3261
                        Valid: true,
×
3262
                        Bool:  edge.IsDisabled(),
×
3263
                },
×
3264
                MaxHtlcMsat: sql.NullInt64{
×
3265
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3266
                        Int64: int64(edge.MaxHTLC),
×
3267
                },
×
3268
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
3269
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
3270
                InboundBaseFeeMsat:      inboundBase,
×
3271
                InboundFeeRateMilliMsat: inboundRate,
×
3272
                Signature:               edge.SigBytes,
×
3273
        })
×
3274
        if err != nil {
×
3275
                return node1Pub, node2Pub, isNode1,
×
3276
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3277
        }
×
3278

3279
        // Convert the flat extra opaque data into a map of TLV types to
3280
        // values.
3281
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3282
        if err != nil {
×
3283
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3284
                        "marshal extra opaque data: %w", err)
×
3285
        }
×
3286

3287
        // Update the channel policy's extra signed fields.
3288
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3289
        if err != nil {
×
3290
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3291
                        "policy extra TLVs: %w", err)
×
3292
        }
×
3293

3294
        return node1Pub, node2Pub, isNode1, nil
×
3295
}
3296

3297
// getNodeByPubKey attempts to look up a target node by its public key.
3298
func getNodeByPubKey(ctx context.Context, db SQLQueries,
3299
        pubKey route.Vertex) (int64, *models.LightningNode, error) {
×
3300

×
3301
        dbNode, err := db.GetNodeByPubKey(
×
3302
                ctx, sqlc.GetNodeByPubKeyParams{
×
3303
                        Version: int16(ProtocolV1),
×
3304
                        PubKey:  pubKey[:],
×
3305
                },
×
3306
        )
×
3307
        if errors.Is(err, sql.ErrNoRows) {
×
3308
                return 0, nil, ErrGraphNodeNotFound
×
3309
        } else if err != nil {
×
3310
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3311
        }
×
3312

3313
        node, err := buildNode(ctx, db, &dbNode)
×
3314
        if err != nil {
×
3315
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3316
        }
×
3317

3318
        return dbNode.ID, node, nil
×
3319
}
3320

3321
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3322
// provided database channel row and the public keys of the two nodes
3323
// involved in the channel.
3324
func buildCacheableChannelInfo(dbChan sqlc.GraphChannel, node1Pub,
3325
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
3326

×
3327
        return &models.CachedEdgeInfo{
×
3328
                ChannelID:     byteOrder.Uint64(dbChan.Scid),
×
3329
                NodeKey1Bytes: node1Pub,
×
3330
                NodeKey2Bytes: node2Pub,
×
3331
                Capacity:      btcutil.Amount(dbChan.Capacity.Int64),
×
3332
        }
×
3333
}
×
3334

3335
// buildNode constructs a LightningNode instance from the given database node
3336
// record. The node's features, addresses and extra signed fields are also
3337
// fetched from the database and set on the node.
3338
func buildNode(ctx context.Context, db SQLQueries, dbNode *sqlc.GraphNode) (
3339
        *models.LightningNode, error) {
×
3340

×
3341
        if dbNode.Version != int16(ProtocolV1) {
×
3342
                return nil, fmt.Errorf("unsupported node version: %d",
×
3343
                        dbNode.Version)
×
3344
        }
×
3345

3346
        var pub [33]byte
×
3347
        copy(pub[:], dbNode.PubKey)
×
3348

×
3349
        node := &models.LightningNode{
×
3350
                PubKeyBytes: pub,
×
3351
                Features:    lnwire.EmptyFeatureVector(),
×
3352
                LastUpdate:  time.Unix(0, 0),
×
3353
        }
×
3354

×
3355
        if len(dbNode.Signature) == 0 {
×
3356
                return node, nil
×
3357
        }
×
3358

3359
        node.HaveNodeAnnouncement = true
×
3360
        node.AuthSigBytes = dbNode.Signature
×
3361
        node.Alias = dbNode.Alias.String
×
3362
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
3363

×
3364
        var err error
×
3365
        if dbNode.Color.Valid {
×
3366
                node.Color, err = DecodeHexColor(dbNode.Color.String)
×
3367
                if err != nil {
×
3368
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3369
                                err)
×
3370
                }
×
3371
        }
3372

3373
        // Fetch the node's features.
3374
        node.Features, err = getNodeFeatures(ctx, db, dbNode.ID)
×
3375
        if err != nil {
×
3376
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3377
                        "features: %w", dbNode.ID, err)
×
3378
        }
×
3379

3380
        // Fetch the node's addresses.
3381
        _, node.Addresses, err = getNodeAddresses(ctx, db, pub[:])
×
3382
        if err != nil {
×
3383
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3384
                        "addresses: %w", dbNode.ID, err)
×
3385
        }
×
3386

3387
        // Fetch the node's extra signed fields.
3388
        extraTLVMap, err := getNodeExtraSignedFields(ctx, db, dbNode.ID)
×
3389
        if err != nil {
×
3390
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3391
                        "extra signed fields: %w", dbNode.ID, err)
×
3392
        }
×
3393

3394
        recs, err := lnwire.CustomRecords(extraTLVMap).Serialize()
×
3395
        if err != nil {
×
3396
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
3397
                        "fields: %w", err)
×
3398
        }
×
3399

3400
        if len(recs) != 0 {
×
3401
                node.ExtraOpaqueData = recs
×
3402
        }
×
3403

3404
        return node, nil
×
3405
}
3406

3407
// getNodeFeatures fetches the feature bits and constructs the feature vector
3408
// for a node with the given DB ID.
3409
func getNodeFeatures(ctx context.Context, db SQLQueries,
3410
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3411

×
3412
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3413
        if err != nil {
×
3414
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3415
                        nodeID, err)
×
3416
        }
×
3417

3418
        features := lnwire.EmptyFeatureVector()
×
3419
        for _, feature := range rows {
×
3420
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3421
        }
×
3422

3423
        return features, nil
×
3424
}
3425

3426
// getNodeExtraSignedFields fetches the extra signed fields for a node with the
3427
// given DB ID.
3428
func getNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3429
        nodeID int64) (map[uint64][]byte, error) {
×
3430

×
3431
        fields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3432
        if err != nil {
×
3433
                return nil, fmt.Errorf("unable to get node(%d) extra "+
×
3434
                        "signed fields: %w", nodeID, err)
×
3435
        }
×
3436

3437
        extraFields := make(map[uint64][]byte)
×
3438
        for _, field := range fields {
×
3439
                extraFields[uint64(field.Type)] = field.Value
×
3440
        }
×
3441

3442
        return extraFields, nil
×
3443
}
3444

3445
// upsertNode upserts the node record into the database. If the node already
3446
// exists, then the node's information is updated. If the node doesn't exist,
3447
// then a new node is created. The node's features, addresses and extra TLV
3448
// types are also updated. The node's DB ID is returned.
3449
func upsertNode(ctx context.Context, db SQLQueries,
3450
        node *models.LightningNode) (int64, error) {
×
3451

×
3452
        params := sqlc.UpsertNodeParams{
×
3453
                Version: int16(ProtocolV1),
×
3454
                PubKey:  node.PubKeyBytes[:],
×
3455
        }
×
3456

×
3457
        if node.HaveNodeAnnouncement {
×
3458
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
3459
                params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
×
3460
                params.Alias = sqldb.SQLStr(node.Alias)
×
3461
                params.Signature = node.AuthSigBytes
×
3462
        }
×
3463

3464
        nodeID, err := db.UpsertNode(ctx, params)
×
3465
        if err != nil {
×
3466
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3467
                        err)
×
3468
        }
×
3469

3470
        // We can exit here if we don't have the announcement yet.
3471
        if !node.HaveNodeAnnouncement {
×
3472
                return nodeID, nil
×
3473
        }
×
3474

3475
        // Update the node's features.
3476
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3477
        if err != nil {
×
3478
                return 0, fmt.Errorf("inserting node features: %w", err)
×
3479
        }
×
3480

3481
        // Update the node's addresses.
3482
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3483
        if err != nil {
×
3484
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
3485
        }
×
3486

3487
        // Convert the flat extra opaque data into a map of TLV types to
3488
        // values.
3489
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
3490
        if err != nil {
×
3491
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
3492
                        err)
×
3493
        }
×
3494

3495
        // Update the node's extra signed fields.
3496
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3497
        if err != nil {
×
3498
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
3499
        }
×
3500

3501
        return nodeID, nil
×
3502
}
3503

3504
// upsertNodeFeatures updates the node's features node_features table. This
3505
// includes deleting any feature bits no longer present and inserting any new
3506
// feature bits. If the feature bit does not yet exist in the features table,
3507
// then an entry is created in that table first.
3508
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3509
        features *lnwire.FeatureVector) error {
×
3510

×
3511
        // Get any existing features for the node.
×
3512
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3513
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3514
                return err
×
3515
        }
×
3516

3517
        // Copy the nodes latest set of feature bits.
3518
        newFeatures := make(map[int32]struct{})
×
3519
        if features != nil {
×
3520
                for feature := range features.Features() {
×
3521
                        newFeatures[int32(feature)] = struct{}{}
×
3522
                }
×
3523
        }
3524

3525
        // For any current feature that already exists in the DB, remove it from
3526
        // the in-memory map. For any existing feature that does not exist in
3527
        // the in-memory map, delete it from the database.
3528
        for _, feature := range existingFeatures {
×
3529
                // The feature is still present, so there are no updates to be
×
3530
                // made.
×
3531
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3532
                        delete(newFeatures, feature.FeatureBit)
×
3533
                        continue
×
3534
                }
3535

3536
                // The feature is no longer present, so we remove it from the
3537
                // database.
3538
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
3539
                        NodeID:     nodeID,
×
3540
                        FeatureBit: feature.FeatureBit,
×
3541
                })
×
3542
                if err != nil {
×
3543
                        return fmt.Errorf("unable to delete node(%d) "+
×
3544
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3545
                                err)
×
3546
                }
×
3547
        }
3548

3549
        // Any remaining entries in newFeatures are new features that need to be
3550
        // added to the database for the first time.
3551
        for feature := range newFeatures {
×
3552
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3553
                        NodeID:     nodeID,
×
3554
                        FeatureBit: feature,
×
3555
                })
×
3556
                if err != nil {
×
3557
                        return fmt.Errorf("unable to insert node(%d) "+
×
3558
                                "feature(%v): %w", nodeID, feature, err)
×
3559
                }
×
3560
        }
3561

3562
        return nil
×
3563
}
3564

3565
// fetchNodeFeatures fetches the features for a node with the given public key.
3566
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3567
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3568

×
3569
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3570
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3571
                        PubKey:  nodePub[:],
×
3572
                        Version: int16(ProtocolV1),
×
3573
                },
×
3574
        )
×
3575
        if err != nil {
×
3576
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
3577
                        nodePub, err)
×
3578
        }
×
3579

3580
        features := lnwire.EmptyFeatureVector()
×
3581
        for _, bit := range rows {
×
3582
                features.Set(lnwire.FeatureBit(bit))
×
3583
        }
×
3584

3585
        return features, nil
×
3586
}
3587

3588
// dbAddressType is an enum type that represents the different address types
3589
// that we store in the node_addresses table. The address type determines how
3590
// the address is to be serialised/deserialize.
3591
type dbAddressType uint8
3592

3593
const (
3594
        addressTypeIPv4   dbAddressType = 1
3595
        addressTypeIPv6   dbAddressType = 2
3596
        addressTypeTorV2  dbAddressType = 3
3597
        addressTypeTorV3  dbAddressType = 4
3598
        addressTypeOpaque dbAddressType = math.MaxInt8
3599
)
3600

3601
// upsertNodeAddresses updates the node's addresses in the database. This
3602
// includes deleting any existing addresses and inserting the new set of
3603
// addresses. The deletion is necessary since the ordering of the addresses may
3604
// change, and we need to ensure that the database reflects the latest set of
3605
// addresses so that at the time of reconstructing the node announcement, the
3606
// order is preserved and the signature over the message remains valid.
3607
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
3608
        addresses []net.Addr) error {
×
3609

×
3610
        // Delete any existing addresses for the node. This is required since
×
3611
        // even if the new set of addresses is the same, the ordering may have
×
3612
        // changed for a given address type.
×
3613
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
3614
        if err != nil {
×
3615
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
3616
                        nodeID, err)
×
3617
        }
×
3618

3619
        // Copy the nodes latest set of addresses.
3620
        newAddresses := map[dbAddressType][]string{
×
3621
                addressTypeIPv4:   {},
×
3622
                addressTypeIPv6:   {},
×
3623
                addressTypeTorV2:  {},
×
3624
                addressTypeTorV3:  {},
×
3625
                addressTypeOpaque: {},
×
3626
        }
×
3627
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3628
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3629
        }
×
3630

3631
        for _, address := range addresses {
×
3632
                switch addr := address.(type) {
×
3633
                case *net.TCPAddr:
×
3634
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3635
                                addAddr(addressTypeIPv4, addr)
×
3636
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3637
                                addAddr(addressTypeIPv6, addr)
×
3638
                        } else {
×
3639
                                return fmt.Errorf("unhandled IP address: %v",
×
3640
                                        addr)
×
3641
                        }
×
3642

3643
                case *tor.OnionAddr:
×
3644
                        switch len(addr.OnionService) {
×
3645
                        case tor.V2Len:
×
3646
                                addAddr(addressTypeTorV2, addr)
×
3647
                        case tor.V3Len:
×
3648
                                addAddr(addressTypeTorV3, addr)
×
3649
                        default:
×
3650
                                return fmt.Errorf("invalid length for a tor " +
×
3651
                                        "address")
×
3652
                        }
3653

3654
                case *lnwire.OpaqueAddrs:
×
3655
                        addAddr(addressTypeOpaque, addr)
×
3656

3657
                default:
×
3658
                        return fmt.Errorf("unhandled address type: %T", addr)
×
3659
                }
3660
        }
3661

3662
        // Any remaining entries in newAddresses are new addresses that need to
3663
        // be added to the database for the first time.
3664
        for addrType, addrList := range newAddresses {
×
3665
                for position, addr := range addrList {
×
3666
                        err := db.InsertNodeAddress(
×
3667
                                ctx, sqlc.InsertNodeAddressParams{
×
3668
                                        NodeID:   nodeID,
×
3669
                                        Type:     int16(addrType),
×
3670
                                        Address:  addr,
×
3671
                                        Position: int32(position),
×
3672
                                },
×
3673
                        )
×
3674
                        if err != nil {
×
3675
                                return fmt.Errorf("unable to insert "+
×
3676
                                        "node(%d) address(%v): %w", nodeID,
×
3677
                                        addr, err)
×
3678
                        }
×
3679
                }
3680
        }
3681

3682
        return nil
×
3683
}
3684

3685
// getNodeAddresses fetches the addresses for a node with the given public key.
3686
func getNodeAddresses(ctx context.Context, db SQLQueries, nodePub []byte) (bool,
3687
        []net.Addr, error) {
×
3688

×
3689
        // GetNodeAddressesByPubKey ensures that the addresses for a given type
×
3690
        // are returned in the same order as they were inserted.
×
3691
        rows, err := db.GetNodeAddressesByPubKey(
×
3692
                ctx, sqlc.GetNodeAddressesByPubKeyParams{
×
3693
                        Version: int16(ProtocolV1),
×
3694
                        PubKey:  nodePub,
×
3695
                },
×
3696
        )
×
3697
        if err != nil {
×
3698
                return false, nil, err
×
3699
        }
×
3700

3701
        // GetNodeAddressesByPubKey uses a left join so there should always be
3702
        // at least one row returned if the node exists even if it has no
3703
        // addresses.
3704
        if len(rows) == 0 {
×
3705
                return false, nil, nil
×
3706
        }
×
3707

3708
        addresses := make([]net.Addr, 0, len(rows))
×
3709
        for _, addr := range rows {
×
3710
                if !(addr.Type.Valid && addr.Address.Valid) {
×
3711
                        continue
×
3712
                }
3713

3714
                address := addr.Address.String
×
3715

×
3716
                switch dbAddressType(addr.Type.Int16) {
×
3717
                case addressTypeIPv4:
×
3718
                        tcp, err := net.ResolveTCPAddr("tcp4", address)
×
3719
                        if err != nil {
×
3720
                                return false, nil, nil
×
3721
                        }
×
3722
                        tcp.IP = tcp.IP.To4()
×
3723

×
3724
                        addresses = append(addresses, tcp)
×
3725

3726
                case addressTypeIPv6:
×
3727
                        tcp, err := net.ResolveTCPAddr("tcp6", address)
×
3728
                        if err != nil {
×
3729
                                return false, nil, nil
×
3730
                        }
×
3731
                        addresses = append(addresses, tcp)
×
3732

3733
                case addressTypeTorV3, addressTypeTorV2:
×
3734
                        service, portStr, err := net.SplitHostPort(address)
×
3735
                        if err != nil {
×
3736
                                return false, nil, fmt.Errorf("unable to "+
×
3737
                                        "split tor v3 address: %v",
×
3738
                                        addr.Address)
×
3739
                        }
×
3740

3741
                        port, err := strconv.Atoi(portStr)
×
3742
                        if err != nil {
×
3743
                                return false, nil, err
×
3744
                        }
×
3745

3746
                        addresses = append(addresses, &tor.OnionAddr{
×
3747
                                OnionService: service,
×
3748
                                Port:         port,
×
3749
                        })
×
3750

3751
                case addressTypeOpaque:
×
3752
                        opaque, err := hex.DecodeString(address)
×
3753
                        if err != nil {
×
3754
                                return false, nil, fmt.Errorf("unable to "+
×
3755
                                        "decode opaque address: %v", addr)
×
3756
                        }
×
3757

3758
                        addresses = append(addresses, &lnwire.OpaqueAddrs{
×
3759
                                Payload: opaque,
×
3760
                        })
×
3761

3762
                default:
×
3763
                        return false, nil, fmt.Errorf("unknown address "+
×
3764
                                "type: %v", addr.Type)
×
3765
                }
3766
        }
3767

3768
        // If we have no addresses, then we'll return nil instead of an
3769
        // empty slice.
3770
        if len(addresses) == 0 {
×
3771
                addresses = nil
×
3772
        }
×
3773

3774
        return true, addresses, nil
×
3775
}
3776

3777
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
3778
// database. This includes updating any existing types, inserting any new types,
3779
// and deleting any types that are no longer present.
3780
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3781
        nodeID int64, extraFields map[uint64][]byte) error {
×
3782

×
3783
        // Get any existing extra signed fields for the node.
×
3784
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3785
        if err != nil {
×
3786
                return err
×
3787
        }
×
3788

3789
        // Make a lookup map of the existing field types so that we can use it
3790
        // to keep track of any fields we should delete.
3791
        m := make(map[uint64]bool)
×
3792
        for _, field := range existingFields {
×
3793
                m[uint64(field.Type)] = true
×
3794
        }
×
3795

3796
        // For all the new fields, we'll upsert them and remove them from the
3797
        // map of existing fields.
3798
        for tlvType, value := range extraFields {
×
3799
                err = db.UpsertNodeExtraType(
×
3800
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
3801
                                NodeID: nodeID,
×
3802
                                Type:   int64(tlvType),
×
3803
                                Value:  value,
×
3804
                        },
×
3805
                )
×
3806
                if err != nil {
×
3807
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
3808
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3809
                }
×
3810

3811
                // Remove the field from the map of existing fields if it was
3812
                // present.
3813
                delete(m, tlvType)
×
3814
        }
3815

3816
        // For all the fields that are left in the map of existing fields, we'll
3817
        // delete them as they are no longer present in the new set of fields.
3818
        for tlvType := range m {
×
3819
                err = db.DeleteExtraNodeType(
×
3820
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
3821
                                NodeID: nodeID,
×
3822
                                Type:   int64(tlvType),
×
3823
                        },
×
3824
                )
×
3825
                if err != nil {
×
3826
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
3827
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3828
                }
×
3829
        }
3830

3831
        return nil
×
3832
}
3833

3834
// srcNodeInfo holds the information about the source node of the graph.
3835
type srcNodeInfo struct {
3836
        // id is the DB level ID of the source node entry in the "nodes" table.
3837
        id int64
3838

3839
        // pub is the public key of the source node.
3840
        pub route.Vertex
3841
}
3842

3843
// sourceNode returns the DB node ID and pub key of the source node for the
3844
// specified protocol version.
3845
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
3846
        version ProtocolVersion) (int64, route.Vertex, error) {
×
3847

×
3848
        s.srcNodeMu.Lock()
×
3849
        defer s.srcNodeMu.Unlock()
×
3850

×
3851
        // If we already have the source node ID and pub key cached, then
×
3852
        // return them.
×
3853
        if info, ok := s.srcNodes[version]; ok {
×
3854
                return info.id, info.pub, nil
×
3855
        }
×
3856

3857
        var pubKey route.Vertex
×
3858

×
3859
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
3860
        if err != nil {
×
3861
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
3862
                        err)
×
3863
        }
×
3864

3865
        if len(nodes) == 0 {
×
3866
                return 0, pubKey, ErrSourceNodeNotSet
×
3867
        } else if len(nodes) > 1 {
×
3868
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
3869
                        "protocol %s found", version)
×
3870
        }
×
3871

3872
        copy(pubKey[:], nodes[0].PubKey)
×
3873

×
3874
        s.srcNodes[version] = &srcNodeInfo{
×
3875
                id:  nodes[0].NodeID,
×
3876
                pub: pubKey,
×
3877
        }
×
3878

×
3879
        return nodes[0].NodeID, pubKey, nil
×
3880
}
3881

3882
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
3883
// This then produces a map from TLV type to value. If the input is not a
3884
// valid TLV stream, then an error is returned.
3885
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
3886
        r := bytes.NewReader(data)
×
3887

×
3888
        tlvStream, err := tlv.NewStream()
×
3889
        if err != nil {
×
3890
                return nil, err
×
3891
        }
×
3892

3893
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
3894
        // pass it into the P2P decoding variant.
3895
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
3896
        if err != nil {
×
3897
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
3898
        }
×
3899
        if len(parsedTypes) == 0 {
×
3900
                return nil, nil
×
3901
        }
×
3902

3903
        records := make(map[uint64][]byte)
×
3904
        for k, v := range parsedTypes {
×
3905
                records[uint64(k)] = v
×
3906
        }
×
3907

3908
        return records, nil
×
3909
}
3910

3911
// dbChanInfo holds the DB level IDs of a channel and the nodes involved in the
3912
// channel.
3913
type dbChanInfo struct {
3914
        channelID int64
3915
        node1ID   int64
3916
        node2ID   int64
3917
}
3918

3919
// insertChannel inserts a new channel record into the database.
3920
func insertChannel(ctx context.Context, db SQLQueries,
3921
        edge *models.ChannelEdgeInfo) (*dbChanInfo, error) {
×
3922

×
3923
        chanIDB := channelIDToBytes(edge.ChannelID)
×
3924

×
3925
        // Make sure that the channel doesn't already exist. We do this
×
3926
        // explicitly instead of relying on catching a unique constraint error
×
3927
        // because relying on SQL to throw that error would abort the entire
×
3928
        // batch of transactions.
×
3929
        _, err := db.GetChannelBySCID(
×
3930
                ctx, sqlc.GetChannelBySCIDParams{
×
3931
                        Scid:    chanIDB,
×
3932
                        Version: int16(ProtocolV1),
×
3933
                },
×
3934
        )
×
3935
        if err == nil {
×
3936
                return nil, ErrEdgeAlreadyExist
×
3937
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3938
                return nil, fmt.Errorf("unable to fetch channel: %w", err)
×
3939
        }
×
3940

3941
        // Make sure that at least a "shell" entry for each node is present in
3942
        // the nodes table.
3943
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
3944
        if err != nil {
×
3945
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3946
        }
×
3947

3948
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
3949
        if err != nil {
×
3950
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3951
        }
×
3952

3953
        var capacity sql.NullInt64
×
3954
        if edge.Capacity != 0 {
×
3955
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
3956
        }
×
3957

3958
        createParams := sqlc.CreateChannelParams{
×
3959
                Version:     int16(ProtocolV1),
×
3960
                Scid:        chanIDB,
×
3961
                NodeID1:     node1DBID,
×
3962
                NodeID2:     node2DBID,
×
3963
                Outpoint:    edge.ChannelPoint.String(),
×
3964
                Capacity:    capacity,
×
3965
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
3966
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
3967
        }
×
3968

×
3969
        if edge.AuthProof != nil {
×
3970
                proof := edge.AuthProof
×
3971

×
3972
                createParams.Node1Signature = proof.NodeSig1Bytes
×
3973
                createParams.Node2Signature = proof.NodeSig2Bytes
×
3974
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
3975
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
3976
        }
×
3977

3978
        // Insert the new channel record.
3979
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
3980
        if err != nil {
×
3981
                return nil, err
×
3982
        }
×
3983

3984
        // Insert any channel features.
3985
        for feature := range edge.Features.Features() {
×
3986
                err = db.InsertChannelFeature(
×
3987
                        ctx, sqlc.InsertChannelFeatureParams{
×
3988
                                ChannelID:  dbChanID,
×
3989
                                FeatureBit: int32(feature),
×
3990
                        },
×
3991
                )
×
3992
                if err != nil {
×
3993
                        return nil, fmt.Errorf("unable to insert channel(%d) "+
×
3994
                                "feature(%v): %w", dbChanID, feature, err)
×
3995
                }
×
3996
        }
3997

3998
        // Finally, insert any extra TLV fields in the channel announcement.
3999
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
4000
        if err != nil {
×
4001
                return nil, fmt.Errorf("unable to marshal extra opaque "+
×
4002
                        "data: %w", err)
×
4003
        }
×
4004

4005
        for tlvType, value := range extra {
×
4006
                err := db.CreateChannelExtraType(
×
4007
                        ctx, sqlc.CreateChannelExtraTypeParams{
×
4008
                                ChannelID: dbChanID,
×
4009
                                Type:      int64(tlvType),
×
4010
                                Value:     value,
×
4011
                        },
×
4012
                )
×
4013
                if err != nil {
×
4014
                        return nil, fmt.Errorf("unable to upsert "+
×
4015
                                "channel(%d) extra signed field(%v): %w",
×
4016
                                edge.ChannelID, tlvType, err)
×
4017
                }
×
4018
        }
4019

4020
        return &dbChanInfo{
×
4021
                channelID: dbChanID,
×
4022
                node1ID:   node1DBID,
×
4023
                node2ID:   node2DBID,
×
4024
        }, nil
×
4025
}
4026

4027
// maybeCreateShellNode checks if a shell node entry exists for the
4028
// given public key. If it does not exist, then a new shell node entry is
4029
// created. The ID of the node is returned. A shell node only has a protocol
4030
// version and public key persisted.
4031
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
4032
        pubKey route.Vertex) (int64, error) {
×
4033

×
4034
        dbNode, err := db.GetNodeByPubKey(
×
4035
                ctx, sqlc.GetNodeByPubKeyParams{
×
4036
                        PubKey:  pubKey[:],
×
4037
                        Version: int16(ProtocolV1),
×
4038
                },
×
4039
        )
×
4040
        // The node exists. Return the ID.
×
4041
        if err == nil {
×
4042
                return dbNode.ID, nil
×
4043
        } else if !errors.Is(err, sql.ErrNoRows) {
×
4044
                return 0, err
×
4045
        }
×
4046

4047
        // Otherwise, the node does not exist, so we create a shell entry for
4048
        // it.
4049
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
4050
                Version: int16(ProtocolV1),
×
4051
                PubKey:  pubKey[:],
×
4052
        })
×
4053
        if err != nil {
×
4054
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
4055
        }
×
4056

4057
        return id, nil
×
4058
}
4059

4060
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
4061
// the database. This includes deleting any existing types and then inserting
4062
// the new types.
4063
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
4064
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
4065

×
4066
        // Delete all existing extra signed fields for the channel policy.
×
4067
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
4068
        if err != nil {
×
4069
                return fmt.Errorf("unable to delete "+
×
4070
                        "existing policy extra signed fields for policy %d: %w",
×
4071
                        chanPolicyID, err)
×
4072
        }
×
4073

4074
        // Insert all new extra signed fields for the channel policy.
4075
        for tlvType, value := range extraFields {
×
4076
                err = db.InsertChanPolicyExtraType(
×
4077
                        ctx, sqlc.InsertChanPolicyExtraTypeParams{
×
4078
                                ChannelPolicyID: chanPolicyID,
×
4079
                                Type:            int64(tlvType),
×
4080
                                Value:           value,
×
4081
                        },
×
4082
                )
×
4083
                if err != nil {
×
4084
                        return fmt.Errorf("unable to insert "+
×
4085
                                "channel_policy(%d) extra signed field(%v): %w",
×
4086
                                chanPolicyID, tlvType, err)
×
4087
                }
×
4088
        }
4089

4090
        return nil
×
4091
}
4092

4093
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
4094
// provided dbChanRow and also fetches any other required information
4095
// to construct the edge info.
4096
func getAndBuildEdgeInfo(ctx context.Context, db SQLQueries,
4097
        chain chainhash.Hash, dbChanID int64, dbChan sqlc.GraphChannel, node1,
4098
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
4099

×
4100
        if dbChan.Version != int16(ProtocolV1) {
×
4101
                return nil, fmt.Errorf("unsupported channel version: %d",
×
4102
                        dbChan.Version)
×
4103
        }
×
4104

4105
        fv, extras, err := getChanFeaturesAndExtras(
×
4106
                ctx, db, dbChanID,
×
4107
        )
×
4108
        if err != nil {
×
4109
                return nil, err
×
4110
        }
×
4111

4112
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
4113
        if err != nil {
×
4114
                return nil, err
×
4115
        }
×
4116

4117
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4118
        if err != nil {
×
4119
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4120
                        "fields: %w", err)
×
4121
        }
×
4122
        if recs == nil {
×
4123
                recs = make([]byte, 0)
×
4124
        }
×
4125

4126
        var btcKey1, btcKey2 route.Vertex
×
4127
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
4128
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
4129

×
4130
        channel := &models.ChannelEdgeInfo{
×
4131
                ChainHash:        chain,
×
4132
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
4133
                NodeKey1Bytes:    node1,
×
4134
                NodeKey2Bytes:    node2,
×
4135
                BitcoinKey1Bytes: btcKey1,
×
4136
                BitcoinKey2Bytes: btcKey2,
×
4137
                ChannelPoint:     *op,
×
4138
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
4139
                Features:         fv,
×
4140
                ExtraOpaqueData:  recs,
×
4141
        }
×
4142

×
4143
        // We always set all the signatures at the same time, so we can
×
4144
        // safely check if one signature is present to determine if we have the
×
4145
        // rest of the signatures for the auth proof.
×
4146
        if len(dbChan.Bitcoin1Signature) > 0 {
×
4147
                channel.AuthProof = &models.ChannelAuthProof{
×
4148
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
4149
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
4150
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
4151
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
4152
                }
×
4153
        }
×
4154

4155
        return channel, nil
×
4156
}
4157

4158
// buildNodeVertices is a helper that converts raw node public keys
4159
// into route.Vertex instances.
4160
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4161
        route.Vertex, error) {
×
4162

×
4163
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4164
        if err != nil {
×
4165
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4166
                        "create vertex from node1 pubkey: %w", err)
×
4167
        }
×
4168

4169
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
4170
        if err != nil {
×
4171
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4172
                        "create vertex from node2 pubkey: %w", err)
×
4173
        }
×
4174

4175
        return node1Vertex, node2Vertex, nil
×
4176
}
4177

4178
// getChanFeaturesAndExtras fetches the channel features and extra TLV types
4179
// for a channel with the given ID.
4180
func getChanFeaturesAndExtras(ctx context.Context, db SQLQueries,
4181
        id int64) (*lnwire.FeatureVector, map[uint64][]byte, error) {
×
4182

×
4183
        rows, err := db.GetChannelFeaturesAndExtras(ctx, id)
×
4184
        if err != nil {
×
4185
                return nil, nil, fmt.Errorf("unable to fetch channel "+
×
4186
                        "features and extras: %w", err)
×
4187
        }
×
4188

4189
        var (
×
4190
                fv     = lnwire.EmptyFeatureVector()
×
4191
                extras = make(map[uint64][]byte)
×
4192
        )
×
4193
        for _, row := range rows {
×
4194
                if row.IsFeature {
×
4195
                        fv.Set(lnwire.FeatureBit(row.FeatureBit))
×
4196

×
4197
                        continue
×
4198
                }
4199

4200
                tlvType, ok := row.ExtraKey.(int64)
×
4201
                if !ok {
×
4202
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
4203
                                "TLV type: %T", row.ExtraKey)
×
4204
                }
×
4205

4206
                valueBytes, ok := row.Value.([]byte)
×
4207
                if !ok {
×
4208
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
4209
                                "Value: %T", row.Value)
×
4210
                }
×
4211

4212
                extras[uint64(tlvType)] = valueBytes
×
4213
        }
4214

4215
        return fv, extras, nil
×
4216
}
4217

4218
// getAndBuildChanPolicies uses the given sqlc.GraphChannelPolicy and also
4219
// retrieves all the extra info required to build the complete
4220
// models.ChannelEdgePolicy types. It returns two policies, which may be nil if
4221
// the provided sqlc.GraphChannelPolicy records are nil.
4222
func getAndBuildChanPolicies(ctx context.Context, db SQLQueries,
4223
        dbPol1, dbPol2 *sqlc.GraphChannelPolicy, channelID uint64, node1,
4224
        node2 route.Vertex) (*models.ChannelEdgePolicy,
4225
        *models.ChannelEdgePolicy, error) {
×
4226

×
4227
        if dbPol1 == nil && dbPol2 == nil {
×
4228
                return nil, nil, nil
×
4229
        }
×
4230

4231
        var (
×
4232
                policy1ID int64
×
4233
                policy2ID int64
×
4234
        )
×
4235
        if dbPol1 != nil {
×
4236
                policy1ID = dbPol1.ID
×
4237
        }
×
4238
        if dbPol2 != nil {
×
4239
                policy2ID = dbPol2.ID
×
4240
        }
×
4241
        rows, err := db.GetChannelPolicyExtraTypes(
×
4242
                ctx, sqlc.GetChannelPolicyExtraTypesParams{
×
4243
                        ID:   policy1ID,
×
4244
                        ID_2: policy2ID,
×
4245
                },
×
4246
        )
×
4247
        if err != nil {
×
4248
                return nil, nil, err
×
4249
        }
×
4250

4251
        var (
×
4252
                dbPol1Extras = make(map[uint64][]byte)
×
4253
                dbPol2Extras = make(map[uint64][]byte)
×
4254
        )
×
4255
        for _, row := range rows {
×
4256
                switch row.PolicyID {
×
4257
                case policy1ID:
×
4258
                        dbPol1Extras[uint64(row.Type)] = row.Value
×
4259
                case policy2ID:
×
4260
                        dbPol2Extras[uint64(row.Type)] = row.Value
×
4261
                default:
×
4262
                        return nil, nil, fmt.Errorf("unexpected policy ID %d "+
×
4263
                                "in row: %v", row.PolicyID, row)
×
4264
                }
4265
        }
4266

4267
        var pol1, pol2 *models.ChannelEdgePolicy
×
4268
        if dbPol1 != nil {
×
4269
                pol1, err = buildChanPolicy(
×
4270
                        *dbPol1, channelID, dbPol1Extras, node2,
×
4271
                )
×
4272
                if err != nil {
×
4273
                        return nil, nil, err
×
4274
                }
×
4275
        }
4276
        if dbPol2 != nil {
×
4277
                pol2, err = buildChanPolicy(
×
4278
                        *dbPol2, channelID, dbPol2Extras, node1,
×
4279
                )
×
4280
                if err != nil {
×
4281
                        return nil, nil, err
×
4282
                }
×
4283
        }
4284

4285
        return pol1, pol2, nil
×
4286
}
4287

4288
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4289
// provided sqlc.GraphChannelPolicy and other required information.
4290
func buildChanPolicy(dbPolicy sqlc.GraphChannelPolicy, channelID uint64,
4291
        extras map[uint64][]byte,
4292
        toNode route.Vertex) (*models.ChannelEdgePolicy, error) {
×
4293

×
4294
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4295
        if err != nil {
×
4296
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4297
                        "fields: %w", err)
×
4298
        }
×
4299

4300
        var inboundFee fn.Option[lnwire.Fee]
×
4301
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
4302
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4303

×
4304
                inboundFee = fn.Some(lnwire.Fee{
×
4305
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4306
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4307
                })
×
4308
        }
×
4309

4310
        return &models.ChannelEdgePolicy{
×
4311
                SigBytes:  dbPolicy.Signature,
×
4312
                ChannelID: channelID,
×
4313
                LastUpdate: time.Unix(
×
4314
                        dbPolicy.LastUpdate.Int64, 0,
×
4315
                ),
×
4316
                MessageFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags](
×
4317
                        dbPolicy.MessageFlags,
×
4318
                ),
×
4319
                ChannelFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags](
×
4320
                        dbPolicy.ChannelFlags,
×
4321
                ),
×
4322
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
4323
                MinHTLC: lnwire.MilliSatoshi(
×
4324
                        dbPolicy.MinHtlcMsat,
×
4325
                ),
×
4326
                MaxHTLC: lnwire.MilliSatoshi(
×
4327
                        dbPolicy.MaxHtlcMsat.Int64,
×
4328
                ),
×
4329
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4330
                        dbPolicy.BaseFeeMsat,
×
4331
                ),
×
4332
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4333
                ToNode:                    toNode,
×
4334
                InboundFee:                inboundFee,
×
4335
                ExtraOpaqueData:           recs,
×
4336
        }, nil
×
4337
}
4338

4339
// buildNodes builds the models.LightningNode instances for the
4340
// given row which is expected to be a sqlc type that contains node information.
4341
func buildNodes(ctx context.Context, db SQLQueries, dbNode1,
4342
        dbNode2 sqlc.GraphNode) (*models.LightningNode, *models.LightningNode,
4343
        error) {
×
4344

×
4345
        node1, err := buildNode(ctx, db, &dbNode1)
×
4346
        if err != nil {
×
4347
                return nil, nil, err
×
4348
        }
×
4349

4350
        node2, err := buildNode(ctx, db, &dbNode2)
×
4351
        if err != nil {
×
4352
                return nil, nil, err
×
4353
        }
×
4354

4355
        return node1, node2, nil
×
4356
}
4357

4358
// extractChannelPolicies extracts the sqlc.GraphChannelPolicy records from the give
4359
// row which is expected to be a sqlc type that contains channel policy
4360
// information. It returns two policies, which may be nil if the policy
4361
// information is not present in the row.
4362
//
4363
//nolint:ll,dupl,funlen
4364
func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy,
4365
        *sqlc.GraphChannelPolicy, error) {
×
4366

×
4367
        var policy1, policy2 *sqlc.GraphChannelPolicy
×
4368
        switch r := row.(type) {
×
NEW
4369
        case sqlc.GetChannelsBySCIDWithPoliciesRow:
×
NEW
4370
                if r.Policy1ID.Valid {
×
NEW
4371
                        policy1 = &sqlc.GraphChannelPolicy{
×
NEW
4372
                                ID:                      r.Policy1ID.Int64,
×
NEW
4373
                                Version:                 r.Policy1Version.Int16,
×
NEW
4374
                                ChannelID:               r.GraphChannel.ID,
×
NEW
4375
                                NodeID:                  r.Policy1NodeID.Int64,
×
NEW
4376
                                Timelock:                r.Policy1Timelock.Int32,
×
NEW
4377
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
NEW
4378
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
NEW
4379
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
NEW
4380
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
NEW
4381
                                LastUpdate:              r.Policy1LastUpdate,
×
NEW
4382
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
NEW
4383
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
NEW
4384
                                Disabled:                r.Policy1Disabled,
×
NEW
4385
                                MessageFlags:            r.Policy1MessageFlags,
×
NEW
4386
                                ChannelFlags:            r.Policy1ChannelFlags,
×
NEW
4387
                                Signature:               r.Policy1Signature,
×
NEW
4388
                        }
×
NEW
4389
                }
×
NEW
4390
                if r.Policy2ID.Valid {
×
NEW
4391
                        policy2 = &sqlc.GraphChannelPolicy{
×
NEW
4392
                                ID:                      r.Policy2ID.Int64,
×
NEW
4393
                                Version:                 r.Policy2Version.Int16,
×
NEW
4394
                                ChannelID:               r.GraphChannel.ID,
×
NEW
4395
                                NodeID:                  r.Policy2NodeID.Int64,
×
NEW
4396
                                Timelock:                r.Policy2Timelock.Int32,
×
NEW
4397
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
NEW
4398
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
NEW
4399
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
NEW
4400
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
NEW
4401
                                LastUpdate:              r.Policy2LastUpdate,
×
NEW
4402
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
NEW
4403
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
NEW
4404
                                Disabled:                r.Policy2Disabled,
×
NEW
4405
                                MessageFlags:            r.Policy2MessageFlags,
×
NEW
4406
                                ChannelFlags:            r.Policy2ChannelFlags,
×
NEW
4407
                                Signature:               r.Policy2Signature,
×
NEW
4408
                        }
×
NEW
4409
                }
×
4410

NEW
4411
                return policy1, policy2, nil
×
4412

4413
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
4414
                if r.Policy1ID.Valid {
×
4415
                        policy1 = &sqlc.GraphChannelPolicy{
×
4416
                                ID:                      r.Policy1ID.Int64,
×
4417
                                Version:                 r.Policy1Version.Int16,
×
4418
                                ChannelID:               r.GraphChannel.ID,
×
4419
                                NodeID:                  r.Policy1NodeID.Int64,
×
4420
                                Timelock:                r.Policy1Timelock.Int32,
×
4421
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4422
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4423
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4424
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4425
                                LastUpdate:              r.Policy1LastUpdate,
×
4426
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4427
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4428
                                Disabled:                r.Policy1Disabled,
×
4429
                                MessageFlags:            r.Policy1MessageFlags,
×
4430
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4431
                                Signature:               r.Policy1Signature,
×
4432
                        }
×
4433
                }
×
4434
                if r.Policy2ID.Valid {
×
4435
                        policy2 = &sqlc.GraphChannelPolicy{
×
4436
                                ID:                      r.Policy2ID.Int64,
×
4437
                                Version:                 r.Policy2Version.Int16,
×
4438
                                ChannelID:               r.GraphChannel.ID,
×
4439
                                NodeID:                  r.Policy2NodeID.Int64,
×
4440
                                Timelock:                r.Policy2Timelock.Int32,
×
4441
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4442
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4443
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4444
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4445
                                LastUpdate:              r.Policy2LastUpdate,
×
4446
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4447
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4448
                                Disabled:                r.Policy2Disabled,
×
4449
                                MessageFlags:            r.Policy2MessageFlags,
×
4450
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4451
                                Signature:               r.Policy2Signature,
×
4452
                        }
×
4453
                }
×
4454

4455
                return policy1, policy2, nil
×
4456

4457
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4458
                if r.Policy1ID.Valid {
×
4459
                        policy1 = &sqlc.GraphChannelPolicy{
×
4460
                                ID:                      r.Policy1ID.Int64,
×
4461
                                Version:                 r.Policy1Version.Int16,
×
4462
                                ChannelID:               r.GraphChannel.ID,
×
4463
                                NodeID:                  r.Policy1NodeID.Int64,
×
4464
                                Timelock:                r.Policy1Timelock.Int32,
×
4465
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4466
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4467
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4468
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4469
                                LastUpdate:              r.Policy1LastUpdate,
×
4470
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4471
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4472
                                Disabled:                r.Policy1Disabled,
×
4473
                                MessageFlags:            r.Policy1MessageFlags,
×
4474
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4475
                                Signature:               r.Policy1Signature,
×
4476
                        }
×
4477
                }
×
4478
                if r.Policy2ID.Valid {
×
4479
                        policy2 = &sqlc.GraphChannelPolicy{
×
4480
                                ID:                      r.Policy2ID.Int64,
×
4481
                                Version:                 r.Policy2Version.Int16,
×
4482
                                ChannelID:               r.GraphChannel.ID,
×
4483
                                NodeID:                  r.Policy2NodeID.Int64,
×
4484
                                Timelock:                r.Policy2Timelock.Int32,
×
4485
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4486
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4487
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4488
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4489
                                LastUpdate:              r.Policy2LastUpdate,
×
4490
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4491
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4492
                                Disabled:                r.Policy2Disabled,
×
4493
                                MessageFlags:            r.Policy2MessageFlags,
×
4494
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4495
                                Signature:               r.Policy2Signature,
×
4496
                        }
×
4497
                }
×
4498

4499
                return policy1, policy2, nil
×
4500

4501
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4502
                if r.Policy1ID.Valid {
×
4503
                        policy1 = &sqlc.GraphChannelPolicy{
×
4504
                                ID:                      r.Policy1ID.Int64,
×
4505
                                Version:                 r.Policy1Version.Int16,
×
4506
                                ChannelID:               r.GraphChannel.ID,
×
4507
                                NodeID:                  r.Policy1NodeID.Int64,
×
4508
                                Timelock:                r.Policy1Timelock.Int32,
×
4509
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4510
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4511
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4512
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4513
                                LastUpdate:              r.Policy1LastUpdate,
×
4514
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4515
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4516
                                Disabled:                r.Policy1Disabled,
×
4517
                                MessageFlags:            r.Policy1MessageFlags,
×
4518
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4519
                                Signature:               r.Policy1Signature,
×
4520
                        }
×
4521
                }
×
4522
                if r.Policy2ID.Valid {
×
4523
                        policy2 = &sqlc.GraphChannelPolicy{
×
4524
                                ID:                      r.Policy2ID.Int64,
×
4525
                                Version:                 r.Policy2Version.Int16,
×
4526
                                ChannelID:               r.GraphChannel.ID,
×
4527
                                NodeID:                  r.Policy2NodeID.Int64,
×
4528
                                Timelock:                r.Policy2Timelock.Int32,
×
4529
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4530
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4531
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4532
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4533
                                LastUpdate:              r.Policy2LastUpdate,
×
4534
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4535
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4536
                                Disabled:                r.Policy2Disabled,
×
4537
                                MessageFlags:            r.Policy2MessageFlags,
×
4538
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4539
                                Signature:               r.Policy2Signature,
×
4540
                        }
×
4541
                }
×
4542

4543
                return policy1, policy2, nil
×
4544

4545
        case sqlc.ListChannelsByNodeIDRow:
×
4546
                if r.Policy1ID.Valid {
×
4547
                        policy1 = &sqlc.GraphChannelPolicy{
×
4548
                                ID:                      r.Policy1ID.Int64,
×
4549
                                Version:                 r.Policy1Version.Int16,
×
4550
                                ChannelID:               r.GraphChannel.ID,
×
4551
                                NodeID:                  r.Policy1NodeID.Int64,
×
4552
                                Timelock:                r.Policy1Timelock.Int32,
×
4553
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4554
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4555
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4556
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4557
                                LastUpdate:              r.Policy1LastUpdate,
×
4558
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4559
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4560
                                Disabled:                r.Policy1Disabled,
×
4561
                                MessageFlags:            r.Policy1MessageFlags,
×
4562
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4563
                                Signature:               r.Policy1Signature,
×
4564
                        }
×
4565
                }
×
4566
                if r.Policy2ID.Valid {
×
4567
                        policy2 = &sqlc.GraphChannelPolicy{
×
4568
                                ID:                      r.Policy2ID.Int64,
×
4569
                                Version:                 r.Policy2Version.Int16,
×
4570
                                ChannelID:               r.GraphChannel.ID,
×
4571
                                NodeID:                  r.Policy2NodeID.Int64,
×
4572
                                Timelock:                r.Policy2Timelock.Int32,
×
4573
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4574
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4575
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4576
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4577
                                LastUpdate:              r.Policy2LastUpdate,
×
4578
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4579
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4580
                                Disabled:                r.Policy2Disabled,
×
4581
                                MessageFlags:            r.Policy2MessageFlags,
×
4582
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4583
                                Signature:               r.Policy2Signature,
×
4584
                        }
×
4585
                }
×
4586

4587
                return policy1, policy2, nil
×
4588

4589
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4590
                if r.Policy1ID.Valid {
×
4591
                        policy1 = &sqlc.GraphChannelPolicy{
×
4592
                                ID:                      r.Policy1ID.Int64,
×
4593
                                Version:                 r.Policy1Version.Int16,
×
4594
                                ChannelID:               r.GraphChannel.ID,
×
4595
                                NodeID:                  r.Policy1NodeID.Int64,
×
4596
                                Timelock:                r.Policy1Timelock.Int32,
×
4597
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4598
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4599
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4600
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4601
                                LastUpdate:              r.Policy1LastUpdate,
×
4602
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4603
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4604
                                Disabled:                r.Policy1Disabled,
×
4605
                                MessageFlags:            r.Policy1MessageFlags,
×
4606
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4607
                                Signature:               r.Policy1Signature,
×
4608
                        }
×
4609
                }
×
4610
                if r.Policy2ID.Valid {
×
4611
                        policy2 = &sqlc.GraphChannelPolicy{
×
4612
                                ID:                      r.Policy2ID.Int64,
×
4613
                                Version:                 r.Policy2Version.Int16,
×
4614
                                ChannelID:               r.GraphChannel.ID,
×
4615
                                NodeID:                  r.Policy2NodeID.Int64,
×
4616
                                Timelock:                r.Policy2Timelock.Int32,
×
4617
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4618
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4619
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4620
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4621
                                LastUpdate:              r.Policy2LastUpdate,
×
4622
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4623
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4624
                                Disabled:                r.Policy2Disabled,
×
4625
                                MessageFlags:            r.Policy2MessageFlags,
×
4626
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4627
                                Signature:               r.Policy2Signature,
×
4628
                        }
×
4629
                }
×
4630

4631
                return policy1, policy2, nil
×
4632
        default:
×
4633
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
4634
                        "extractChannelPolicies: %T", r)
×
4635
        }
4636
}
4637

4638
// channelIDToBytes converts a channel ID (SCID) to a byte array
4639
// representation.
4640
func channelIDToBytes(channelID uint64) []byte {
×
4641
        var chanIDB [8]byte
×
4642
        byteOrder.PutUint64(chanIDB[:], channelID)
×
4643

×
4644
        return chanIDB[:]
×
4645
}
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc