• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 17370655067

01 Sep 2025 07:23AM UTC coverage: 66.678% (+9.4%) from 57.321%
17370655067

push

github

web-flow
Merge pull request #10161 from ellemouton/graphRetry

graph/db+sqldb: Make the SQL migration retry-safe/idempotent

0 of 356 new or added lines in 3 files covered. (0.0%)

19 existing lines in 7 files now uncovered.

136009 of 203980 relevant lines covered (66.68%)

21390.06 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "maps"
11
        "math"
12
        "net"
13
        "slices"
14
        "strconv"
15
        "sync"
16
        "time"
17

18
        "github.com/btcsuite/btcd/btcec/v2"
19
        "github.com/btcsuite/btcd/btcutil"
20
        "github.com/btcsuite/btcd/chaincfg/chainhash"
21
        "github.com/btcsuite/btcd/wire"
22
        "github.com/lightningnetwork/lnd/aliasmgr"
23
        "github.com/lightningnetwork/lnd/batch"
24
        "github.com/lightningnetwork/lnd/fn/v2"
25
        "github.com/lightningnetwork/lnd/graph/db/models"
26
        "github.com/lightningnetwork/lnd/lnwire"
27
        "github.com/lightningnetwork/lnd/routing/route"
28
        "github.com/lightningnetwork/lnd/sqldb"
29
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
30
        "github.com/lightningnetwork/lnd/tlv"
31
        "github.com/lightningnetwork/lnd/tor"
32
)
33

34
// ProtocolVersion is an enum that defines the gossip protocol version of a
35
// message.
36
type ProtocolVersion uint8
37

38
const (
39
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
40
        ProtocolV1 ProtocolVersion = 1
41
)
42

43
// String returns a string representation of the protocol version.
44
func (v ProtocolVersion) String() string {
×
45
        return fmt.Sprintf("V%d", v)
×
46
}
×
47

48
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
49
// execute queries against the SQL graph tables.
50
//
51
//nolint:ll,interfacebloat
52
type SQLQueries interface {
53
        /*
54
                Node queries.
55
        */
56
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
57
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.GraphNode, error)
58
        GetNodesByIDs(ctx context.Context, ids []int64) ([]sqlc.GraphNode, error)
59
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
60
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.GraphNode, error)
61
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.GraphNode, error)
62
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
63
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
64
        DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error)
65
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
66
        DeleteNode(ctx context.Context, id int64) error
67

68
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeExtraType, error)
69
        GetNodeExtraTypesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeExtraType, error)
70
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
71
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
72

73
        UpsertNodeAddress(ctx context.Context, arg sqlc.UpsertNodeAddressParams) error
74
        GetNodeAddresses(ctx context.Context, nodeID int64) ([]sqlc.GetNodeAddressesRow, error)
75
        GetNodeAddressesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress, error)
76
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
77

78
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
79
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeFeature, error)
80
        GetNodeFeaturesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature, error)
81
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
82
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
83

84
        /*
85
                Source node queries.
86
        */
87
        AddSourceNode(ctx context.Context, nodeID int64) error
88
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
89

90
        /*
91
                Channel queries.
92
        */
93
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
94
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
95
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.GraphChannel, error)
96
        GetChannelsBySCIDs(ctx context.Context, arg sqlc.GetChannelsBySCIDsParams) ([]sqlc.GraphChannel, error)
97
        GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]sqlc.GetChannelsByOutpointsRow, error)
98
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
99
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
100
        GetChannelsBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelsBySCIDWithPoliciesParams) ([]sqlc.GetChannelsBySCIDWithPoliciesRow, error)
101
        GetChannelsByIDs(ctx context.Context, ids []int64) ([]sqlc.GetChannelsByIDsRow, error)
102
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
103
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
104
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
105
        ListChannelsForNodeIDs(ctx context.Context, arg sqlc.ListChannelsForNodeIDsParams) ([]sqlc.ListChannelsForNodeIDsRow, error)
106
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
107
        ListChannelsWithPoliciesForCachePaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesForCachePaginatedParams) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow, error)
108
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
109
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
110
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
111
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.GraphChannel, error)
112
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
113
        DeleteChannels(ctx context.Context, ids []int64) error
114

115
        UpsertChannelExtraType(ctx context.Context, arg sqlc.UpsertChannelExtraTypeParams) error
116
        GetChannelExtrasBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelExtraType, error)
117
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
118
        GetChannelFeaturesBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelFeature, error)
119

120
        /*
121
                Channel Policy table queries.
122
        */
123
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
124
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.GraphChannelPolicy, error)
125
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
126

127
        UpsertChanPolicyExtraType(ctx context.Context, arg sqlc.UpsertChanPolicyExtraTypeParams) error
128
        GetChannelPolicyExtraTypesBatch(ctx context.Context, policyIds []int64) ([]sqlc.GetChannelPolicyExtraTypesBatchRow, error)
129
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
130

131
        /*
132
                Zombie index queries.
133
        */
134
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
135
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.GraphZombieChannel, error)
136
        GetZombieChannelsSCIDs(ctx context.Context, arg sqlc.GetZombieChannelsSCIDsParams) ([]sqlc.GraphZombieChannel, error)
137
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
138
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
139
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
140

141
        /*
142
                Prune log table queries.
143
        */
144
        GetPruneTip(ctx context.Context) (sqlc.GraphPruneLog, error)
145
        GetPruneHashByHeight(ctx context.Context, blockHeight int64) ([]byte, error)
146
        GetPruneEntriesForHeights(ctx context.Context, heights []int64) ([]sqlc.GraphPruneLog, error)
147
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
148
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
149

150
        /*
151
                Closed SCID table queries.
152
        */
153
        InsertClosedChannel(ctx context.Context, scid []byte) error
154
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
155
        GetClosedChannelsSCIDs(ctx context.Context, scids [][]byte) ([][]byte, error)
156

157
        /*
158
                Migration specific queries.
159

160
                NOTE: these should not be used in code other than migrations.
161
                Once sqldbv2 is in place, these can be removed from this struct
162
                as then migrations will have their own dedicated queries
163
                structs.
164
        */
165
        InsertNodeMig(ctx context.Context, arg sqlc.InsertNodeMigParams) (int64, error)
166
        InsertChannelMig(ctx context.Context, arg sqlc.InsertChannelMigParams) (int64, error)
167
        InsertEdgePolicyMig(ctx context.Context, arg sqlc.InsertEdgePolicyMigParams) (int64, error)
168
}
169

170
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
171
// database operations.
172
type BatchedSQLQueries interface {
173
        SQLQueries
174
        sqldb.BatchedTx[SQLQueries]
175
}
176

177
// SQLStore is an implementation of the V1Store interface that uses a SQL
178
// database as the backend.
179
type SQLStore struct {
180
        cfg *SQLStoreConfig
181
        db  BatchedSQLQueries
182

183
        // cacheMu guards all caches (rejectCache and chanCache). If
184
        // this mutex will be acquired at the same time as the DB mutex then
185
        // the cacheMu MUST be acquired first to prevent deadlock.
186
        cacheMu     sync.RWMutex
187
        rejectCache *rejectCache
188
        chanCache   *channelCache
189

190
        chanScheduler batch.Scheduler[SQLQueries]
191
        nodeScheduler batch.Scheduler[SQLQueries]
192

193
        srcNodes  map[ProtocolVersion]*srcNodeInfo
194
        srcNodeMu sync.Mutex
195
}
196

197
// A compile-time assertion to ensure that SQLStore implements the V1Store
198
// interface.
199
var _ V1Store = (*SQLStore)(nil)
200

201
// SQLStoreConfig holds the configuration for the SQLStore.
202
type SQLStoreConfig struct {
203
        // ChainHash is the genesis hash for the chain that all the gossip
204
        // messages in this store are aimed at.
205
        ChainHash chainhash.Hash
206

207
        // QueryConfig holds configuration values for SQL queries.
208
        QueryCfg *sqldb.QueryConfig
209
}
210

211
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
212
// storage backend.
213
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
214
        options ...StoreOptionModifier) (*SQLStore, error) {
×
215

×
216
        opts := DefaultOptions()
×
217
        for _, o := range options {
×
218
                o(opts)
×
219
        }
×
220

221
        if opts.NoMigration {
×
222
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
223
                        "supported for SQL stores")
×
224
        }
×
225

226
        s := &SQLStore{
×
227
                cfg:         cfg,
×
228
                db:          db,
×
229
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
230
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
231
                srcNodes:    make(map[ProtocolVersion]*srcNodeInfo),
×
232
        }
×
233

×
234
        s.chanScheduler = batch.NewTimeScheduler(
×
235
                db, &s.cacheMu, opts.BatchCommitInterval,
×
236
        )
×
237
        s.nodeScheduler = batch.NewTimeScheduler(
×
238
                db, nil, opts.BatchCommitInterval,
×
239
        )
×
240

×
241
        return s, nil
×
242
}
243

244
// AddLightningNode adds a vertex/node to the graph database. If the node is not
245
// in the database from before, this will add a new, unconnected one to the
246
// graph. If it is present from before, this will update that node's
247
// information.
248
//
249
// NOTE: part of the V1Store interface.
250
func (s *SQLStore) AddLightningNode(ctx context.Context,
251
        node *models.LightningNode, opts ...batch.SchedulerOption) error {
×
252

×
253
        r := &batch.Request[SQLQueries]{
×
254
                Opts: batch.NewSchedulerOptions(opts...),
×
255
                Do: func(queries SQLQueries) error {
×
256
                        _, err := upsertNode(ctx, queries, node)
×
257
                        return err
×
258
                },
×
259
        }
260

261
        return s.nodeScheduler.Execute(ctx, r)
×
262
}
263

264
// FetchLightningNode attempts to look up a target node by its identity public
265
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
266
// returned.
267
//
268
// NOTE: part of the V1Store interface.
269
func (s *SQLStore) FetchLightningNode(ctx context.Context,
270
        pubKey route.Vertex) (*models.LightningNode, error) {
×
271

×
272
        var node *models.LightningNode
×
273
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
274
                var err error
×
275
                _, node, err = getNodeByPubKey(ctx, s.cfg.QueryCfg, db, pubKey)
×
276

×
277
                return err
×
278
        }, sqldb.NoOpReset)
×
279
        if err != nil {
×
280
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
281
        }
×
282

283
        return node, nil
×
284
}
285

286
// HasLightningNode determines if the graph has a vertex identified by the
287
// target node identity public key. If the node exists in the database, a
288
// timestamp of when the data for the node was lasted updated is returned along
289
// with a true boolean. Otherwise, an empty time.Time is returned with a false
290
// boolean.
291
//
292
// NOTE: part of the V1Store interface.
293
func (s *SQLStore) HasLightningNode(ctx context.Context,
294
        pubKey [33]byte) (time.Time, bool, error) {
×
295

×
296
        var (
×
297
                exists     bool
×
298
                lastUpdate time.Time
×
299
        )
×
300
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
301
                dbNode, err := db.GetNodeByPubKey(
×
302
                        ctx, sqlc.GetNodeByPubKeyParams{
×
303
                                Version: int16(ProtocolV1),
×
304
                                PubKey:  pubKey[:],
×
305
                        },
×
306
                )
×
307
                if errors.Is(err, sql.ErrNoRows) {
×
308
                        return nil
×
309
                } else if err != nil {
×
310
                        return fmt.Errorf("unable to fetch node: %w", err)
×
311
                }
×
312

313
                exists = true
×
314

×
315
                if dbNode.LastUpdate.Valid {
×
316
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
317
                }
×
318

319
                return nil
×
320
        }, sqldb.NoOpReset)
321
        if err != nil {
×
322
                return time.Time{}, false,
×
323
                        fmt.Errorf("unable to fetch node: %w", err)
×
324
        }
×
325

326
        return lastUpdate, exists, nil
×
327
}
328

329
// AddrsForNode returns all known addresses for the target node public key
330
// that the graph DB is aware of. The returned boolean indicates if the
331
// given node is unknown to the graph DB or not.
332
//
333
// NOTE: part of the V1Store interface.
334
func (s *SQLStore) AddrsForNode(ctx context.Context,
335
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
336

×
337
        var (
×
338
                addresses []net.Addr
×
339
                known     bool
×
340
        )
×
341
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
342
                // First, check if the node exists and get its DB ID if it
×
343
                // does.
×
344
                dbID, err := db.GetNodeIDByPubKey(
×
345
                        ctx, sqlc.GetNodeIDByPubKeyParams{
×
346
                                Version: int16(ProtocolV1),
×
347
                                PubKey:  nodePub.SerializeCompressed(),
×
348
                        },
×
349
                )
×
350
                if errors.Is(err, sql.ErrNoRows) {
×
351
                        return nil
×
352
                }
×
353

354
                known = true
×
355

×
356
                addresses, err = getNodeAddresses(ctx, db, dbID)
×
357
                if err != nil {
×
358
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
359
                                err)
×
360
                }
×
361

362
                return nil
×
363
        }, sqldb.NoOpReset)
364
        if err != nil {
×
365
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
366
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
367
        }
×
368

369
        return known, addresses, nil
×
370
}
371

372
// DeleteLightningNode starts a new database transaction to remove a vertex/node
373
// from the database according to the node's public key.
374
//
375
// NOTE: part of the V1Store interface.
376
func (s *SQLStore) DeleteLightningNode(ctx context.Context,
377
        pubKey route.Vertex) error {
×
378

×
379
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
380
                res, err := db.DeleteNodeByPubKey(
×
381
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
382
                                Version: int16(ProtocolV1),
×
383
                                PubKey:  pubKey[:],
×
384
                        },
×
385
                )
×
386
                if err != nil {
×
387
                        return err
×
388
                }
×
389

390
                rows, err := res.RowsAffected()
×
391
                if err != nil {
×
392
                        return err
×
393
                }
×
394

395
                if rows == 0 {
×
396
                        return ErrGraphNodeNotFound
×
397
                } else if rows > 1 {
×
398
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
399
                }
×
400

401
                return err
×
402
        }, sqldb.NoOpReset)
403
        if err != nil {
×
404
                return fmt.Errorf("unable to delete node: %w", err)
×
405
        }
×
406

407
        return nil
×
408
}
409

410
// FetchNodeFeatures returns the features of the given node. If no features are
411
// known for the node, an empty feature vector is returned.
412
//
413
// NOTE: this is part of the graphdb.NodeTraverser interface.
414
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
415
        *lnwire.FeatureVector, error) {
×
416

×
417
        ctx := context.TODO()
×
418

×
419
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
420
}
×
421

422
// DisabledChannelIDs returns the channel ids of disabled channels.
423
// A channel is disabled when two of the associated ChanelEdgePolicies
424
// have their disabled bit on.
425
//
426
// NOTE: part of the V1Store interface.
427
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
428
        var (
×
429
                ctx     = context.TODO()
×
430
                chanIDs []uint64
×
431
        )
×
432
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
433
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
434
                if err != nil {
×
435
                        return fmt.Errorf("unable to fetch disabled "+
×
436
                                "channels: %w", err)
×
437
                }
×
438

439
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
440

×
441
                return nil
×
442
        }, sqldb.NoOpReset)
443
        if err != nil {
×
444
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
445
                        err)
×
446
        }
×
447

448
        return chanIDs, nil
×
449
}
450

451
// LookupAlias attempts to return the alias as advertised by the target node.
452
//
453
// NOTE: part of the V1Store interface.
454
func (s *SQLStore) LookupAlias(ctx context.Context,
455
        pub *btcec.PublicKey) (string, error) {
×
456

×
457
        var alias string
×
458
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
459
                dbNode, err := db.GetNodeByPubKey(
×
460
                        ctx, sqlc.GetNodeByPubKeyParams{
×
461
                                Version: int16(ProtocolV1),
×
462
                                PubKey:  pub.SerializeCompressed(),
×
463
                        },
×
464
                )
×
465
                if errors.Is(err, sql.ErrNoRows) {
×
466
                        return ErrNodeAliasNotFound
×
467
                } else if err != nil {
×
468
                        return fmt.Errorf("unable to fetch node: %w", err)
×
469
                }
×
470

471
                if !dbNode.Alias.Valid {
×
472
                        return ErrNodeAliasNotFound
×
473
                }
×
474

475
                alias = dbNode.Alias.String
×
476

×
477
                return nil
×
478
        }, sqldb.NoOpReset)
479
        if err != nil {
×
480
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
481
        }
×
482

483
        return alias, nil
×
484
}
485

486
// SourceNode returns the source node of the graph. The source node is treated
487
// as the center node within a star-graph. This method may be used to kick off
488
// a path finding algorithm in order to explore the reachability of another
489
// node based off the source node.
490
//
491
// NOTE: part of the V1Store interface.
492
func (s *SQLStore) SourceNode(ctx context.Context) (*models.LightningNode,
493
        error) {
×
494

×
495
        var node *models.LightningNode
×
496
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
497
                _, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
498
                if err != nil {
×
499
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
500
                                err)
×
501
                }
×
502

503
                _, node, err = getNodeByPubKey(ctx, s.cfg.QueryCfg, db, nodePub)
×
504

×
505
                return err
×
506
        }, sqldb.NoOpReset)
507
        if err != nil {
×
508
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
509
        }
×
510

511
        return node, nil
×
512
}
513

514
// SetSourceNode sets the source node within the graph database. The source
515
// node is to be used as the center of a star-graph within path finding
516
// algorithms.
517
//
518
// NOTE: part of the V1Store interface.
519
func (s *SQLStore) SetSourceNode(ctx context.Context,
520
        node *models.LightningNode) error {
×
521

×
522
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
523
                id, err := upsertNode(ctx, db, node)
×
524
                if err != nil {
×
525
                        return fmt.Errorf("unable to upsert source node: %w",
×
526
                                err)
×
527
                }
×
528

529
                // Make sure that if a source node for this version is already
530
                // set, then the ID is the same as the one we are about to set.
531
                dbSourceNodeID, _, err := s.getSourceNode(ctx, db, ProtocolV1)
×
532
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
533
                        return fmt.Errorf("unable to fetch source node: %w",
×
534
                                err)
×
535
                } else if err == nil {
×
536
                        if dbSourceNodeID != id {
×
537
                                return fmt.Errorf("v1 source node already "+
×
538
                                        "set to a different node: %d vs %d",
×
539
                                        dbSourceNodeID, id)
×
540
                        }
×
541

542
                        return nil
×
543
                }
544

545
                return db.AddSourceNode(ctx, id)
×
546
        }, sqldb.NoOpReset)
547
}
548

549
// NodeUpdatesInHorizon returns all the known lightning node which have an
550
// update timestamp within the passed range. This method can be used by two
551
// nodes to quickly determine if they have the same set of up to date node
552
// announcements.
553
//
554
// NOTE: This is part of the V1Store interface.
555
func (s *SQLStore) NodeUpdatesInHorizon(startTime,
556
        endTime time.Time) ([]models.LightningNode, error) {
×
557

×
558
        ctx := context.TODO()
×
559

×
560
        var nodes []models.LightningNode
×
561
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
562
                dbNodes, err := db.GetNodesByLastUpdateRange(
×
563
                        ctx, sqlc.GetNodesByLastUpdateRangeParams{
×
564
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
565
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
566
                        },
×
567
                )
×
568
                if err != nil {
×
569
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
570
                }
×
571

572
                err = forEachNodeInBatch(
×
573
                        ctx, s.cfg.QueryCfg, db, dbNodes,
×
574
                        func(_ int64, node *models.LightningNode) error {
×
575
                                nodes = append(nodes, *node)
×
576

×
577
                                return nil
×
578
                        },
×
579
                )
580
                if err != nil {
×
581
                        return fmt.Errorf("unable to build nodes: %w", err)
×
582
                }
×
583

584
                return nil
×
585
        }, sqldb.NoOpReset)
586
        if err != nil {
×
587
                return nil, fmt.Errorf("unable to fetch nodes: %w", err)
×
588
        }
×
589

590
        return nodes, nil
×
591
}
592

593
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
594
// undirected edge from the two target nodes are created. The information stored
595
// denotes the static attributes of the channel, such as the channelID, the keys
596
// involved in creation of the channel, and the set of features that the channel
597
// supports. The chanPoint and chanID are used to uniquely identify the edge
598
// globally within the database.
599
//
600
// NOTE: part of the V1Store interface.
601
func (s *SQLStore) AddChannelEdge(ctx context.Context,
602
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
603

×
604
        var alreadyExists bool
×
605
        r := &batch.Request[SQLQueries]{
×
606
                Opts: batch.NewSchedulerOptions(opts...),
×
607
                Reset: func() {
×
608
                        alreadyExists = false
×
609
                },
×
610
                Do: func(tx SQLQueries) error {
×
611
                        chanIDB := channelIDToBytes(edge.ChannelID)
×
612

×
613
                        // Make sure that the channel doesn't already exist. We
×
614
                        // do this explicitly instead of relying on catching a
×
615
                        // unique constraint error because relying on SQL to
×
616
                        // throw that error would abort the entire batch of
×
617
                        // transactions.
×
618
                        _, err := tx.GetChannelBySCID(
×
619
                                ctx, sqlc.GetChannelBySCIDParams{
×
620
                                        Scid:    chanIDB,
×
621
                                        Version: int16(ProtocolV1),
×
622
                                },
×
623
                        )
×
624
                        if err == nil {
×
625
                                alreadyExists = true
×
626
                                return nil
×
627
                        } else if !errors.Is(err, sql.ErrNoRows) {
×
628
                                return fmt.Errorf("unable to fetch channel: %w",
×
629
                                        err)
×
630
                        }
×
631

NEW
632
                        return insertChannel(ctx, tx, edge)
×
633
                },
634
                OnCommit: func(err error) error {
×
635
                        switch {
×
636
                        case err != nil:
×
637
                                return err
×
638
                        case alreadyExists:
×
639
                                return ErrEdgeAlreadyExist
×
640
                        default:
×
641
                                s.rejectCache.remove(edge.ChannelID)
×
642
                                s.chanCache.remove(edge.ChannelID)
×
643
                                return nil
×
644
                        }
645
                },
646
        }
647

648
        return s.chanScheduler.Execute(ctx, r)
×
649
}
650

651
// HighestChanID returns the "highest" known channel ID in the channel graph.
652
// This represents the "newest" channel from the PoV of the chain. This method
653
// can be used by peers to quickly determine if their graphs are in sync.
654
//
655
// NOTE: This is part of the V1Store interface.
656
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
657
        var highestChanID uint64
×
658
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
659
                chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
×
660
                if errors.Is(err, sql.ErrNoRows) {
×
661
                        return nil
×
662
                } else if err != nil {
×
663
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
664
                                err)
×
665
                }
×
666

667
                highestChanID = byteOrder.Uint64(chanID)
×
668

×
669
                return nil
×
670
        }, sqldb.NoOpReset)
671
        if err != nil {
×
672
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
673
        }
×
674

675
        return highestChanID, nil
×
676
}
677

678
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
679
// within the database for the referenced channel. The `flags` attribute within
680
// the ChannelEdgePolicy determines which of the directed edges are being
681
// updated. If the flag is 1, then the first node's information is being
682
// updated, otherwise it's the second node's information. The node ordering is
683
// determined by the lexicographical ordering of the identity public keys of the
684
// nodes on either side of the channel.
685
//
686
// NOTE: part of the V1Store interface.
687
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
688
        edge *models.ChannelEdgePolicy,
689
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
690

×
691
        var (
×
692
                isUpdate1    bool
×
693
                edgeNotFound bool
×
694
                from, to     route.Vertex
×
695
        )
×
696

×
697
        r := &batch.Request[SQLQueries]{
×
698
                Opts: batch.NewSchedulerOptions(opts...),
×
699
                Reset: func() {
×
700
                        isUpdate1 = false
×
701
                        edgeNotFound = false
×
702
                },
×
703
                Do: func(tx SQLQueries) error {
×
704
                        var err error
×
705
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
706
                                ctx, tx, edge,
×
707
                        )
×
708
                        if err != nil {
×
709
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
710
                        }
×
711

712
                        // Silence ErrEdgeNotFound so that the batch can
713
                        // succeed, but propagate the error via local state.
714
                        if errors.Is(err, ErrEdgeNotFound) {
×
715
                                edgeNotFound = true
×
716
                                return nil
×
717
                        }
×
718

719
                        return err
×
720
                },
721
                OnCommit: func(err error) error {
×
722
                        switch {
×
723
                        case err != nil:
×
724
                                return err
×
725
                        case edgeNotFound:
×
726
                                return ErrEdgeNotFound
×
727
                        default:
×
728
                                s.updateEdgeCache(edge, isUpdate1)
×
729
                                return nil
×
730
                        }
731
                },
732
        }
733

734
        err := s.chanScheduler.Execute(ctx, r)
×
735

×
736
        return from, to, err
×
737
}
738

739
// updateEdgeCache updates our reject and channel caches with the new
740
// edge policy information.
741
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
742
        isUpdate1 bool) {
×
743

×
744
        // If an entry for this channel is found in reject cache, we'll modify
×
745
        // the entry with the updated timestamp for the direction that was just
×
746
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
747
        // during the next query for this edge.
×
748
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
749
                if isUpdate1 {
×
750
                        entry.upd1Time = e.LastUpdate.Unix()
×
751
                } else {
×
752
                        entry.upd2Time = e.LastUpdate.Unix()
×
753
                }
×
754
                s.rejectCache.insert(e.ChannelID, entry)
×
755
        }
756

757
        // If an entry for this channel is found in channel cache, we'll modify
758
        // the entry with the updated policy for the direction that was just
759
        // written. If the edge doesn't exist, we'll defer loading the info and
760
        // policies and lazily read from disk during the next query.
761
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
762
                if isUpdate1 {
×
763
                        channel.Policy1 = e
×
764
                } else {
×
765
                        channel.Policy2 = e
×
766
                }
×
767
                s.chanCache.insert(e.ChannelID, channel)
×
768
        }
769
}
770

771
// ForEachSourceNodeChannel iterates through all channels of the source node,
772
// executing the passed callback on each. The call-back is provided with the
773
// channel's outpoint, whether we have a policy for the channel and the channel
774
// peer's node information.
775
//
776
// NOTE: part of the V1Store interface.
777
func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context,
778
        cb func(chanPoint wire.OutPoint, havePolicy bool,
779
                otherNode *models.LightningNode) error, reset func()) error {
×
780

×
781
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
782
                nodeID, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
783
                if err != nil {
×
784
                        return fmt.Errorf("unable to fetch source node: %w",
×
785
                                err)
×
786
                }
×
787

788
                return forEachNodeChannel(
×
789
                        ctx, db, s.cfg, nodeID,
×
790
                        func(info *models.ChannelEdgeInfo,
×
791
                                outPolicy *models.ChannelEdgePolicy,
×
792
                                _ *models.ChannelEdgePolicy) error {
×
793

×
794
                                // Fetch the other node.
×
795
                                var (
×
796
                                        otherNodePub [33]byte
×
797
                                        node1        = info.NodeKey1Bytes
×
798
                                        node2        = info.NodeKey2Bytes
×
799
                                )
×
800
                                switch {
×
801
                                case bytes.Equal(node1[:], nodePub[:]):
×
802
                                        otherNodePub = node2
×
803
                                case bytes.Equal(node2[:], nodePub[:]):
×
804
                                        otherNodePub = node1
×
805
                                default:
×
806
                                        return fmt.Errorf("node not " +
×
807
                                                "participating in this channel")
×
808
                                }
809

810
                                _, otherNode, err := getNodeByPubKey(
×
811
                                        ctx, s.cfg.QueryCfg, db, otherNodePub,
×
812
                                )
×
813
                                if err != nil {
×
814
                                        return fmt.Errorf("unable to fetch "+
×
815
                                                "other node(%x): %w",
×
816
                                                otherNodePub, err)
×
817
                                }
×
818

819
                                return cb(
×
820
                                        info.ChannelPoint, outPolicy != nil,
×
821
                                        otherNode,
×
822
                                )
×
823
                        },
824
                )
825
        }, reset)
826
}
827

828
// ForEachNode iterates through all the stored vertices/nodes in the graph,
829
// executing the passed callback with each node encountered. If the callback
830
// returns an error, then the transaction is aborted and the iteration stops
831
// early.
832
//
833
// NOTE: part of the V1Store interface.
834
func (s *SQLStore) ForEachNode(ctx context.Context,
835
        cb func(node *models.LightningNode) error, reset func()) error {
×
836

×
837
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
838
                return forEachNodePaginated(
×
839
                        ctx, s.cfg.QueryCfg, db,
×
840
                        ProtocolV1, func(_ context.Context, _ int64,
×
841
                                node *models.LightningNode) error {
×
842

×
843
                                return cb(node)
×
844
                        },
×
845
                )
846
        }, reset)
847
}
848

849
// ForEachNodeDirectedChannel iterates through all channels of a given node,
850
// executing the passed callback on the directed edge representing the channel
851
// and its incoming policy. If the callback returns an error, then the iteration
852
// is halted with the error propagated back up to the caller.
853
//
854
// Unknown policies are passed into the callback as nil values.
855
//
856
// NOTE: this is part of the graphdb.NodeTraverser interface.
857
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
858
        cb func(channel *DirectedChannel) error, reset func()) error {
×
859

×
860
        var ctx = context.TODO()
×
861

×
862
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
863
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
864
        }, reset)
×
865
}
866

867
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
868
// graph, executing the passed callback with each node encountered. If the
869
// callback returns an error, then the transaction is aborted and the iteration
870
// stops early.
871
func (s *SQLStore) ForEachNodeCacheable(ctx context.Context,
872
        cb func(route.Vertex, *lnwire.FeatureVector) error,
873
        reset func()) error {
×
874

×
875
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
876
                return forEachNodeCacheable(
×
877
                        ctx, s.cfg.QueryCfg, db,
×
878
                        func(_ int64, nodePub route.Vertex,
×
879
                                features *lnwire.FeatureVector) error {
×
880

×
881
                                return cb(nodePub, features)
×
882
                        },
×
883
                )
884
        }, reset)
885
        if err != nil {
×
886
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
887
        }
×
888

889
        return nil
×
890
}
891

892
// ForEachNodeChannel iterates through all channels of the given node,
893
// executing the passed callback with an edge info structure and the policies
894
// of each end of the channel. The first edge policy is the outgoing edge *to*
895
// the connecting node, while the second is the incoming edge *from* the
896
// connecting node. If the callback returns an error, then the iteration is
897
// halted with the error propagated back up to the caller.
898
//
899
// Unknown policies are passed into the callback as nil values.
900
//
901
// NOTE: part of the V1Store interface.
902
func (s *SQLStore) ForEachNodeChannel(ctx context.Context, nodePub route.Vertex,
903
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
904
                *models.ChannelEdgePolicy) error, reset func()) error {
×
905

×
906
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
907
                dbNode, err := db.GetNodeByPubKey(
×
908
                        ctx, sqlc.GetNodeByPubKeyParams{
×
909
                                Version: int16(ProtocolV1),
×
910
                                PubKey:  nodePub[:],
×
911
                        },
×
912
                )
×
913
                if errors.Is(err, sql.ErrNoRows) {
×
914
                        return nil
×
915
                } else if err != nil {
×
916
                        return fmt.Errorf("unable to fetch node: %w", err)
×
917
                }
×
918

919
                return forEachNodeChannel(ctx, db, s.cfg, dbNode.ID, cb)
×
920
        }, reset)
921
}
922

923
// ChanUpdatesInHorizon returns all the known channel edges which have at least
924
// one edge that has an update timestamp within the specified horizon.
925
//
926
// NOTE: This is part of the V1Store interface.
927
func (s *SQLStore) ChanUpdatesInHorizon(startTime,
928
        endTime time.Time) ([]ChannelEdge, error) {
×
929

×
930
        s.cacheMu.Lock()
×
931
        defer s.cacheMu.Unlock()
×
932

×
933
        var (
×
934
                ctx = context.TODO()
×
935
                // To ensure we don't return duplicate ChannelEdges, we'll use
×
936
                // an additional map to keep track of the edges already seen to
×
937
                // prevent re-adding it.
×
938
                edgesSeen    = make(map[uint64]struct{})
×
939
                edgesToCache = make(map[uint64]ChannelEdge)
×
940
                edges        []ChannelEdge
×
941
                hits         int
×
942
        )
×
943

×
944
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
945
                rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
946
                        ctx, sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
947
                                Version:   int16(ProtocolV1),
×
948
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
949
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
950
                        },
×
951
                )
×
952
                if err != nil {
×
953
                        return err
×
954
                }
×
955

956
                if len(rows) == 0 {
×
957
                        return nil
×
958
                }
×
959

960
                // We'll pre-allocate the slices and maps here with a best
961
                // effort size in order to avoid unnecessary allocations later
962
                // on.
963
                uncachedRows := make(
×
964
                        []sqlc.GetChannelsByPolicyLastUpdateRangeRow, 0,
×
965
                        len(rows),
×
966
                )
×
967
                edgesToCache = make(map[uint64]ChannelEdge, len(rows))
×
968
                edgesSeen = make(map[uint64]struct{}, len(rows))
×
969
                edges = make([]ChannelEdge, 0, len(rows))
×
970

×
971
                // Separate cached from non-cached channels since we will only
×
972
                // batch load the data for the ones we haven't cached yet.
×
973
                for _, row := range rows {
×
974
                        chanIDInt := byteOrder.Uint64(row.GraphChannel.Scid)
×
975

×
976
                        // Skip duplicates.
×
977
                        if _, ok := edgesSeen[chanIDInt]; ok {
×
978
                                continue
×
979
                        }
980
                        edgesSeen[chanIDInt] = struct{}{}
×
981

×
982
                        // Check cache first.
×
983
                        if channel, ok := s.chanCache.get(chanIDInt); ok {
×
984
                                hits++
×
985
                                edges = append(edges, channel)
×
986
                                continue
×
987
                        }
988

989
                        // Mark this row as one we need to batch load data for.
990
                        uncachedRows = append(uncachedRows, row)
×
991
                }
992

993
                // If there are no uncached rows, then we can return early.
994
                if len(uncachedRows) == 0 {
×
995
                        return nil
×
996
                }
×
997

998
                // Batch load data for all uncached channels.
999
                newEdges, err := batchBuildChannelEdges(
×
1000
                        ctx, s.cfg, db, uncachedRows,
×
1001
                )
×
1002
                if err != nil {
×
1003
                        return fmt.Errorf("unable to batch build channel "+
×
1004
                                "edges: %w", err)
×
1005
                }
×
1006

1007
                edges = append(edges, newEdges...)
×
1008

×
1009
                return nil
×
1010
        }, sqldb.NoOpReset)
1011
        if err != nil {
×
1012
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
1013
        }
×
1014

1015
        // Insert any edges loaded from disk into the cache.
1016
        for chanid, channel := range edgesToCache {
×
1017
                s.chanCache.insert(chanid, channel)
×
1018
        }
×
1019

1020
        if len(edges) > 0 {
×
1021
                log.Debugf("ChanUpdatesInHorizon hit percentage: %.2f (%d/%d)",
×
1022
                        float64(hits)*100/float64(len(edges)), hits, len(edges))
×
1023
        } else {
×
1024
                log.Debugf("ChanUpdatesInHorizon returned no edges in "+
×
1025
                        "horizon (%s, %s)", startTime, endTime)
×
1026
        }
×
1027

1028
        return edges, nil
×
1029
}
1030

1031
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1032
// data to the call-back. If withAddrs is true, then the call-back will also be
1033
// provided with the addresses associated with the node. The address retrieval
1034
// result in an additional round-trip to the database, so it should only be used
1035
// if the addresses are actually needed.
1036
//
1037
// NOTE: part of the V1Store interface.
1038
func (s *SQLStore) ForEachNodeCached(ctx context.Context, withAddrs bool,
1039
        cb func(ctx context.Context, node route.Vertex, addrs []net.Addr,
1040
                chans map[uint64]*DirectedChannel) error, reset func()) error {
×
1041

×
1042
        type nodeCachedBatchData struct {
×
1043
                features      map[int64][]int
×
1044
                addrs         map[int64][]nodeAddress
×
1045
                chanBatchData *batchChannelData
×
1046
                chanMap       map[int64][]sqlc.ListChannelsForNodeIDsRow
×
1047
        }
×
1048

×
1049
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1050
                // pageQueryFunc is used to query the next page of nodes.
×
1051
                pageQueryFunc := func(ctx context.Context, lastID int64,
×
1052
                        limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
1053

×
1054
                        return db.ListNodeIDsAndPubKeys(
×
1055
                                ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
1056
                                        Version: int16(ProtocolV1),
×
1057
                                        ID:      lastID,
×
1058
                                        Limit:   limit,
×
1059
                                },
×
1060
                        )
×
1061
                }
×
1062

1063
                // batchDataFunc is then used to batch load the data required
1064
                // for each page of nodes.
1065
                batchDataFunc := func(ctx context.Context,
×
1066
                        nodeIDs []int64) (*nodeCachedBatchData, error) {
×
1067

×
1068
                        // Batch load node features.
×
1069
                        nodeFeatures, err := batchLoadNodeFeaturesHelper(
×
1070
                                ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1071
                        )
×
1072
                        if err != nil {
×
1073
                                return nil, fmt.Errorf("unable to batch load "+
×
1074
                                        "node features: %w", err)
×
1075
                        }
×
1076

1077
                        // Maybe fetch the node's addresses if requested.
1078
                        var nodeAddrs map[int64][]nodeAddress
×
1079
                        if withAddrs {
×
1080
                                nodeAddrs, err = batchLoadNodeAddressesHelper(
×
1081
                                        ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1082
                                )
×
1083
                                if err != nil {
×
1084
                                        return nil, fmt.Errorf("unable to "+
×
1085
                                                "batch load node "+
×
1086
                                                "addresses: %w", err)
×
1087
                                }
×
1088
                        }
1089

1090
                        // Batch load ALL unique channels for ALL nodes in this
1091
                        // page.
1092
                        allChannels, err := db.ListChannelsForNodeIDs(
×
1093
                                ctx, sqlc.ListChannelsForNodeIDsParams{
×
1094
                                        Version:  int16(ProtocolV1),
×
1095
                                        Node1Ids: nodeIDs,
×
1096
                                        Node2Ids: nodeIDs,
×
1097
                                },
×
1098
                        )
×
1099
                        if err != nil {
×
1100
                                return nil, fmt.Errorf("unable to batch "+
×
1101
                                        "fetch channels for nodes: %w", err)
×
1102
                        }
×
1103

1104
                        // Deduplicate channels and collect IDs.
1105
                        var (
×
1106
                                allChannelIDs []int64
×
1107
                                allPolicyIDs  []int64
×
1108
                        )
×
1109
                        uniqueChannels := make(
×
1110
                                map[int64]sqlc.ListChannelsForNodeIDsRow,
×
1111
                        )
×
1112

×
1113
                        for _, channel := range allChannels {
×
1114
                                channelID := channel.GraphChannel.ID
×
1115

×
1116
                                // Only process each unique channel once.
×
1117
                                _, exists := uniqueChannels[channelID]
×
1118
                                if exists {
×
1119
                                        continue
×
1120
                                }
1121

1122
                                uniqueChannels[channelID] = channel
×
1123
                                allChannelIDs = append(allChannelIDs, channelID)
×
1124

×
1125
                                if channel.Policy1ID.Valid {
×
1126
                                        allPolicyIDs = append(
×
1127
                                                allPolicyIDs,
×
1128
                                                channel.Policy1ID.Int64,
×
1129
                                        )
×
1130
                                }
×
1131
                                if channel.Policy2ID.Valid {
×
1132
                                        allPolicyIDs = append(
×
1133
                                                allPolicyIDs,
×
1134
                                                channel.Policy2ID.Int64,
×
1135
                                        )
×
1136
                                }
×
1137
                        }
1138

1139
                        // Batch load channel data for all unique channels.
1140
                        channelBatchData, err := batchLoadChannelData(
×
1141
                                ctx, s.cfg.QueryCfg, db, allChannelIDs,
×
1142
                                allPolicyIDs,
×
1143
                        )
×
1144
                        if err != nil {
×
1145
                                return nil, fmt.Errorf("unable to batch "+
×
1146
                                        "load channel data: %w", err)
×
1147
                        }
×
1148

1149
                        // Create map of node ID to channels that involve this
1150
                        // node.
1151
                        nodeIDSet := make(map[int64]bool)
×
1152
                        for _, nodeID := range nodeIDs {
×
1153
                                nodeIDSet[nodeID] = true
×
1154
                        }
×
1155

1156
                        nodeChannelMap := make(
×
1157
                                map[int64][]sqlc.ListChannelsForNodeIDsRow,
×
1158
                        )
×
1159
                        for _, channel := range uniqueChannels {
×
1160
                                // Add channel to both nodes if they're in our
×
1161
                                // current page.
×
1162
                                node1 := channel.GraphChannel.NodeID1
×
1163
                                if nodeIDSet[node1] {
×
1164
                                        nodeChannelMap[node1] = append(
×
1165
                                                nodeChannelMap[node1], channel,
×
1166
                                        )
×
1167
                                }
×
1168
                                node2 := channel.GraphChannel.NodeID2
×
1169
                                if nodeIDSet[node2] {
×
1170
                                        nodeChannelMap[node2] = append(
×
1171
                                                nodeChannelMap[node2], channel,
×
1172
                                        )
×
1173
                                }
×
1174
                        }
1175

1176
                        return &nodeCachedBatchData{
×
1177
                                features:      nodeFeatures,
×
1178
                                addrs:         nodeAddrs,
×
1179
                                chanBatchData: channelBatchData,
×
1180
                                chanMap:       nodeChannelMap,
×
1181
                        }, nil
×
1182
                }
1183

1184
                // processItem is used to process each node in the current page.
1185
                processItem := func(ctx context.Context,
×
1186
                        nodeData sqlc.ListNodeIDsAndPubKeysRow,
×
1187
                        batchData *nodeCachedBatchData) error {
×
1188

×
1189
                        // Build feature vector for this node.
×
1190
                        fv := lnwire.EmptyFeatureVector()
×
1191
                        features, exists := batchData.features[nodeData.ID]
×
1192
                        if exists {
×
1193
                                for _, bit := range features {
×
1194
                                        fv.Set(lnwire.FeatureBit(bit))
×
1195
                                }
×
1196
                        }
1197

1198
                        var nodePub route.Vertex
×
1199
                        copy(nodePub[:], nodeData.PubKey)
×
1200

×
1201
                        nodeChannels := batchData.chanMap[nodeData.ID]
×
1202

×
1203
                        toNodeCallback := func() route.Vertex {
×
1204
                                return nodePub
×
1205
                        }
×
1206

1207
                        // Build cached channels map for this node.
1208
                        channels := make(map[uint64]*DirectedChannel)
×
1209
                        for _, channelRow := range nodeChannels {
×
1210
                                directedChan, err := buildDirectedChannel(
×
1211
                                        s.cfg.ChainHash, nodeData.ID, nodePub,
×
1212
                                        channelRow, batchData.chanBatchData, fv,
×
1213
                                        toNodeCallback,
×
1214
                                )
×
1215
                                if err != nil {
×
1216
                                        return err
×
1217
                                }
×
1218

1219
                                channels[directedChan.ChannelID] = directedChan
×
1220
                        }
1221

1222
                        addrs, err := buildNodeAddresses(
×
1223
                                batchData.addrs[nodeData.ID],
×
1224
                        )
×
1225
                        if err != nil {
×
1226
                                return fmt.Errorf("unable to build node "+
×
1227
                                        "addresses: %w", err)
×
1228
                        }
×
1229

1230
                        return cb(ctx, nodePub, addrs, channels)
×
1231
                }
1232

1233
                return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
1234
                        ctx, s.cfg.QueryCfg, int64(-1), pageQueryFunc,
×
1235
                        func(node sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
1236
                                return node.ID
×
1237
                        },
×
1238
                        func(node sqlc.ListNodeIDsAndPubKeysRow) (int64,
1239
                                error) {
×
1240

×
1241
                                return node.ID, nil
×
1242
                        },
×
1243
                        batchDataFunc, processItem,
1244
                )
1245
        }, reset)
1246
}
1247

1248
// ForEachChannelCacheable iterates through all the channel edges stored
1249
// within the graph and invokes the passed callback for each edge. The
1250
// callback takes two edges as since this is a directed graph, both the
1251
// in/out edges are visited. If the callback returns an error, then the
1252
// transaction is aborted and the iteration stops early.
1253
//
1254
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1255
// pointer for that particular channel edge routing policy will be
1256
// passed into the callback.
1257
//
1258
// NOTE: this method is like ForEachChannel but fetches only the data
1259
// required for the graph cache.
1260
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1261
        *models.CachedEdgePolicy, *models.CachedEdgePolicy) error,
1262
        reset func()) error {
×
1263

×
1264
        ctx := context.TODO()
×
1265

×
1266
        handleChannel := func(_ context.Context,
×
1267
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) error {
×
1268

×
1269
                node1, node2, err := buildNodeVertices(
×
1270
                        row.Node1Pubkey, row.Node2Pubkey,
×
1271
                )
×
1272
                if err != nil {
×
1273
                        return err
×
1274
                }
×
1275

1276
                edge := buildCacheableChannelInfo(
×
1277
                        row.Scid, row.Capacity.Int64, node1, node2,
×
1278
                )
×
1279

×
1280
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1281
                if err != nil {
×
1282
                        return err
×
1283
                }
×
1284

1285
                pol1, pol2, err := buildCachedChanPolicies(
×
1286
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1287
                )
×
1288
                if err != nil {
×
1289
                        return err
×
1290
                }
×
1291

1292
                return cb(edge, pol1, pol2)
×
1293
        }
1294

1295
        extractCursor := func(
×
1296
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) int64 {
×
1297

×
1298
                return row.ID
×
1299
        }
×
1300

1301
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1302
                //nolint:ll
×
1303
                queryFunc := func(ctx context.Context, lastID int64,
×
1304
                        limit int32) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow,
×
1305
                        error) {
×
1306

×
1307
                        return db.ListChannelsWithPoliciesForCachePaginated(
×
1308
                                ctx, sqlc.ListChannelsWithPoliciesForCachePaginatedParams{
×
1309
                                        Version: int16(ProtocolV1),
×
1310
                                        ID:      lastID,
×
1311
                                        Limit:   limit,
×
1312
                                },
×
1313
                        )
×
1314
                }
×
1315

1316
                return sqldb.ExecutePaginatedQuery(
×
1317
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
1318
                        extractCursor, handleChannel,
×
1319
                )
×
1320
        }, reset)
1321
}
1322

1323
// ForEachChannel iterates through all the channel edges stored within the
1324
// graph and invokes the passed callback for each edge. The callback takes two
1325
// edges as since this is a directed graph, both the in/out edges are visited.
1326
// If the callback returns an error, then the transaction is aborted and the
1327
// iteration stops early.
1328
//
1329
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1330
// for that particular channel edge routing policy will be passed into the
1331
// callback.
1332
//
1333
// NOTE: part of the V1Store interface.
1334
func (s *SQLStore) ForEachChannel(ctx context.Context,
1335
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1336
                *models.ChannelEdgePolicy) error, reset func()) error {
×
1337

×
1338
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1339
                return forEachChannelWithPolicies(ctx, db, s.cfg, cb)
×
1340
        }, reset)
×
1341
}
1342

1343
// FilterChannelRange returns the channel ID's of all known channels which were
1344
// mined in a block height within the passed range. The channel IDs are grouped
1345
// by their common block height. This method can be used to quickly share with a
1346
// peer the set of channels we know of within a particular range to catch them
1347
// up after a period of time offline. If withTimestamps is true then the
1348
// timestamp info of the latest received channel update messages of the channel
1349
// will be included in the response.
1350
//
1351
// NOTE: This is part of the V1Store interface.
1352
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1353
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1354

×
1355
        var (
×
1356
                ctx       = context.TODO()
×
1357
                startSCID = &lnwire.ShortChannelID{
×
1358
                        BlockHeight: startHeight,
×
1359
                }
×
1360
                endSCID = lnwire.ShortChannelID{
×
1361
                        BlockHeight: endHeight,
×
1362
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1363
                        TxPosition:  math.MaxUint16,
×
1364
                }
×
1365
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1366
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1367
        )
×
1368

×
1369
        // 1) get all channels where channelID is between start and end chan ID.
×
1370
        // 2) skip if not public (ie, no channel_proof)
×
1371
        // 3) collect that channel.
×
1372
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1373
        //    and add those timestamps to the collected channel.
×
1374
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1375
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1376
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1377
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1378
                                StartScid: chanIDStart,
×
1379
                                EndScid:   chanIDEnd,
×
1380
                        },
×
1381
                )
×
1382
                if err != nil {
×
1383
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1384
                                err)
×
1385
                }
×
1386

1387
                for _, dbChan := range dbChans {
×
1388
                        cid := lnwire.NewShortChanIDFromInt(
×
1389
                                byteOrder.Uint64(dbChan.Scid),
×
1390
                        )
×
1391
                        chanInfo := NewChannelUpdateInfo(
×
1392
                                cid, time.Time{}, time.Time{},
×
1393
                        )
×
1394

×
1395
                        if !withTimestamps {
×
1396
                                channelsPerBlock[cid.BlockHeight] = append(
×
1397
                                        channelsPerBlock[cid.BlockHeight],
×
1398
                                        chanInfo,
×
1399
                                )
×
1400

×
1401
                                continue
×
1402
                        }
1403

1404
                        //nolint:ll
1405
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1406
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1407
                                        Version:   int16(ProtocolV1),
×
1408
                                        ChannelID: dbChan.ID,
×
1409
                                        NodeID:    dbChan.NodeID1,
×
1410
                                },
×
1411
                        )
×
1412
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1413
                                return fmt.Errorf("unable to fetch node1 "+
×
1414
                                        "policy: %w", err)
×
1415
                        } else if err == nil {
×
1416
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1417
                                        node1Policy.LastUpdate.Int64, 0,
×
1418
                                )
×
1419
                        }
×
1420

1421
                        //nolint:ll
1422
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1423
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1424
                                        Version:   int16(ProtocolV1),
×
1425
                                        ChannelID: dbChan.ID,
×
1426
                                        NodeID:    dbChan.NodeID2,
×
1427
                                },
×
1428
                        )
×
1429
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1430
                                return fmt.Errorf("unable to fetch node2 "+
×
1431
                                        "policy: %w", err)
×
1432
                        } else if err == nil {
×
1433
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1434
                                        node2Policy.LastUpdate.Int64, 0,
×
1435
                                )
×
1436
                        }
×
1437

1438
                        channelsPerBlock[cid.BlockHeight] = append(
×
1439
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1440
                        )
×
1441
                }
1442

1443
                return nil
×
1444
        }, func() {
×
1445
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1446
        })
×
1447
        if err != nil {
×
1448
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1449
        }
×
1450

1451
        if len(channelsPerBlock) == 0 {
×
1452
                return nil, nil
×
1453
        }
×
1454

1455
        // Return the channel ranges in ascending block height order.
1456
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1457
        slices.Sort(blocks)
×
1458

×
1459
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1460
                return BlockChannelRange{
×
1461
                        Height:   block,
×
1462
                        Channels: channelsPerBlock[block],
×
1463
                }
×
1464
        }), nil
×
1465
}
1466

1467
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1468
// zombie. This method is used on an ad-hoc basis, when channels need to be
1469
// marked as zombies outside the normal pruning cycle.
1470
//
1471
// NOTE: part of the V1Store interface.
1472
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1473
        pubKey1, pubKey2 [33]byte) error {
×
1474

×
1475
        ctx := context.TODO()
×
1476

×
1477
        s.cacheMu.Lock()
×
1478
        defer s.cacheMu.Unlock()
×
1479

×
1480
        chanIDB := channelIDToBytes(chanID)
×
1481

×
1482
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1483
                return db.UpsertZombieChannel(
×
1484
                        ctx, sqlc.UpsertZombieChannelParams{
×
1485
                                Version:  int16(ProtocolV1),
×
1486
                                Scid:     chanIDB,
×
1487
                                NodeKey1: pubKey1[:],
×
1488
                                NodeKey2: pubKey2[:],
×
1489
                        },
×
1490
                )
×
1491
        }, sqldb.NoOpReset)
×
1492
        if err != nil {
×
1493
                return fmt.Errorf("unable to upsert zombie channel "+
×
1494
                        "(channel_id=%d): %w", chanID, err)
×
1495
        }
×
1496

1497
        s.rejectCache.remove(chanID)
×
1498
        s.chanCache.remove(chanID)
×
1499

×
1500
        return nil
×
1501
}
1502

1503
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1504
//
1505
// NOTE: part of the V1Store interface.
1506
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1507
        s.cacheMu.Lock()
×
1508
        defer s.cacheMu.Unlock()
×
1509

×
1510
        var (
×
1511
                ctx     = context.TODO()
×
1512
                chanIDB = channelIDToBytes(chanID)
×
1513
        )
×
1514

×
1515
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1516
                res, err := db.DeleteZombieChannel(
×
1517
                        ctx, sqlc.DeleteZombieChannelParams{
×
1518
                                Scid:    chanIDB,
×
1519
                                Version: int16(ProtocolV1),
×
1520
                        },
×
1521
                )
×
1522
                if err != nil {
×
1523
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1524
                                err)
×
1525
                }
×
1526

1527
                rows, err := res.RowsAffected()
×
1528
                if err != nil {
×
1529
                        return err
×
1530
                }
×
1531

1532
                if rows == 0 {
×
1533
                        return ErrZombieEdgeNotFound
×
1534
                } else if rows > 1 {
×
1535
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1536
                                "expected 1", rows)
×
1537
                }
×
1538

1539
                return nil
×
1540
        }, sqldb.NoOpReset)
1541
        if err != nil {
×
1542
                return fmt.Errorf("unable to mark edge live "+
×
1543
                        "(channel_id=%d): %w", chanID, err)
×
1544
        }
×
1545

1546
        s.rejectCache.remove(chanID)
×
1547
        s.chanCache.remove(chanID)
×
1548

×
1549
        return err
×
1550
}
1551

1552
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1553
// zombie, then the two node public keys corresponding to this edge are also
1554
// returned.
1555
//
1556
// NOTE: part of the V1Store interface.
1557
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
1558
        error) {
×
1559

×
1560
        var (
×
1561
                ctx              = context.TODO()
×
1562
                isZombie         bool
×
1563
                pubKey1, pubKey2 route.Vertex
×
1564
                chanIDB          = channelIDToBytes(chanID)
×
1565
        )
×
1566

×
1567
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1568
                zombie, err := db.GetZombieChannel(
×
1569
                        ctx, sqlc.GetZombieChannelParams{
×
1570
                                Scid:    chanIDB,
×
1571
                                Version: int16(ProtocolV1),
×
1572
                        },
×
1573
                )
×
1574
                if errors.Is(err, sql.ErrNoRows) {
×
1575
                        return nil
×
1576
                }
×
1577
                if err != nil {
×
1578
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1579
                                err)
×
1580
                }
×
1581

1582
                copy(pubKey1[:], zombie.NodeKey1)
×
1583
                copy(pubKey2[:], zombie.NodeKey2)
×
1584
                isZombie = true
×
1585

×
1586
                return nil
×
1587
        }, sqldb.NoOpReset)
1588
        if err != nil {
×
1589
                return false, route.Vertex{}, route.Vertex{},
×
1590
                        fmt.Errorf("%w: %w (chanID=%d)",
×
1591
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
×
1592
        }
×
1593

1594
        return isZombie, pubKey1, pubKey2, nil
×
1595
}
1596

1597
// NumZombies returns the current number of zombie channels in the graph.
1598
//
1599
// NOTE: part of the V1Store interface.
1600
func (s *SQLStore) NumZombies() (uint64, error) {
×
1601
        var (
×
1602
                ctx        = context.TODO()
×
1603
                numZombies uint64
×
1604
        )
×
1605
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1606
                count, err := db.CountZombieChannels(ctx, int16(ProtocolV1))
×
1607
                if err != nil {
×
1608
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1609
                                err)
×
1610
                }
×
1611

1612
                numZombies = uint64(count)
×
1613

×
1614
                return nil
×
1615
        }, sqldb.NoOpReset)
1616
        if err != nil {
×
1617
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1618
        }
×
1619

1620
        return numZombies, nil
×
1621
}
1622

1623
// DeleteChannelEdges removes edges with the given channel IDs from the
1624
// database and marks them as zombies. This ensures that we're unable to re-add
1625
// it to our database once again. If an edge does not exist within the
1626
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1627
// true, then when we mark these edges as zombies, we'll set up the keys such
1628
// that we require the node that failed to send the fresh update to be the one
1629
// that resurrects the channel from its zombie state. The markZombie bool
1630
// denotes whether to mark the channel as a zombie.
1631
//
1632
// NOTE: part of the V1Store interface.
1633
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1634
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
1635

×
1636
        s.cacheMu.Lock()
×
1637
        defer s.cacheMu.Unlock()
×
1638

×
1639
        // Keep track of which channels we end up finding so that we can
×
1640
        // correctly return ErrEdgeNotFound if we do not find a channel.
×
1641
        chanLookup := make(map[uint64]struct{}, len(chanIDs))
×
1642
        for _, chanID := range chanIDs {
×
1643
                chanLookup[chanID] = struct{}{}
×
1644
        }
×
1645

1646
        var (
×
1647
                ctx   = context.TODO()
×
1648
                edges []*models.ChannelEdgeInfo
×
1649
        )
×
1650
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1651
                // First, collect all channel rows.
×
1652
                var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow
×
1653
                chanCallBack := func(ctx context.Context,
×
1654
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
1655

×
1656
                        // Deleting the entry from the map indicates that we
×
1657
                        // have found the channel.
×
1658
                        scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1659
                        delete(chanLookup, scid)
×
1660

×
1661
                        channelRows = append(channelRows, row)
×
1662

×
1663
                        return nil
×
1664
                }
×
1665

1666
                err := s.forEachChanWithPoliciesInSCIDList(
×
1667
                        ctx, db, chanCallBack, chanIDs,
×
1668
                )
×
1669
                if err != nil {
×
1670
                        return err
×
1671
                }
×
1672

1673
                if len(chanLookup) > 0 {
×
1674
                        return ErrEdgeNotFound
×
1675
                }
×
1676

1677
                if len(channelRows) == 0 {
×
1678
                        return nil
×
1679
                }
×
1680

1681
                // Batch build all channel edges.
1682
                var chanIDsToDelete []int64
×
1683
                edges, chanIDsToDelete, err = batchBuildChannelInfo(
×
1684
                        ctx, s.cfg, db, channelRows,
×
1685
                )
×
1686
                if err != nil {
×
1687
                        return err
×
1688
                }
×
1689

1690
                if markZombie {
×
1691
                        for i, row := range channelRows {
×
1692
                                scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1693

×
1694
                                err := handleZombieMarking(
×
1695
                                        ctx, db, row, edges[i],
×
1696
                                        strictZombiePruning, scid,
×
1697
                                )
×
1698
                                if err != nil {
×
1699
                                        return fmt.Errorf("unable to mark "+
×
1700
                                                "channel as zombie: %w", err)
×
1701
                                }
×
1702
                        }
1703
                }
1704

1705
                return s.deleteChannels(ctx, db, chanIDsToDelete)
×
1706
        }, func() {
×
1707
                edges = nil
×
1708

×
1709
                // Re-fill the lookup map.
×
1710
                for _, chanID := range chanIDs {
×
1711
                        chanLookup[chanID] = struct{}{}
×
1712
                }
×
1713
        })
1714
        if err != nil {
×
1715
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1716
                        err)
×
1717
        }
×
1718

1719
        for _, chanID := range chanIDs {
×
1720
                s.rejectCache.remove(chanID)
×
1721
                s.chanCache.remove(chanID)
×
1722
        }
×
1723

1724
        return edges, nil
×
1725
}
1726

1727
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1728
// channel identified by the channel ID. If the channel can't be found, then
1729
// ErrEdgeNotFound is returned. A struct which houses the general information
1730
// for the channel itself is returned as well as two structs that contain the
1731
// routing policies for the channel in either direction.
1732
//
1733
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1734
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1735
// the ChannelEdgeInfo will only include the public keys of each node.
1736
//
1737
// NOTE: part of the V1Store interface.
1738
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1739
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1740
        *models.ChannelEdgePolicy, error) {
×
1741

×
1742
        var (
×
1743
                ctx              = context.TODO()
×
1744
                edge             *models.ChannelEdgeInfo
×
1745
                policy1, policy2 *models.ChannelEdgePolicy
×
1746
                chanIDB          = channelIDToBytes(chanID)
×
1747
        )
×
1748
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1749
                row, err := db.GetChannelBySCIDWithPolicies(
×
1750
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1751
                                Scid:    chanIDB,
×
1752
                                Version: int16(ProtocolV1),
×
1753
                        },
×
1754
                )
×
1755
                if errors.Is(err, sql.ErrNoRows) {
×
1756
                        // First check if this edge is perhaps in the zombie
×
1757
                        // index.
×
1758
                        zombie, err := db.GetZombieChannel(
×
1759
                                ctx, sqlc.GetZombieChannelParams{
×
1760
                                        Scid:    chanIDB,
×
1761
                                        Version: int16(ProtocolV1),
×
1762
                                },
×
1763
                        )
×
1764
                        if errors.Is(err, sql.ErrNoRows) {
×
1765
                                return ErrEdgeNotFound
×
1766
                        } else if err != nil {
×
1767
                                return fmt.Errorf("unable to check if "+
×
1768
                                        "channel is zombie: %w", err)
×
1769
                        }
×
1770

1771
                        // At this point, we know the channel is a zombie, so
1772
                        // we'll return an error indicating this, and we will
1773
                        // populate the edge info with the public keys of each
1774
                        // party as this is the only information we have about
1775
                        // it.
1776
                        edge = &models.ChannelEdgeInfo{}
×
1777
                        copy(edge.NodeKey1Bytes[:], zombie.NodeKey1)
×
1778
                        copy(edge.NodeKey2Bytes[:], zombie.NodeKey2)
×
1779

×
1780
                        return ErrZombieEdge
×
1781
                } else if err != nil {
×
1782
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1783
                }
×
1784

1785
                node1, node2, err := buildNodeVertices(
×
1786
                        row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
1787
                )
×
1788
                if err != nil {
×
1789
                        return err
×
1790
                }
×
1791

1792
                edge, err = getAndBuildEdgeInfo(
×
1793
                        ctx, s.cfg, db, row.GraphChannel, node1, node2,
×
1794
                )
×
1795
                if err != nil {
×
1796
                        return fmt.Errorf("unable to build channel info: %w",
×
1797
                                err)
×
1798
                }
×
1799

1800
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1801
                if err != nil {
×
1802
                        return fmt.Errorf("unable to extract channel "+
×
1803
                                "policies: %w", err)
×
1804
                }
×
1805

1806
                policy1, policy2, err = getAndBuildChanPolicies(
×
1807
                        ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, edge.ChannelID,
×
1808
                        node1, node2,
×
1809
                )
×
1810
                if err != nil {
×
1811
                        return fmt.Errorf("unable to build channel "+
×
1812
                                "policies: %w", err)
×
1813
                }
×
1814

1815
                return nil
×
1816
        }, sqldb.NoOpReset)
1817
        if err != nil {
×
1818
                // If we are returning the ErrZombieEdge, then we also need to
×
1819
                // return the edge info as the method comment indicates that
×
1820
                // this will be populated when the edge is a zombie.
×
1821
                return edge, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1822
                        err)
×
1823
        }
×
1824

1825
        return edge, policy1, policy2, nil
×
1826
}
1827

1828
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
1829
// the channel identified by the funding outpoint. If the channel can't be
1830
// found, then ErrEdgeNotFound is returned. A struct which houses the general
1831
// information for the channel itself is returned as well as two structs that
1832
// contain the routing policies for the channel in either direction.
1833
//
1834
// NOTE: part of the V1Store interface.
1835
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
1836
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1837
        *models.ChannelEdgePolicy, error) {
×
1838

×
1839
        var (
×
1840
                ctx              = context.TODO()
×
1841
                edge             *models.ChannelEdgeInfo
×
1842
                policy1, policy2 *models.ChannelEdgePolicy
×
1843
        )
×
1844
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1845
                row, err := db.GetChannelByOutpointWithPolicies(
×
1846
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
1847
                                Outpoint: op.String(),
×
1848
                                Version:  int16(ProtocolV1),
×
1849
                        },
×
1850
                )
×
1851
                if errors.Is(err, sql.ErrNoRows) {
×
1852
                        return ErrEdgeNotFound
×
1853
                } else if err != nil {
×
1854
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1855
                }
×
1856

1857
                node1, node2, err := buildNodeVertices(
×
1858
                        row.Node1Pubkey, row.Node2Pubkey,
×
1859
                )
×
1860
                if err != nil {
×
1861
                        return err
×
1862
                }
×
1863

1864
                edge, err = getAndBuildEdgeInfo(
×
1865
                        ctx, s.cfg, db, row.GraphChannel, node1, node2,
×
1866
                )
×
1867
                if err != nil {
×
1868
                        return fmt.Errorf("unable to build channel info: %w",
×
1869
                                err)
×
1870
                }
×
1871

1872
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1873
                if err != nil {
×
1874
                        return fmt.Errorf("unable to extract channel "+
×
1875
                                "policies: %w", err)
×
1876
                }
×
1877

1878
                policy1, policy2, err = getAndBuildChanPolicies(
×
1879
                        ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, edge.ChannelID,
×
1880
                        node1, node2,
×
1881
                )
×
1882
                if err != nil {
×
1883
                        return fmt.Errorf("unable to build channel "+
×
1884
                                "policies: %w", err)
×
1885
                }
×
1886

1887
                return nil
×
1888
        }, sqldb.NoOpReset)
1889
        if err != nil {
×
1890
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1891
                        err)
×
1892
        }
×
1893

1894
        return edge, policy1, policy2, nil
×
1895
}
1896

1897
// HasChannelEdge returns true if the database knows of a channel edge with the
1898
// passed channel ID, and false otherwise. If an edge with that ID is found
1899
// within the graph, then two time stamps representing the last time the edge
1900
// was updated for both directed edges are returned along with the boolean. If
1901
// it is not found, then the zombie index is checked and its result is returned
1902
// as the second boolean.
1903
//
1904
// NOTE: part of the V1Store interface.
1905
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
1906
        bool, error) {
×
1907

×
1908
        ctx := context.TODO()
×
1909

×
1910
        var (
×
1911
                exists          bool
×
1912
                isZombie        bool
×
1913
                node1LastUpdate time.Time
×
1914
                node2LastUpdate time.Time
×
1915
        )
×
1916

×
1917
        // We'll query the cache with the shared lock held to allow multiple
×
1918
        // readers to access values in the cache concurrently if they exist.
×
1919
        s.cacheMu.RLock()
×
1920
        if entry, ok := s.rejectCache.get(chanID); ok {
×
1921
                s.cacheMu.RUnlock()
×
1922
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
1923
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
1924
                exists, isZombie = entry.flags.unpack()
×
1925

×
1926
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
1927
        }
×
1928
        s.cacheMu.RUnlock()
×
1929

×
1930
        s.cacheMu.Lock()
×
1931
        defer s.cacheMu.Unlock()
×
1932

×
1933
        // The item was not found with the shared lock, so we'll acquire the
×
1934
        // exclusive lock and check the cache again in case another method added
×
1935
        // the entry to the cache while no lock was held.
×
1936
        if entry, ok := s.rejectCache.get(chanID); ok {
×
1937
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
1938
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
1939
                exists, isZombie = entry.flags.unpack()
×
1940

×
1941
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
1942
        }
×
1943

1944
        chanIDB := channelIDToBytes(chanID)
×
1945
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1946
                channel, err := db.GetChannelBySCID(
×
1947
                        ctx, sqlc.GetChannelBySCIDParams{
×
1948
                                Scid:    chanIDB,
×
1949
                                Version: int16(ProtocolV1),
×
1950
                        },
×
1951
                )
×
1952
                if errors.Is(err, sql.ErrNoRows) {
×
1953
                        // Check if it is a zombie channel.
×
1954
                        isZombie, err = db.IsZombieChannel(
×
1955
                                ctx, sqlc.IsZombieChannelParams{
×
1956
                                        Scid:    chanIDB,
×
1957
                                        Version: int16(ProtocolV1),
×
1958
                                },
×
1959
                        )
×
1960
                        if err != nil {
×
1961
                                return fmt.Errorf("could not check if channel "+
×
1962
                                        "is zombie: %w", err)
×
1963
                        }
×
1964

1965
                        return nil
×
1966
                } else if err != nil {
×
1967
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1968
                }
×
1969

1970
                exists = true
×
1971

×
1972
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
1973
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1974
                                Version:   int16(ProtocolV1),
×
1975
                                ChannelID: channel.ID,
×
1976
                                NodeID:    channel.NodeID1,
×
1977
                        },
×
1978
                )
×
1979
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1980
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
1981
                                err)
×
1982
                } else if err == nil {
×
1983
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
1984
                }
×
1985

1986
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
1987
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1988
                                Version:   int16(ProtocolV1),
×
1989
                                ChannelID: channel.ID,
×
1990
                                NodeID:    channel.NodeID2,
×
1991
                        },
×
1992
                )
×
1993
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1994
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
1995
                                err)
×
1996
                } else if err == nil {
×
1997
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
1998
                }
×
1999

2000
                return nil
×
2001
        }, sqldb.NoOpReset)
2002
        if err != nil {
×
2003
                return time.Time{}, time.Time{}, false, false,
×
2004
                        fmt.Errorf("unable to fetch channel: %w", err)
×
2005
        }
×
2006

2007
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
2008
                upd1Time: node1LastUpdate.Unix(),
×
2009
                upd2Time: node2LastUpdate.Unix(),
×
2010
                flags:    packRejectFlags(exists, isZombie),
×
2011
        })
×
2012

×
2013
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2014
}
2015

2016
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2017
// passed channel point (outpoint). If the passed channel doesn't exist within
2018
// the database, then ErrEdgeNotFound is returned.
2019
//
2020
// NOTE: part of the V1Store interface.
2021
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
2022
        var (
×
2023
                ctx       = context.TODO()
×
2024
                channelID uint64
×
2025
        )
×
2026
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2027
                chanID, err := db.GetSCIDByOutpoint(
×
2028
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2029
                                Outpoint: chanPoint.String(),
×
2030
                                Version:  int16(ProtocolV1),
×
2031
                        },
×
2032
                )
×
2033
                if errors.Is(err, sql.ErrNoRows) {
×
2034
                        return ErrEdgeNotFound
×
2035
                } else if err != nil {
×
2036
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2037
                                err)
×
2038
                }
×
2039

2040
                channelID = byteOrder.Uint64(chanID)
×
2041

×
2042
                return nil
×
2043
        }, sqldb.NoOpReset)
2044
        if err != nil {
×
2045
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2046
        }
×
2047

2048
        return channelID, nil
×
2049
}
2050

2051
// IsPublicNode is a helper method that determines whether the node with the
2052
// given public key is seen as a public node in the graph from the graph's
2053
// source node's point of view.
2054
//
2055
// NOTE: part of the V1Store interface.
2056
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
2057
        ctx := context.TODO()
×
2058

×
2059
        var isPublic bool
×
2060
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2061
                var err error
×
2062
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2063

×
2064
                return err
×
2065
        }, sqldb.NoOpReset)
×
2066
        if err != nil {
×
2067
                return false, fmt.Errorf("unable to check if node is "+
×
2068
                        "public: %w", err)
×
2069
        }
×
2070

2071
        return isPublic, nil
×
2072
}
2073

2074
// FetchChanInfos returns the set of channel edges that correspond to the passed
2075
// channel ID's. If an edge is the query is unknown to the database, it will
2076
// skipped and the result will contain only those edges that exist at the time
2077
// of the query. This can be used to respond to peer queries that are seeking to
2078
// fill in gaps in their view of the channel graph.
2079
//
2080
// NOTE: part of the V1Store interface.
2081
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2082
        var (
×
2083
                ctx   = context.TODO()
×
2084
                edges = make(map[uint64]ChannelEdge)
×
2085
        )
×
2086
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2087
                // First, collect all channel rows.
×
2088
                var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow
×
2089
                chanCallBack := func(ctx context.Context,
×
2090
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
2091

×
2092
                        channelRows = append(channelRows, row)
×
2093
                        return nil
×
2094
                }
×
2095

2096
                err := s.forEachChanWithPoliciesInSCIDList(
×
2097
                        ctx, db, chanCallBack, chanIDs,
×
2098
                )
×
2099
                if err != nil {
×
2100
                        return err
×
2101
                }
×
2102

2103
                if len(channelRows) == 0 {
×
2104
                        return nil
×
2105
                }
×
2106

2107
                // Batch build all channel edges.
2108
                chans, err := batchBuildChannelEdges(
×
2109
                        ctx, s.cfg, db, channelRows,
×
2110
                )
×
2111
                if err != nil {
×
2112
                        return fmt.Errorf("unable to build channel edges: %w",
×
2113
                                err)
×
2114
                }
×
2115

2116
                for _, c := range chans {
×
2117
                        edges[c.Info.ChannelID] = c
×
2118
                }
×
2119

2120
                return err
×
2121
        }, func() {
×
2122
                clear(edges)
×
2123
        })
×
2124
        if err != nil {
×
2125
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2126
        }
×
2127

2128
        res := make([]ChannelEdge, 0, len(edges))
×
2129
        for _, chanID := range chanIDs {
×
2130
                edge, ok := edges[chanID]
×
2131
                if !ok {
×
2132
                        continue
×
2133
                }
2134

2135
                res = append(res, edge)
×
2136
        }
2137

2138
        return res, nil
×
2139
}
2140

2141
// forEachChanWithPoliciesInSCIDList is a wrapper around the
2142
// GetChannelsBySCIDWithPolicies query that allows us to iterate through
2143
// channels in a paginated manner.
2144
func (s *SQLStore) forEachChanWithPoliciesInSCIDList(ctx context.Context,
2145
        db SQLQueries, cb func(ctx context.Context,
2146
                row sqlc.GetChannelsBySCIDWithPoliciesRow) error,
2147
        chanIDs []uint64) error {
×
2148

×
2149
        queryWrapper := func(ctx context.Context,
×
2150
                scids [][]byte) ([]sqlc.GetChannelsBySCIDWithPoliciesRow,
×
2151
                error) {
×
2152

×
2153
                return db.GetChannelsBySCIDWithPolicies(
×
2154
                        ctx, sqlc.GetChannelsBySCIDWithPoliciesParams{
×
2155
                                Version: int16(ProtocolV1),
×
2156
                                Scids:   scids,
×
2157
                        },
×
2158
                )
×
2159
        }
×
2160

2161
        return sqldb.ExecuteBatchQuery(
×
2162
                ctx, s.cfg.QueryCfg, chanIDs, channelIDToBytes, queryWrapper,
×
2163
                cb,
×
2164
        )
×
2165
}
2166

2167
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2168
// ID's that we don't know and are not known zombies of the passed set. In other
2169
// words, we perform a set difference of our set of chan ID's and the ones
2170
// passed in. This method can be used by callers to determine the set of
2171
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2172
// known zombies is also returned.
2173
//
2174
// NOTE: part of the V1Store interface.
2175
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
2176
        []ChannelUpdateInfo, error) {
×
2177

×
2178
        var (
×
2179
                ctx          = context.TODO()
×
2180
                newChanIDs   []uint64
×
2181
                knownZombies []ChannelUpdateInfo
×
2182
                infoLookup   = make(
×
2183
                        map[uint64]ChannelUpdateInfo, len(chansInfo),
×
2184
                )
×
2185
        )
×
2186

×
2187
        // We first build a lookup map of the channel ID's to the
×
2188
        // ChannelUpdateInfo. This allows us to quickly delete channels that we
×
2189
        // already know about.
×
2190
        for _, chanInfo := range chansInfo {
×
2191
                infoLookup[chanInfo.ShortChannelID.ToUint64()] = chanInfo
×
2192
        }
×
2193

2194
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2195
                // The call-back function deletes known channels from
×
2196
                // infoLookup, so that we can later check which channels are
×
2197
                // zombies by only looking at the remaining channels in the set.
×
2198
                cb := func(ctx context.Context,
×
2199
                        channel sqlc.GraphChannel) error {
×
2200

×
2201
                        delete(infoLookup, byteOrder.Uint64(channel.Scid))
×
2202

×
2203
                        return nil
×
2204
                }
×
2205

2206
                err := s.forEachChanInSCIDList(ctx, db, cb, chansInfo)
×
2207
                if err != nil {
×
2208
                        return fmt.Errorf("unable to iterate through "+
×
2209
                                "channels: %w", err)
×
2210
                }
×
2211

2212
                // We want to ensure that we deal with the channels in the
2213
                // same order that they were passed in, so we iterate over the
2214
                // original chansInfo slice and then check if that channel is
2215
                // still in the infoLookup map.
2216
                for _, chanInfo := range chansInfo {
×
2217
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
2218
                        if _, ok := infoLookup[channelID]; !ok {
×
2219
                                continue
×
2220
                        }
2221

2222
                        isZombie, err := db.IsZombieChannel(
×
2223
                                ctx, sqlc.IsZombieChannelParams{
×
2224
                                        Scid:    channelIDToBytes(channelID),
×
2225
                                        Version: int16(ProtocolV1),
×
2226
                                },
×
2227
                        )
×
2228
                        if err != nil {
×
2229
                                return fmt.Errorf("unable to fetch zombie "+
×
2230
                                        "channel: %w", err)
×
2231
                        }
×
2232

2233
                        if isZombie {
×
2234
                                knownZombies = append(knownZombies, chanInfo)
×
2235

×
2236
                                continue
×
2237
                        }
2238

2239
                        newChanIDs = append(newChanIDs, channelID)
×
2240
                }
2241

2242
                return nil
×
2243
        }, func() {
×
2244
                newChanIDs = nil
×
2245
                knownZombies = nil
×
2246
                // Rebuild the infoLookup map in case of a rollback.
×
2247
                for _, chanInfo := range chansInfo {
×
2248
                        scid := chanInfo.ShortChannelID.ToUint64()
×
2249
                        infoLookup[scid] = chanInfo
×
2250
                }
×
2251
        })
2252
        if err != nil {
×
2253
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2254
        }
×
2255

2256
        return newChanIDs, knownZombies, nil
×
2257
}
2258

2259
// forEachChanInSCIDList is a helper method that executes a paged query
2260
// against the database to fetch all channels that match the passed
2261
// ChannelUpdateInfo slice. The callback function is called for each channel
2262
// that is found.
2263
func (s *SQLStore) forEachChanInSCIDList(ctx context.Context, db SQLQueries,
2264
        cb func(ctx context.Context, channel sqlc.GraphChannel) error,
2265
        chansInfo []ChannelUpdateInfo) error {
×
2266

×
2267
        queryWrapper := func(ctx context.Context,
×
2268
                scids [][]byte) ([]sqlc.GraphChannel, error) {
×
2269

×
2270
                return db.GetChannelsBySCIDs(
×
2271
                        ctx, sqlc.GetChannelsBySCIDsParams{
×
2272
                                Version: int16(ProtocolV1),
×
2273
                                Scids:   scids,
×
2274
                        },
×
2275
                )
×
2276
        }
×
2277

2278
        chanIDConverter := func(chanInfo ChannelUpdateInfo) []byte {
×
2279
                channelID := chanInfo.ShortChannelID.ToUint64()
×
2280

×
2281
                return channelIDToBytes(channelID)
×
2282
        }
×
2283

2284
        return sqldb.ExecuteBatchQuery(
×
2285
                ctx, s.cfg.QueryCfg, chansInfo, chanIDConverter, queryWrapper,
×
2286
                cb,
×
2287
        )
×
2288
}
2289

2290
// PruneGraphNodes is a garbage collection method which attempts to prune out
2291
// any nodes from the channel graph that are currently unconnected. This ensure
2292
// that we only maintain a graph of reachable nodes. In the event that a pruned
2293
// node gains more channels, it will be re-added back to the graph.
2294
//
2295
// NOTE: this prunes nodes across protocol versions. It will never prune the
2296
// source nodes.
2297
//
2298
// NOTE: part of the V1Store interface.
2299
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
2300
        var ctx = context.TODO()
×
2301

×
2302
        var prunedNodes []route.Vertex
×
2303
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2304
                var err error
×
2305
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2306

×
2307
                return err
×
2308
        }, func() {
×
2309
                prunedNodes = nil
×
2310
        })
×
2311
        if err != nil {
×
2312
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2313
        }
×
2314

2315
        return prunedNodes, nil
×
2316
}
2317

2318
// PruneGraph prunes newly closed channels from the channel graph in response
2319
// to a new block being solved on the network. Any transactions which spend the
2320
// funding output of any known channels within he graph will be deleted.
2321
// Additionally, the "prune tip", or the last block which has been used to
2322
// prune the graph is stored so callers can ensure the graph is fully in sync
2323
// with the current UTXO state. A slice of channels that have been closed by
2324
// the target block along with any pruned nodes are returned if the function
2325
// succeeds without error.
2326
//
2327
// NOTE: part of the V1Store interface.
2328
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2329
        blockHash *chainhash.Hash, blockHeight uint32) (
2330
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
2331

×
2332
        ctx := context.TODO()
×
2333

×
2334
        s.cacheMu.Lock()
×
2335
        defer s.cacheMu.Unlock()
×
2336

×
2337
        var (
×
2338
                closedChans []*models.ChannelEdgeInfo
×
2339
                prunedNodes []route.Vertex
×
2340
        )
×
2341
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2342
                // First, collect all channel rows that need to be pruned.
×
2343
                var channelRows []sqlc.GetChannelsByOutpointsRow
×
2344
                channelCallback := func(ctx context.Context,
×
2345
                        row sqlc.GetChannelsByOutpointsRow) error {
×
2346

×
2347
                        channelRows = append(channelRows, row)
×
2348

×
2349
                        return nil
×
2350
                }
×
2351

2352
                err := s.forEachChanInOutpoints(
×
2353
                        ctx, db, spentOutputs, channelCallback,
×
2354
                )
×
2355
                if err != nil {
×
2356
                        return fmt.Errorf("unable to fetch channels by "+
×
2357
                                "outpoints: %w", err)
×
2358
                }
×
2359

2360
                if len(channelRows) == 0 {
×
2361
                        // There are no channels to prune. So we can exit early
×
2362
                        // after updating the prune log.
×
2363
                        err = db.UpsertPruneLogEntry(
×
2364
                                ctx, sqlc.UpsertPruneLogEntryParams{
×
2365
                                        BlockHash:   blockHash[:],
×
2366
                                        BlockHeight: int64(blockHeight),
×
2367
                                },
×
2368
                        )
×
2369
                        if err != nil {
×
2370
                                return fmt.Errorf("unable to insert prune log "+
×
2371
                                        "entry: %w", err)
×
2372
                        }
×
2373

2374
                        return nil
×
2375
                }
2376

2377
                // Batch build all channel edges for pruning.
2378
                var chansToDelete []int64
×
2379
                closedChans, chansToDelete, err = batchBuildChannelInfo(
×
2380
                        ctx, s.cfg, db, channelRows,
×
2381
                )
×
2382
                if err != nil {
×
2383
                        return err
×
2384
                }
×
2385

2386
                err = s.deleteChannels(ctx, db, chansToDelete)
×
2387
                if err != nil {
×
2388
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2389
                }
×
2390

2391
                err = db.UpsertPruneLogEntry(
×
2392
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
2393
                                BlockHash:   blockHash[:],
×
2394
                                BlockHeight: int64(blockHeight),
×
2395
                        },
×
2396
                )
×
2397
                if err != nil {
×
2398
                        return fmt.Errorf("unable to insert prune log "+
×
2399
                                "entry: %w", err)
×
2400
                }
×
2401

2402
                // Now that we've pruned some channels, we'll also prune any
2403
                // nodes that no longer have any channels.
2404
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2405
                if err != nil {
×
2406
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2407
                                err)
×
2408
                }
×
2409

2410
                return nil
×
2411
        }, func() {
×
2412
                prunedNodes = nil
×
2413
                closedChans = nil
×
2414
        })
×
2415
        if err != nil {
×
2416
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2417
        }
×
2418

2419
        for _, channel := range closedChans {
×
2420
                s.rejectCache.remove(channel.ChannelID)
×
2421
                s.chanCache.remove(channel.ChannelID)
×
2422
        }
×
2423

2424
        return closedChans, prunedNodes, nil
×
2425
}
2426

2427
// forEachChanInOutpoints is a helper function that executes a paginated
2428
// query to fetch channels by their outpoints and applies the given call-back
2429
// to each.
2430
//
2431
// NOTE: this fetches channels for all protocol versions.
2432
func (s *SQLStore) forEachChanInOutpoints(ctx context.Context, db SQLQueries,
2433
        outpoints []*wire.OutPoint, cb func(ctx context.Context,
2434
                row sqlc.GetChannelsByOutpointsRow) error) error {
×
2435

×
2436
        // Create a wrapper that uses the transaction's db instance to execute
×
2437
        // the query.
×
2438
        queryWrapper := func(ctx context.Context,
×
2439
                pageOutpoints []string) ([]sqlc.GetChannelsByOutpointsRow,
×
2440
                error) {
×
2441

×
2442
                return db.GetChannelsByOutpoints(ctx, pageOutpoints)
×
2443
        }
×
2444

2445
        // Define the conversion function from Outpoint to string.
2446
        outpointToString := func(outpoint *wire.OutPoint) string {
×
2447
                return outpoint.String()
×
2448
        }
×
2449

2450
        return sqldb.ExecuteBatchQuery(
×
2451
                ctx, s.cfg.QueryCfg, outpoints, outpointToString,
×
2452
                queryWrapper, cb,
×
2453
        )
×
2454
}
2455

2456
func (s *SQLStore) deleteChannels(ctx context.Context, db SQLQueries,
2457
        dbIDs []int64) error {
×
2458

×
2459
        // Create a wrapper that uses the transaction's db instance to execute
×
2460
        // the query.
×
2461
        queryWrapper := func(ctx context.Context, ids []int64) ([]any, error) {
×
2462
                return nil, db.DeleteChannels(ctx, ids)
×
2463
        }
×
2464

2465
        idConverter := func(id int64) int64 {
×
2466
                return id
×
2467
        }
×
2468

2469
        return sqldb.ExecuteBatchQuery(
×
2470
                ctx, s.cfg.QueryCfg, dbIDs, idConverter,
×
2471
                queryWrapper, func(ctx context.Context, _ any) error {
×
2472
                        return nil
×
2473
                },
×
2474
        )
2475
}
2476

2477
// ChannelView returns the verifiable edge information for each active channel
2478
// within the known channel graph. The set of UTXOs (along with their scripts)
2479
// returned are the ones that need to be watched on chain to detect channel
2480
// closes on the resident blockchain.
2481
//
2482
// NOTE: part of the V1Store interface.
2483
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
2484
        var (
×
2485
                ctx        = context.TODO()
×
2486
                edgePoints []EdgePoint
×
2487
        )
×
2488

×
2489
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2490
                handleChannel := func(_ context.Context,
×
2491
                        channel sqlc.ListChannelsPaginatedRow) error {
×
2492

×
2493
                        pkScript, err := genMultiSigP2WSH(
×
2494
                                channel.BitcoinKey1, channel.BitcoinKey2,
×
2495
                        )
×
2496
                        if err != nil {
×
2497
                                return err
×
2498
                        }
×
2499

2500
                        op, err := wire.NewOutPointFromString(channel.Outpoint)
×
2501
                        if err != nil {
×
2502
                                return err
×
2503
                        }
×
2504

2505
                        edgePoints = append(edgePoints, EdgePoint{
×
2506
                                FundingPkScript: pkScript,
×
2507
                                OutPoint:        *op,
×
2508
                        })
×
2509

×
2510
                        return nil
×
2511
                }
2512

2513
                queryFunc := func(ctx context.Context, lastID int64,
×
2514
                        limit int32) ([]sqlc.ListChannelsPaginatedRow, error) {
×
2515

×
2516
                        return db.ListChannelsPaginated(
×
2517
                                ctx, sqlc.ListChannelsPaginatedParams{
×
2518
                                        Version: int16(ProtocolV1),
×
2519
                                        ID:      lastID,
×
2520
                                        Limit:   limit,
×
2521
                                },
×
2522
                        )
×
2523
                }
×
2524

2525
                extractCursor := func(row sqlc.ListChannelsPaginatedRow) int64 {
×
2526
                        return row.ID
×
2527
                }
×
2528

2529
                return sqldb.ExecutePaginatedQuery(
×
2530
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
2531
                        extractCursor, handleChannel,
×
2532
                )
×
2533
        }, func() {
×
2534
                edgePoints = nil
×
2535
        })
×
2536
        if err != nil {
×
2537
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
2538
        }
×
2539

2540
        return edgePoints, nil
×
2541
}
2542

2543
// PruneTip returns the block height and hash of the latest block that has been
2544
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2545
// to tell if the graph is currently in sync with the current best known UTXO
2546
// state.
2547
//
2548
// NOTE: part of the V1Store interface.
2549
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
2550
        var (
×
2551
                ctx       = context.TODO()
×
2552
                tipHash   chainhash.Hash
×
2553
                tipHeight uint32
×
2554
        )
×
2555
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2556
                pruneTip, err := db.GetPruneTip(ctx)
×
2557
                if errors.Is(err, sql.ErrNoRows) {
×
2558
                        return ErrGraphNeverPruned
×
2559
                } else if err != nil {
×
2560
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
2561
                }
×
2562

2563
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
2564
                tipHeight = uint32(pruneTip.BlockHeight)
×
2565

×
2566
                return nil
×
2567
        }, sqldb.NoOpReset)
2568
        if err != nil {
×
2569
                return nil, 0, err
×
2570
        }
×
2571

2572
        return &tipHash, tipHeight, nil
×
2573
}
2574

2575
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2576
//
2577
// NOTE: this prunes nodes across protocol versions. It will never prune the
2578
// source nodes.
2579
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
2580
        db SQLQueries) ([]route.Vertex, error) {
×
2581

×
2582
        nodeKeys, err := db.DeleteUnconnectedNodes(ctx)
×
2583
        if err != nil {
×
2584
                return nil, fmt.Errorf("unable to delete unconnected "+
×
2585
                        "nodes: %w", err)
×
2586
        }
×
2587

2588
        prunedNodes := make([]route.Vertex, len(nodeKeys))
×
2589
        for i, nodeKey := range nodeKeys {
×
2590
                pub, err := route.NewVertexFromBytes(nodeKey)
×
2591
                if err != nil {
×
2592
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
2593
                                "from bytes: %w", err)
×
2594
                }
×
2595

2596
                prunedNodes[i] = pub
×
2597
        }
2598

2599
        return prunedNodes, nil
×
2600
}
2601

2602
// DisconnectBlockAtHeight is used to indicate that the block specified
2603
// by the passed height has been disconnected from the main chain. This
2604
// will "rewind" the graph back to the height below, deleting channels
2605
// that are no longer confirmed from the graph. The prune log will be
2606
// set to the last prune height valid for the remaining chain.
2607
// Channels that were removed from the graph resulting from the
2608
// disconnected block are returned.
2609
//
2610
// NOTE: part of the V1Store interface.
2611
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
2612
        []*models.ChannelEdgeInfo, error) {
×
2613

×
2614
        ctx := context.TODO()
×
2615

×
2616
        var (
×
2617
                // Every channel having a ShortChannelID starting at 'height'
×
2618
                // will no longer be confirmed.
×
2619
                startShortChanID = lnwire.ShortChannelID{
×
2620
                        BlockHeight: height,
×
2621
                }
×
2622

×
2623
                // Delete everything after this height from the db up until the
×
2624
                // SCID alias range.
×
2625
                endShortChanID = aliasmgr.StartingAlias
×
2626

×
2627
                removedChans []*models.ChannelEdgeInfo
×
2628

×
2629
                chanIDStart = channelIDToBytes(startShortChanID.ToUint64())
×
2630
                chanIDEnd   = channelIDToBytes(endShortChanID.ToUint64())
×
2631
        )
×
2632

×
2633
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2634
                rows, err := db.GetChannelsBySCIDRange(
×
2635
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
2636
                                StartScid: chanIDStart,
×
2637
                                EndScid:   chanIDEnd,
×
2638
                        },
×
2639
                )
×
2640
                if err != nil {
×
2641
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2642
                }
×
2643

2644
                if len(rows) == 0 {
×
2645
                        // No channels to disconnect, but still clean up prune
×
2646
                        // log.
×
2647
                        return db.DeletePruneLogEntriesInRange(
×
2648
                                ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2649
                                        StartHeight: int64(height),
×
2650
                                        EndHeight: int64(
×
2651
                                                endShortChanID.BlockHeight,
×
2652
                                        ),
×
2653
                                },
×
2654
                        )
×
2655
                }
×
2656

2657
                // Batch build all channel edges for disconnection.
2658
                channelEdges, chanIDsToDelete, err := batchBuildChannelInfo(
×
2659
                        ctx, s.cfg, db, rows,
×
2660
                )
×
2661
                if err != nil {
×
2662
                        return err
×
2663
                }
×
2664

2665
                removedChans = channelEdges
×
2666

×
2667
                err = s.deleteChannels(ctx, db, chanIDsToDelete)
×
2668
                if err != nil {
×
2669
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2670
                }
×
2671

2672
                return db.DeletePruneLogEntriesInRange(
×
2673
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2674
                                StartHeight: int64(height),
×
2675
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
2676
                        },
×
2677
                )
×
2678
        }, func() {
×
2679
                removedChans = nil
×
2680
        })
×
2681
        if err != nil {
×
2682
                return nil, fmt.Errorf("unable to disconnect block at "+
×
2683
                        "height: %w", err)
×
2684
        }
×
2685

2686
        for _, channel := range removedChans {
×
2687
                s.rejectCache.remove(channel.ChannelID)
×
2688
                s.chanCache.remove(channel.ChannelID)
×
2689
        }
×
2690

2691
        return removedChans, nil
×
2692
}
2693

2694
// AddEdgeProof sets the proof of an existing edge in the graph database.
2695
//
2696
// NOTE: part of the V1Store interface.
2697
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
2698
        proof *models.ChannelAuthProof) error {
×
2699

×
2700
        var (
×
2701
                ctx       = context.TODO()
×
2702
                scidBytes = channelIDToBytes(scid.ToUint64())
×
2703
        )
×
2704

×
2705
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2706
                res, err := db.AddV1ChannelProof(
×
2707
                        ctx, sqlc.AddV1ChannelProofParams{
×
2708
                                Scid:              scidBytes,
×
2709
                                Node1Signature:    proof.NodeSig1Bytes,
×
2710
                                Node2Signature:    proof.NodeSig2Bytes,
×
2711
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
2712
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
2713
                        },
×
2714
                )
×
2715
                if err != nil {
×
2716
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
2717
                }
×
2718

2719
                n, err := res.RowsAffected()
×
2720
                if err != nil {
×
2721
                        return err
×
2722
                }
×
2723

2724
                if n == 0 {
×
2725
                        return fmt.Errorf("no rows affected when adding edge "+
×
2726
                                "proof for SCID %v", scid)
×
2727
                } else if n > 1 {
×
2728
                        return fmt.Errorf("multiple rows affected when adding "+
×
2729
                                "edge proof for SCID %v: %d rows affected",
×
2730
                                scid, n)
×
2731
                }
×
2732

2733
                return nil
×
2734
        }, sqldb.NoOpReset)
2735
        if err != nil {
×
2736
                return fmt.Errorf("unable to add edge proof: %w", err)
×
2737
        }
×
2738

2739
        return nil
×
2740
}
2741

2742
// PutClosedScid stores a SCID for a closed channel in the database. This is so
2743
// that we can ignore channel announcements that we know to be closed without
2744
// having to validate them and fetch a block.
2745
//
2746
// NOTE: part of the V1Store interface.
2747
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
2748
        var (
×
2749
                ctx     = context.TODO()
×
2750
                chanIDB = channelIDToBytes(scid.ToUint64())
×
2751
        )
×
2752

×
2753
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2754
                return db.InsertClosedChannel(ctx, chanIDB)
×
2755
        }, sqldb.NoOpReset)
×
2756
}
2757

2758
// IsClosedScid checks whether a channel identified by the passed in scid is
2759
// closed. This helps avoid having to perform expensive validation checks.
2760
//
2761
// NOTE: part of the V1Store interface.
2762
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
2763
        var (
×
2764
                ctx      = context.TODO()
×
2765
                isClosed bool
×
2766
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
2767
        )
×
2768
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2769
                var err error
×
2770
                isClosed, err = db.IsClosedChannel(ctx, chanIDB)
×
2771
                if err != nil {
×
2772
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
2773
                                err)
×
2774
                }
×
2775

2776
                return nil
×
2777
        }, sqldb.NoOpReset)
2778
        if err != nil {
×
2779
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
2780
                        err)
×
2781
        }
×
2782

2783
        return isClosed, nil
×
2784
}
2785

2786
// GraphSession will provide the call-back with access to a NodeTraverser
2787
// instance which can be used to perform queries against the channel graph.
2788
//
2789
// NOTE: part of the V1Store interface.
2790
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error,
2791
        reset func()) error {
×
2792

×
2793
        var ctx = context.TODO()
×
2794

×
2795
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2796
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
2797
        }, reset)
×
2798
}
2799

2800
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
2801
// read only transaction for a consistent view of the graph.
2802
type sqlNodeTraverser struct {
2803
        db    SQLQueries
2804
        chain chainhash.Hash
2805
}
2806

2807
// A compile-time assertion to ensure that sqlNodeTraverser implements the
2808
// NodeTraverser interface.
2809
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
2810

2811
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
2812
func newSQLNodeTraverser(db SQLQueries,
2813
        chain chainhash.Hash) *sqlNodeTraverser {
×
2814

×
2815
        return &sqlNodeTraverser{
×
2816
                db:    db,
×
2817
                chain: chain,
×
2818
        }
×
2819
}
×
2820

2821
// ForEachNodeDirectedChannel calls the callback for every channel of the given
2822
// node.
2823
//
2824
// NOTE: Part of the NodeTraverser interface.
2825
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
2826
        cb func(channel *DirectedChannel) error, _ func()) error {
×
2827

×
2828
        ctx := context.TODO()
×
2829

×
2830
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
2831
}
×
2832

2833
// FetchNodeFeatures returns the features of the given node. If the node is
2834
// unknown, assume no additional features are supported.
2835
//
2836
// NOTE: Part of the NodeTraverser interface.
2837
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
2838
        *lnwire.FeatureVector, error) {
×
2839

×
2840
        ctx := context.TODO()
×
2841

×
2842
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
2843
}
×
2844

2845
// forEachNodeDirectedChannel iterates through all channels of a given
2846
// node, executing the passed callback on the directed edge representing the
2847
// channel and its incoming policy. If the node is not found, no error is
2848
// returned.
2849
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
2850
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
2851

×
2852
        toNodeCallback := func() route.Vertex {
×
2853
                return nodePub
×
2854
        }
×
2855

2856
        dbID, err := db.GetNodeIDByPubKey(
×
2857
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
2858
                        Version: int16(ProtocolV1),
×
2859
                        PubKey:  nodePub[:],
×
2860
                },
×
2861
        )
×
2862
        if errors.Is(err, sql.ErrNoRows) {
×
2863
                return nil
×
2864
        } else if err != nil {
×
2865
                return fmt.Errorf("unable to fetch node: %w", err)
×
2866
        }
×
2867

2868
        rows, err := db.ListChannelsByNodeID(
×
2869
                ctx, sqlc.ListChannelsByNodeIDParams{
×
2870
                        Version: int16(ProtocolV1),
×
2871
                        NodeID1: dbID,
×
2872
                },
×
2873
        )
×
2874
        if err != nil {
×
2875
                return fmt.Errorf("unable to fetch channels: %w", err)
×
2876
        }
×
2877

2878
        // Exit early if there are no channels for this node so we don't
2879
        // do the unnecessary feature fetching.
2880
        if len(rows) == 0 {
×
2881
                return nil
×
2882
        }
×
2883

2884
        features, err := getNodeFeatures(ctx, db, dbID)
×
2885
        if err != nil {
×
2886
                return fmt.Errorf("unable to fetch node features: %w", err)
×
2887
        }
×
2888

2889
        for _, row := range rows {
×
2890
                node1, node2, err := buildNodeVertices(
×
2891
                        row.Node1Pubkey, row.Node2Pubkey,
×
2892
                )
×
2893
                if err != nil {
×
2894
                        return fmt.Errorf("unable to build node vertices: %w",
×
2895
                                err)
×
2896
                }
×
2897

2898
                edge := buildCacheableChannelInfo(
×
2899
                        row.GraphChannel.Scid, row.GraphChannel.Capacity.Int64,
×
2900
                        node1, node2,
×
2901
                )
×
2902

×
2903
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2904
                if err != nil {
×
2905
                        return err
×
2906
                }
×
2907

2908
                p1, p2, err := buildCachedChanPolicies(
×
2909
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
2910
                )
×
2911
                if err != nil {
×
2912
                        return err
×
2913
                }
×
2914

2915
                // Determine the outgoing and incoming policy for this
2916
                // channel and node combo.
2917
                outPolicy, inPolicy := p1, p2
×
2918
                if p1 != nil && node2 == nodePub {
×
2919
                        outPolicy, inPolicy = p2, p1
×
2920
                } else if p2 != nil && node1 != nodePub {
×
2921
                        outPolicy, inPolicy = p2, p1
×
2922
                }
×
2923

2924
                var cachedInPolicy *models.CachedEdgePolicy
×
2925
                if inPolicy != nil {
×
2926
                        cachedInPolicy = inPolicy
×
2927
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
2928
                        cachedInPolicy.ToNodeFeatures = features
×
2929
                }
×
2930

2931
                directedChannel := &DirectedChannel{
×
2932
                        ChannelID:    edge.ChannelID,
×
2933
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
2934
                        OtherNode:    edge.NodeKey2Bytes,
×
2935
                        Capacity:     edge.Capacity,
×
2936
                        OutPolicySet: outPolicy != nil,
×
2937
                        InPolicy:     cachedInPolicy,
×
2938
                }
×
2939
                if outPolicy != nil {
×
2940
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
2941
                                directedChannel.InboundFee = fee
×
2942
                        })
×
2943
                }
2944

2945
                if nodePub == edge.NodeKey2Bytes {
×
2946
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
2947
                }
×
2948

2949
                if err := cb(directedChannel); err != nil {
×
2950
                        return err
×
2951
                }
×
2952
        }
2953

2954
        return nil
×
2955
}
2956

2957
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
2958
// and executes the provided callback for each node. It does so via pagination
2959
// along with batch loading of the node feature bits.
2960
func forEachNodeCacheable(ctx context.Context, cfg *sqldb.QueryConfig,
2961
        db SQLQueries, processNode func(nodeID int64, nodePub route.Vertex,
2962
                features *lnwire.FeatureVector) error) error {
×
2963

×
2964
        handleNode := func(_ context.Context,
×
2965
                dbNode sqlc.ListNodeIDsAndPubKeysRow,
×
2966
                featureBits map[int64][]int) error {
×
2967

×
2968
                fv := lnwire.EmptyFeatureVector()
×
2969
                if features, exists := featureBits[dbNode.ID]; exists {
×
2970
                        for _, bit := range features {
×
2971
                                fv.Set(lnwire.FeatureBit(bit))
×
2972
                        }
×
2973
                }
2974

2975
                var pub route.Vertex
×
2976
                copy(pub[:], dbNode.PubKey)
×
2977

×
2978
                return processNode(dbNode.ID, pub, fv)
×
2979
        }
2980

2981
        queryFunc := func(ctx context.Context, lastID int64,
×
2982
                limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
2983

×
2984
                return db.ListNodeIDsAndPubKeys(
×
2985
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
2986
                                Version: int16(ProtocolV1),
×
2987
                                ID:      lastID,
×
2988
                                Limit:   limit,
×
2989
                        },
×
2990
                )
×
2991
        }
×
2992

2993
        extractCursor := func(row sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
2994
                return row.ID
×
2995
        }
×
2996

2997
        collectFunc := func(node sqlc.ListNodeIDsAndPubKeysRow) (int64, error) {
×
2998
                return node.ID, nil
×
2999
        }
×
3000

3001
        batchQueryFunc := func(ctx context.Context,
×
3002
                nodeIDs []int64) (map[int64][]int, error) {
×
3003

×
3004
                return batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
3005
        }
×
3006

3007
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
3008
                ctx, cfg, int64(-1), queryFunc, extractCursor, collectFunc,
×
3009
                batchQueryFunc, handleNode,
×
3010
        )
×
3011
}
3012

3013
// forEachNodeChannel iterates through all channels of a node, executing
3014
// the passed callback on each. The call-back is provided with the channel's
3015
// edge information, the outgoing policy and the incoming policy for the
3016
// channel and node combo.
3017
func forEachNodeChannel(ctx context.Context, db SQLQueries,
3018
        cfg *SQLStoreConfig, id int64, cb func(*models.ChannelEdgeInfo,
3019
                *models.ChannelEdgePolicy,
3020
                *models.ChannelEdgePolicy) error) error {
×
3021

×
3022
        // Get all the V1 channels for this node.
×
3023
        rows, err := db.ListChannelsByNodeID(
×
3024
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3025
                        Version: int16(ProtocolV1),
×
3026
                        NodeID1: id,
×
3027
                },
×
3028
        )
×
3029
        if err != nil {
×
3030
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3031
        }
×
3032

3033
        // Collect all the channel and policy IDs.
3034
        var (
×
3035
                chanIDs   = make([]int64, 0, len(rows))
×
3036
                policyIDs = make([]int64, 0, 2*len(rows))
×
3037
        )
×
3038
        for _, row := range rows {
×
3039
                chanIDs = append(chanIDs, row.GraphChannel.ID)
×
3040

×
3041
                if row.Policy1ID.Valid {
×
3042
                        policyIDs = append(policyIDs, row.Policy1ID.Int64)
×
3043
                }
×
3044
                if row.Policy2ID.Valid {
×
3045
                        policyIDs = append(policyIDs, row.Policy2ID.Int64)
×
3046
                }
×
3047
        }
3048

3049
        batchData, err := batchLoadChannelData(
×
3050
                ctx, cfg.QueryCfg, db, chanIDs, policyIDs,
×
3051
        )
×
3052
        if err != nil {
×
3053
                return fmt.Errorf("unable to batch load channel data: %w", err)
×
3054
        }
×
3055

3056
        // Call the call-back for each channel and its known policies.
3057
        for _, row := range rows {
×
3058
                node1, node2, err := buildNodeVertices(
×
3059
                        row.Node1Pubkey, row.Node2Pubkey,
×
3060
                )
×
3061
                if err != nil {
×
3062
                        return fmt.Errorf("unable to build node vertices: %w",
×
3063
                                err)
×
3064
                }
×
3065

3066
                edge, err := buildEdgeInfoWithBatchData(
×
3067
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
3068
                        batchData,
×
3069
                )
×
3070
                if err != nil {
×
3071
                        return fmt.Errorf("unable to build channel info: %w",
×
3072
                                err)
×
3073
                }
×
3074

3075
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3076
                if err != nil {
×
3077
                        return fmt.Errorf("unable to extract channel "+
×
3078
                                "policies: %w", err)
×
3079
                }
×
3080

3081
                p1, p2, err := buildChanPoliciesWithBatchData(
×
3082
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
3083
                )
×
3084
                if err != nil {
×
3085
                        return fmt.Errorf("unable to build channel "+
×
3086
                                "policies: %w", err)
×
3087
                }
×
3088

3089
                // Determine the outgoing and incoming policy for this
3090
                // channel and node combo.
3091
                p1ToNode := row.GraphChannel.NodeID2
×
3092
                p2ToNode := row.GraphChannel.NodeID1
×
3093
                outPolicy, inPolicy := p1, p2
×
3094
                if (p1 != nil && p1ToNode == id) ||
×
3095
                        (p2 != nil && p2ToNode != id) {
×
3096

×
3097
                        outPolicy, inPolicy = p2, p1
×
3098
                }
×
3099

3100
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3101
                        return err
×
3102
                }
×
3103
        }
3104

3105
        return nil
×
3106
}
3107

3108
// updateChanEdgePolicy upserts the channel policy info we have stored for
3109
// a channel we already know of.
3110
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3111
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
3112
        error) {
×
3113

×
3114
        var (
×
3115
                node1Pub, node2Pub route.Vertex
×
3116
                isNode1            bool
×
3117
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3118
        )
×
3119

×
3120
        // Check that this edge policy refers to a channel that we already
×
3121
        // know of. We do this explicitly so that we can return the appropriate
×
3122
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
3123
        // abort the transaction which would abort the entire batch.
×
3124
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3125
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3126
                        Scid:    chanIDB,
×
3127
                        Version: int16(ProtocolV1),
×
3128
                },
×
3129
        )
×
3130
        if errors.Is(err, sql.ErrNoRows) {
×
3131
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3132
        } else if err != nil {
×
3133
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3134
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3135
        }
×
3136

3137
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3138
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3139

×
3140
        // Figure out which node this edge is from.
×
3141
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3142
        nodeID := dbChan.NodeID1
×
3143
        if !isNode1 {
×
3144
                nodeID = dbChan.NodeID2
×
3145
        }
×
3146

3147
        var (
×
3148
                inboundBase sql.NullInt64
×
3149
                inboundRate sql.NullInt64
×
3150
        )
×
3151
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3152
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3153
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3154
        })
×
3155

3156
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
3157
                Version:     int16(ProtocolV1),
×
3158
                ChannelID:   dbChan.ID,
×
3159
                NodeID:      nodeID,
×
3160
                Timelock:    int32(edge.TimeLockDelta),
×
3161
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
3162
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
3163
                MinHtlcMsat: int64(edge.MinHTLC),
×
3164
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3165
                Disabled: sql.NullBool{
×
3166
                        Valid: true,
×
3167
                        Bool:  edge.IsDisabled(),
×
3168
                },
×
3169
                MaxHtlcMsat: sql.NullInt64{
×
3170
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3171
                        Int64: int64(edge.MaxHTLC),
×
3172
                },
×
3173
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
3174
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
3175
                InboundBaseFeeMsat:      inboundBase,
×
3176
                InboundFeeRateMilliMsat: inboundRate,
×
3177
                Signature:               edge.SigBytes,
×
3178
        })
×
3179
        if err != nil {
×
3180
                return node1Pub, node2Pub, isNode1,
×
3181
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3182
        }
×
3183

3184
        // Convert the flat extra opaque data into a map of TLV types to
3185
        // values.
3186
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3187
        if err != nil {
×
3188
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3189
                        "marshal extra opaque data: %w", err)
×
3190
        }
×
3191

3192
        // Update the channel policy's extra signed fields.
3193
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3194
        if err != nil {
×
3195
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3196
                        "policy extra TLVs: %w", err)
×
3197
        }
×
3198

3199
        return node1Pub, node2Pub, isNode1, nil
×
3200
}
3201

3202
// getNodeByPubKey attempts to look up a target node by its public key.
3203
func getNodeByPubKey(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries,
3204
        pubKey route.Vertex) (int64, *models.LightningNode, error) {
×
3205

×
3206
        dbNode, err := db.GetNodeByPubKey(
×
3207
                ctx, sqlc.GetNodeByPubKeyParams{
×
3208
                        Version: int16(ProtocolV1),
×
3209
                        PubKey:  pubKey[:],
×
3210
                },
×
3211
        )
×
3212
        if errors.Is(err, sql.ErrNoRows) {
×
3213
                return 0, nil, ErrGraphNodeNotFound
×
3214
        } else if err != nil {
×
3215
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3216
        }
×
3217

3218
        node, err := buildNode(ctx, cfg, db, dbNode)
×
3219
        if err != nil {
×
3220
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3221
        }
×
3222

3223
        return dbNode.ID, node, nil
×
3224
}
3225

3226
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3227
// provided parameters.
3228
func buildCacheableChannelInfo(scid []byte, capacity int64, node1Pub,
3229
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
3230

×
3231
        return &models.CachedEdgeInfo{
×
3232
                ChannelID:     byteOrder.Uint64(scid),
×
3233
                NodeKey1Bytes: node1Pub,
×
3234
                NodeKey2Bytes: node2Pub,
×
3235
                Capacity:      btcutil.Amount(capacity),
×
3236
        }
×
3237
}
×
3238

3239
// buildNode constructs a LightningNode instance from the given database node
3240
// record. The node's features, addresses and extra signed fields are also
3241
// fetched from the database and set on the node.
3242
func buildNode(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries,
3243
        dbNode sqlc.GraphNode) (*models.LightningNode, error) {
×
3244

×
3245
        data, err := batchLoadNodeData(ctx, cfg, db, []int64{dbNode.ID})
×
3246
        if err != nil {
×
3247
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
3248
                        err)
×
3249
        }
×
3250

3251
        return buildNodeWithBatchData(dbNode, data)
×
3252
}
3253

3254
// buildNodeWithBatchData builds a models.LightningNode instance
3255
// from the provided sqlc.GraphNode and batchNodeData. If the node does have
3256
// features/addresses/extra fields, then the corresponding fields are expected
3257
// to be present in the batchNodeData.
3258
func buildNodeWithBatchData(dbNode sqlc.GraphNode,
3259
        batchData *batchNodeData) (*models.LightningNode, error) {
×
3260

×
3261
        if dbNode.Version != int16(ProtocolV1) {
×
3262
                return nil, fmt.Errorf("unsupported node version: %d",
×
3263
                        dbNode.Version)
×
3264
        }
×
3265

3266
        var pub [33]byte
×
3267
        copy(pub[:], dbNode.PubKey)
×
3268

×
3269
        node := &models.LightningNode{
×
3270
                PubKeyBytes: pub,
×
3271
                Features:    lnwire.EmptyFeatureVector(),
×
3272
                LastUpdate:  time.Unix(0, 0),
×
3273
        }
×
3274

×
3275
        if len(dbNode.Signature) == 0 {
×
3276
                return node, nil
×
3277
        }
×
3278

3279
        node.HaveNodeAnnouncement = true
×
3280
        node.AuthSigBytes = dbNode.Signature
×
3281
        node.Alias = dbNode.Alias.String
×
3282
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
3283

×
3284
        var err error
×
3285
        if dbNode.Color.Valid {
×
3286
                node.Color, err = DecodeHexColor(dbNode.Color.String)
×
3287
                if err != nil {
×
3288
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3289
                                err)
×
3290
                }
×
3291
        }
3292

3293
        // Use preloaded features.
3294
        if features, exists := batchData.features[dbNode.ID]; exists {
×
3295
                fv := lnwire.EmptyFeatureVector()
×
3296
                for _, bit := range features {
×
3297
                        fv.Set(lnwire.FeatureBit(bit))
×
3298
                }
×
3299
                node.Features = fv
×
3300
        }
3301

3302
        // Use preloaded addresses.
3303
        addresses, exists := batchData.addresses[dbNode.ID]
×
3304
        if exists && len(addresses) > 0 {
×
3305
                node.Addresses, err = buildNodeAddresses(addresses)
×
3306
                if err != nil {
×
3307
                        return nil, fmt.Errorf("unable to build addresses "+
×
3308
                                "for node(%d): %w", dbNode.ID, err)
×
3309
                }
×
3310
        }
3311

3312
        // Use preloaded extra fields.
3313
        if extraFields, exists := batchData.extraFields[dbNode.ID]; exists {
×
3314
                recs, err := lnwire.CustomRecords(extraFields).Serialize()
×
3315
                if err != nil {
×
3316
                        return nil, fmt.Errorf("unable to serialize extra "+
×
3317
                                "signed fields: %w", err)
×
3318
                }
×
3319
                if len(recs) != 0 {
×
3320
                        node.ExtraOpaqueData = recs
×
3321
                }
×
3322
        }
3323

3324
        return node, nil
×
3325
}
3326

3327
// forEachNodeInBatch fetches all nodes in the provided batch, builds them
3328
// with the preloaded data, and executes the provided callback for each node.
3329
func forEachNodeInBatch(ctx context.Context, cfg *sqldb.QueryConfig,
3330
        db SQLQueries, nodes []sqlc.GraphNode,
3331
        cb func(dbID int64, node *models.LightningNode) error) error {
×
3332

×
3333
        // Extract node IDs for batch loading.
×
3334
        nodeIDs := make([]int64, len(nodes))
×
3335
        for i, node := range nodes {
×
3336
                nodeIDs[i] = node.ID
×
3337
        }
×
3338

3339
        // Batch load all related data for this page.
3340
        batchData, err := batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
3341
        if err != nil {
×
3342
                return fmt.Errorf("unable to batch load node data: %w", err)
×
3343
        }
×
3344

3345
        for _, dbNode := range nodes {
×
3346
                node, err := buildNodeWithBatchData(dbNode, batchData)
×
3347
                if err != nil {
×
3348
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
3349
                                dbNode.ID, err)
×
3350
                }
×
3351

3352
                if err := cb(dbNode.ID, node); err != nil {
×
3353
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
3354
                                dbNode.ID, err)
×
3355
                }
×
3356
        }
3357

3358
        return nil
×
3359
}
3360

3361
// getNodeFeatures fetches the feature bits and constructs the feature vector
3362
// for a node with the given DB ID.
3363
func getNodeFeatures(ctx context.Context, db SQLQueries,
3364
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3365

×
3366
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3367
        if err != nil {
×
3368
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3369
                        nodeID, err)
×
3370
        }
×
3371

3372
        features := lnwire.EmptyFeatureVector()
×
3373
        for _, feature := range rows {
×
3374
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3375
        }
×
3376

3377
        return features, nil
×
3378
}
3379

3380
// upsertNode upserts the node record into the database. If the node already
3381
// exists, then the node's information is updated. If the node doesn't exist,
3382
// then a new node is created. The node's features, addresses and extra TLV
3383
// types are also updated. The node's DB ID is returned.
3384
func upsertNode(ctx context.Context, db SQLQueries,
3385
        node *models.LightningNode) (int64, error) {
×
3386

×
3387
        params := sqlc.UpsertNodeParams{
×
3388
                Version: int16(ProtocolV1),
×
3389
                PubKey:  node.PubKeyBytes[:],
×
3390
        }
×
3391

×
3392
        if node.HaveNodeAnnouncement {
×
3393
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
3394
                params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
×
3395
                params.Alias = sqldb.SQLStr(node.Alias)
×
3396
                params.Signature = node.AuthSigBytes
×
3397
        }
×
3398

3399
        nodeID, err := db.UpsertNode(ctx, params)
×
3400
        if err != nil {
×
3401
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3402
                        err)
×
3403
        }
×
3404

3405
        // We can exit here if we don't have the announcement yet.
3406
        if !node.HaveNodeAnnouncement {
×
3407
                return nodeID, nil
×
3408
        }
×
3409

3410
        // Update the node's features.
3411
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3412
        if err != nil {
×
3413
                return 0, fmt.Errorf("inserting node features: %w", err)
×
3414
        }
×
3415

3416
        // Update the node's addresses.
3417
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3418
        if err != nil {
×
3419
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
3420
        }
×
3421

3422
        // Convert the flat extra opaque data into a map of TLV types to
3423
        // values.
3424
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
3425
        if err != nil {
×
3426
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
3427
                        err)
×
3428
        }
×
3429

3430
        // Update the node's extra signed fields.
3431
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3432
        if err != nil {
×
3433
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
3434
        }
×
3435

3436
        return nodeID, nil
×
3437
}
3438

3439
// upsertNodeFeatures updates the node's features node_features table. This
3440
// includes deleting any feature bits no longer present and inserting any new
3441
// feature bits. If the feature bit does not yet exist in the features table,
3442
// then an entry is created in that table first.
3443
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3444
        features *lnwire.FeatureVector) error {
×
3445

×
3446
        // Get any existing features for the node.
×
3447
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3448
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3449
                return err
×
3450
        }
×
3451

3452
        // Copy the nodes latest set of feature bits.
3453
        newFeatures := make(map[int32]struct{})
×
3454
        if features != nil {
×
3455
                for feature := range features.Features() {
×
3456
                        newFeatures[int32(feature)] = struct{}{}
×
3457
                }
×
3458
        }
3459

3460
        // For any current feature that already exists in the DB, remove it from
3461
        // the in-memory map. For any existing feature that does not exist in
3462
        // the in-memory map, delete it from the database.
3463
        for _, feature := range existingFeatures {
×
3464
                // The feature is still present, so there are no updates to be
×
3465
                // made.
×
3466
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3467
                        delete(newFeatures, feature.FeatureBit)
×
3468
                        continue
×
3469
                }
3470

3471
                // The feature is no longer present, so we remove it from the
3472
                // database.
3473
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
3474
                        NodeID:     nodeID,
×
3475
                        FeatureBit: feature.FeatureBit,
×
3476
                })
×
3477
                if err != nil {
×
3478
                        return fmt.Errorf("unable to delete node(%d) "+
×
3479
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3480
                                err)
×
3481
                }
×
3482
        }
3483

3484
        // Any remaining entries in newFeatures are new features that need to be
3485
        // added to the database for the first time.
3486
        for feature := range newFeatures {
×
3487
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3488
                        NodeID:     nodeID,
×
3489
                        FeatureBit: feature,
×
3490
                })
×
3491
                if err != nil {
×
3492
                        return fmt.Errorf("unable to insert node(%d) "+
×
3493
                                "feature(%v): %w", nodeID, feature, err)
×
3494
                }
×
3495
        }
3496

3497
        return nil
×
3498
}
3499

3500
// fetchNodeFeatures fetches the features for a node with the given public key.
3501
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3502
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3503

×
3504
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3505
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3506
                        PubKey:  nodePub[:],
×
3507
                        Version: int16(ProtocolV1),
×
3508
                },
×
3509
        )
×
3510
        if err != nil {
×
3511
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
3512
                        nodePub, err)
×
3513
        }
×
3514

3515
        features := lnwire.EmptyFeatureVector()
×
3516
        for _, bit := range rows {
×
3517
                features.Set(lnwire.FeatureBit(bit))
×
3518
        }
×
3519

3520
        return features, nil
×
3521
}
3522

3523
// dbAddressType is an enum type that represents the different address types
3524
// that we store in the node_addresses table. The address type determines how
3525
// the address is to be serialised/deserialize.
3526
type dbAddressType uint8
3527

3528
const (
3529
        addressTypeIPv4   dbAddressType = 1
3530
        addressTypeIPv6   dbAddressType = 2
3531
        addressTypeTorV2  dbAddressType = 3
3532
        addressTypeTorV3  dbAddressType = 4
3533
        addressTypeOpaque dbAddressType = math.MaxInt8
3534
)
3535

3536
// collectAddressRecords collects the addresses from the provided
3537
// net.Addr slice and returns a map of dbAddressType to a slice of address
3538
// strings.
3539
func collectAddressRecords(addresses []net.Addr) (map[dbAddressType][]string,
NEW
3540
        error) {
×
UNCOV
3541

×
UNCOV
3542
        // Copy the nodes latest set of addresses.
×
3543
        newAddresses := map[dbAddressType][]string{
×
3544
                addressTypeIPv4:   {},
×
3545
                addressTypeIPv6:   {},
×
3546
                addressTypeTorV2:  {},
×
3547
                addressTypeTorV3:  {},
×
3548
                addressTypeOpaque: {},
×
3549
        }
×
3550
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3551
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3552
        }
×
3553

3554
        for _, address := range addresses {
×
3555
                switch addr := address.(type) {
×
3556
                case *net.TCPAddr:
×
3557
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3558
                                addAddr(addressTypeIPv4, addr)
×
3559
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3560
                                addAddr(addressTypeIPv6, addr)
×
3561
                        } else {
×
NEW
3562
                                return nil, fmt.Errorf("unhandled IP "+
×
NEW
3563
                                        "address: %v", addr)
×
UNCOV
3564
                        }
×
3565

3566
                case *tor.OnionAddr:
×
3567
                        switch len(addr.OnionService) {
×
3568
                        case tor.V2Len:
×
3569
                                addAddr(addressTypeTorV2, addr)
×
3570
                        case tor.V3Len:
×
3571
                                addAddr(addressTypeTorV3, addr)
×
3572
                        default:
×
NEW
3573
                                return nil, fmt.Errorf("invalid length for " +
×
NEW
3574
                                        "a tor address")
×
3575
                        }
3576

3577
                case *lnwire.OpaqueAddrs:
×
3578
                        addAddr(addressTypeOpaque, addr)
×
3579

3580
                default:
×
NEW
3581
                        return nil, fmt.Errorf("unhandled address type: %T",
×
NEW
3582
                                addr)
×
3583
                }
3584
        }
3585

NEW
3586
        return newAddresses, nil
×
3587
}
3588

3589
// upsertNodeAddresses updates the node's addresses in the database. This
3590
// includes deleting any existing addresses and inserting the new set of
3591
// addresses. The deletion is necessary since the ordering of the addresses may
3592
// change, and we need to ensure that the database reflects the latest set of
3593
// addresses so that at the time of reconstructing the node announcement, the
3594
// order is preserved and the signature over the message remains valid.
3595
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
NEW
3596
        addresses []net.Addr) error {
×
NEW
3597

×
NEW
3598
        // Delete any existing addresses for the node. This is required since
×
NEW
3599
        // even if the new set of addresses is the same, the ordering may have
×
NEW
3600
        // changed for a given address type.
×
NEW
3601
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
NEW
3602
        if err != nil {
×
NEW
3603
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
NEW
3604
                        nodeID, err)
×
NEW
3605
        }
×
3606

NEW
3607
        newAddresses, err := collectAddressRecords(addresses)
×
NEW
3608
        if err != nil {
×
NEW
3609
                return err
×
NEW
3610
        }
×
3611

3612
        // Any remaining entries in newAddresses are new addresses that need to
3613
        // be added to the database for the first time.
3614
        for addrType, addrList := range newAddresses {
×
3615
                for position, addr := range addrList {
×
NEW
3616
                        err := db.UpsertNodeAddress(
×
NEW
3617
                                ctx, sqlc.UpsertNodeAddressParams{
×
3618
                                        NodeID:   nodeID,
×
3619
                                        Type:     int16(addrType),
×
3620
                                        Address:  addr,
×
3621
                                        Position: int32(position),
×
3622
                                },
×
3623
                        )
×
3624
                        if err != nil {
×
3625
                                return fmt.Errorf("unable to insert "+
×
3626
                                        "node(%d) address(%v): %w", nodeID,
×
3627
                                        addr, err)
×
3628
                        }
×
3629
                }
3630
        }
3631

3632
        return nil
×
3633
}
3634

3635
// getNodeAddresses fetches the addresses for a node with the given DB ID.
3636
func getNodeAddresses(ctx context.Context, db SQLQueries, id int64) ([]net.Addr,
3637
        error) {
×
3638

×
3639
        // GetNodeAddresses ensures that the addresses for a given type are
×
3640
        // returned in the same order as they were inserted.
×
3641
        rows, err := db.GetNodeAddresses(ctx, id)
×
3642
        if err != nil {
×
3643
                return nil, err
×
3644
        }
×
3645

3646
        addresses := make([]net.Addr, 0, len(rows))
×
3647
        for _, row := range rows {
×
3648
                address := row.Address
×
3649

×
3650
                addr, err := parseAddress(dbAddressType(row.Type), address)
×
3651
                if err != nil {
×
3652
                        return nil, fmt.Errorf("unable to parse address "+
×
3653
                                "for node(%d): %v: %w", id, address, err)
×
3654
                }
×
3655

3656
                addresses = append(addresses, addr)
×
3657
        }
3658

3659
        // If we have no addresses, then we'll return nil instead of an
3660
        // empty slice.
3661
        if len(addresses) == 0 {
×
3662
                addresses = nil
×
3663
        }
×
3664

3665
        return addresses, nil
×
3666
}
3667

3668
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
3669
// database. This includes updating any existing types, inserting any new types,
3670
// and deleting any types that are no longer present.
3671
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3672
        nodeID int64, extraFields map[uint64][]byte) error {
×
3673

×
3674
        // Get any existing extra signed fields for the node.
×
3675
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3676
        if err != nil {
×
3677
                return err
×
3678
        }
×
3679

3680
        // Make a lookup map of the existing field types so that we can use it
3681
        // to keep track of any fields we should delete.
3682
        m := make(map[uint64]bool)
×
3683
        for _, field := range existingFields {
×
3684
                m[uint64(field.Type)] = true
×
3685
        }
×
3686

3687
        // For all the new fields, we'll upsert them and remove them from the
3688
        // map of existing fields.
3689
        for tlvType, value := range extraFields {
×
3690
                err = db.UpsertNodeExtraType(
×
3691
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
3692
                                NodeID: nodeID,
×
3693
                                Type:   int64(tlvType),
×
3694
                                Value:  value,
×
3695
                        },
×
3696
                )
×
3697
                if err != nil {
×
3698
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
3699
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3700
                }
×
3701

3702
                // Remove the field from the map of existing fields if it was
3703
                // present.
3704
                delete(m, tlvType)
×
3705
        }
3706

3707
        // For all the fields that are left in the map of existing fields, we'll
3708
        // delete them as they are no longer present in the new set of fields.
3709
        for tlvType := range m {
×
3710
                err = db.DeleteExtraNodeType(
×
3711
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
3712
                                NodeID: nodeID,
×
3713
                                Type:   int64(tlvType),
×
3714
                        },
×
3715
                )
×
3716
                if err != nil {
×
3717
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
3718
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3719
                }
×
3720
        }
3721

3722
        return nil
×
3723
}
3724

3725
// srcNodeInfo holds the information about the source node of the graph.
3726
type srcNodeInfo struct {
3727
        // id is the DB level ID of the source node entry in the "nodes" table.
3728
        id int64
3729

3730
        // pub is the public key of the source node.
3731
        pub route.Vertex
3732
}
3733

3734
// sourceNode returns the DB node ID and pub key of the source node for the
3735
// specified protocol version.
3736
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
3737
        version ProtocolVersion) (int64, route.Vertex, error) {
×
3738

×
3739
        s.srcNodeMu.Lock()
×
3740
        defer s.srcNodeMu.Unlock()
×
3741

×
3742
        // If we already have the source node ID and pub key cached, then
×
3743
        // return them.
×
3744
        if info, ok := s.srcNodes[version]; ok {
×
3745
                return info.id, info.pub, nil
×
3746
        }
×
3747

3748
        var pubKey route.Vertex
×
3749

×
3750
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
3751
        if err != nil {
×
3752
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
3753
                        err)
×
3754
        }
×
3755

3756
        if len(nodes) == 0 {
×
3757
                return 0, pubKey, ErrSourceNodeNotSet
×
3758
        } else if len(nodes) > 1 {
×
3759
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
3760
                        "protocol %s found", version)
×
3761
        }
×
3762

3763
        copy(pubKey[:], nodes[0].PubKey)
×
3764

×
3765
        s.srcNodes[version] = &srcNodeInfo{
×
3766
                id:  nodes[0].NodeID,
×
3767
                pub: pubKey,
×
3768
        }
×
3769

×
3770
        return nodes[0].NodeID, pubKey, nil
×
3771
}
3772

3773
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
3774
// This then produces a map from TLV type to value. If the input is not a
3775
// valid TLV stream, then an error is returned.
3776
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
3777
        r := bytes.NewReader(data)
×
3778

×
3779
        tlvStream, err := tlv.NewStream()
×
3780
        if err != nil {
×
3781
                return nil, err
×
3782
        }
×
3783

3784
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
3785
        // pass it into the P2P decoding variant.
3786
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
3787
        if err != nil {
×
3788
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
3789
        }
×
3790
        if len(parsedTypes) == 0 {
×
3791
                return nil, nil
×
3792
        }
×
3793

3794
        records := make(map[uint64][]byte)
×
3795
        for k, v := range parsedTypes {
×
3796
                records[uint64(k)] = v
×
3797
        }
×
3798

3799
        return records, nil
×
3800
}
3801

3802
// insertChannel inserts a new channel record into the database.
3803
func insertChannel(ctx context.Context, db SQLQueries,
NEW
3804
        edge *models.ChannelEdgeInfo) error {
×
3805

×
3806
        // Make sure that at least a "shell" entry for each node is present in
×
3807
        // the nodes table.
×
3808
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
3809
        if err != nil {
×
NEW
3810
                return fmt.Errorf("unable to create shell node: %w", err)
×
3811
        }
×
3812

3813
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
3814
        if err != nil {
×
NEW
3815
                return fmt.Errorf("unable to create shell node: %w", err)
×
3816
        }
×
3817

3818
        var capacity sql.NullInt64
×
3819
        if edge.Capacity != 0 {
×
3820
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
3821
        }
×
3822

3823
        createParams := sqlc.CreateChannelParams{
×
3824
                Version:     int16(ProtocolV1),
×
3825
                Scid:        channelIDToBytes(edge.ChannelID),
×
3826
                NodeID1:     node1DBID,
×
3827
                NodeID2:     node2DBID,
×
3828
                Outpoint:    edge.ChannelPoint.String(),
×
3829
                Capacity:    capacity,
×
3830
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
3831
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
3832
        }
×
3833

×
3834
        if edge.AuthProof != nil {
×
3835
                proof := edge.AuthProof
×
3836

×
3837
                createParams.Node1Signature = proof.NodeSig1Bytes
×
3838
                createParams.Node2Signature = proof.NodeSig2Bytes
×
3839
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
3840
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
3841
        }
×
3842

3843
        // Insert the new channel record.
3844
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
3845
        if err != nil {
×
NEW
3846
                return err
×
3847
        }
×
3848

3849
        // Insert any channel features.
3850
        for feature := range edge.Features.Features() {
×
3851
                err = db.InsertChannelFeature(
×
3852
                        ctx, sqlc.InsertChannelFeatureParams{
×
3853
                                ChannelID:  dbChanID,
×
3854
                                FeatureBit: int32(feature),
×
3855
                        },
×
3856
                )
×
3857
                if err != nil {
×
NEW
3858
                        return fmt.Errorf("unable to insert channel(%d) "+
×
3859
                                "feature(%v): %w", dbChanID, feature, err)
×
3860
                }
×
3861
        }
3862

3863
        // Finally, insert any extra TLV fields in the channel announcement.
3864
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3865
        if err != nil {
×
NEW
3866
                return fmt.Errorf("unable to marshal extra opaque data: %w",
×
NEW
3867
                        err)
×
UNCOV
3868
        }
×
3869

3870
        for tlvType, value := range extra {
×
NEW
3871
                err := db.UpsertChannelExtraType(
×
NEW
3872
                        ctx, sqlc.UpsertChannelExtraTypeParams{
×
3873
                                ChannelID: dbChanID,
×
3874
                                Type:      int64(tlvType),
×
3875
                                Value:     value,
×
3876
                        },
×
3877
                )
×
3878
                if err != nil {
×
NEW
3879
                        return fmt.Errorf("unable to upsert channel(%d) "+
×
NEW
3880
                                "extra signed field(%v): %w", edge.ChannelID,
×
NEW
3881
                                tlvType, err)
×
UNCOV
3882
                }
×
3883
        }
3884

NEW
3885
        return nil
×
3886
}
3887

3888
// maybeCreateShellNode checks if a shell node entry exists for the
3889
// given public key. If it does not exist, then a new shell node entry is
3890
// created. The ID of the node is returned. A shell node only has a protocol
3891
// version and public key persisted.
3892
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
3893
        pubKey route.Vertex) (int64, error) {
×
3894

×
3895
        dbNode, err := db.GetNodeByPubKey(
×
3896
                ctx, sqlc.GetNodeByPubKeyParams{
×
3897
                        PubKey:  pubKey[:],
×
3898
                        Version: int16(ProtocolV1),
×
3899
                },
×
3900
        )
×
3901
        // The node exists. Return the ID.
×
3902
        if err == nil {
×
3903
                return dbNode.ID, nil
×
3904
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3905
                return 0, err
×
3906
        }
×
3907

3908
        // Otherwise, the node does not exist, so we create a shell entry for
3909
        // it.
3910
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
3911
                Version: int16(ProtocolV1),
×
3912
                PubKey:  pubKey[:],
×
3913
        })
×
3914
        if err != nil {
×
3915
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
3916
        }
×
3917

3918
        return id, nil
×
3919
}
3920

3921
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
3922
// the database. This includes deleting any existing types and then inserting
3923
// the new types.
3924
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
3925
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
3926

×
3927
        // Delete all existing extra signed fields for the channel policy.
×
3928
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
3929
        if err != nil {
×
3930
                return fmt.Errorf("unable to delete "+
×
3931
                        "existing policy extra signed fields for policy %d: %w",
×
3932
                        chanPolicyID, err)
×
3933
        }
×
3934

3935
        // Insert all new extra signed fields for the channel policy.
3936
        for tlvType, value := range extraFields {
×
NEW
3937
                err = db.UpsertChanPolicyExtraType(
×
NEW
3938
                        ctx, sqlc.UpsertChanPolicyExtraTypeParams{
×
3939
                                ChannelPolicyID: chanPolicyID,
×
3940
                                Type:            int64(tlvType),
×
3941
                                Value:           value,
×
3942
                        },
×
3943
                )
×
3944
                if err != nil {
×
3945
                        return fmt.Errorf("unable to insert "+
×
3946
                                "channel_policy(%d) extra signed field(%v): %w",
×
3947
                                chanPolicyID, tlvType, err)
×
3948
                }
×
3949
        }
3950

3951
        return nil
×
3952
}
3953

3954
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
3955
// provided dbChanRow and also fetches any other required information
3956
// to construct the edge info.
3957
func getAndBuildEdgeInfo(ctx context.Context, cfg *SQLStoreConfig,
3958
        db SQLQueries, dbChan sqlc.GraphChannel, node1,
3959
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
3960

×
3961
        data, err := batchLoadChannelData(
×
3962
                ctx, cfg.QueryCfg, db, []int64{dbChan.ID}, nil,
×
3963
        )
×
3964
        if err != nil {
×
3965
                return nil, fmt.Errorf("unable to batch load channel data: %w",
×
3966
                        err)
×
3967
        }
×
3968

3969
        return buildEdgeInfoWithBatchData(
×
3970
                cfg.ChainHash, dbChan, node1, node2, data,
×
3971
        )
×
3972
}
3973

3974
// buildEdgeInfoWithBatchData builds edge info using pre-loaded batch data.
3975
func buildEdgeInfoWithBatchData(chain chainhash.Hash,
3976
        dbChan sqlc.GraphChannel, node1, node2 route.Vertex,
3977
        batchData *batchChannelData) (*models.ChannelEdgeInfo, error) {
×
3978

×
3979
        if dbChan.Version != int16(ProtocolV1) {
×
3980
                return nil, fmt.Errorf("unsupported channel version: %d",
×
3981
                        dbChan.Version)
×
3982
        }
×
3983

3984
        // Use pre-loaded features and extras types.
3985
        fv := lnwire.EmptyFeatureVector()
×
3986
        if features, exists := batchData.chanfeatures[dbChan.ID]; exists {
×
3987
                for _, bit := range features {
×
3988
                        fv.Set(lnwire.FeatureBit(bit))
×
3989
                }
×
3990
        }
3991

3992
        var extras map[uint64][]byte
×
3993
        channelExtras, exists := batchData.chanExtraTypes[dbChan.ID]
×
3994
        if exists {
×
3995
                extras = channelExtras
×
3996
        } else {
×
3997
                extras = make(map[uint64][]byte)
×
3998
        }
×
3999

4000
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
4001
        if err != nil {
×
4002
                return nil, err
×
4003
        }
×
4004

4005
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4006
        if err != nil {
×
4007
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4008
                        "fields: %w", err)
×
4009
        }
×
4010
        if recs == nil {
×
4011
                recs = make([]byte, 0)
×
4012
        }
×
4013

4014
        var btcKey1, btcKey2 route.Vertex
×
4015
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
4016
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
4017

×
4018
        channel := &models.ChannelEdgeInfo{
×
4019
                ChainHash:        chain,
×
4020
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
4021
                NodeKey1Bytes:    node1,
×
4022
                NodeKey2Bytes:    node2,
×
4023
                BitcoinKey1Bytes: btcKey1,
×
4024
                BitcoinKey2Bytes: btcKey2,
×
4025
                ChannelPoint:     *op,
×
4026
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
4027
                Features:         fv,
×
4028
                ExtraOpaqueData:  recs,
×
4029
        }
×
4030

×
4031
        // We always set all the signatures at the same time, so we can
×
4032
        // safely check if one signature is present to determine if we have the
×
4033
        // rest of the signatures for the auth proof.
×
4034
        if len(dbChan.Bitcoin1Signature) > 0 {
×
4035
                channel.AuthProof = &models.ChannelAuthProof{
×
4036
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
4037
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
4038
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
4039
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
4040
                }
×
4041
        }
×
4042

4043
        return channel, nil
×
4044
}
4045

4046
// buildNodeVertices is a helper that converts raw node public keys
4047
// into route.Vertex instances.
4048
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4049
        route.Vertex, error) {
×
4050

×
4051
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4052
        if err != nil {
×
4053
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4054
                        "create vertex from node1 pubkey: %w", err)
×
4055
        }
×
4056

4057
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
4058
        if err != nil {
×
4059
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4060
                        "create vertex from node2 pubkey: %w", err)
×
4061
        }
×
4062

4063
        return node1Vertex, node2Vertex, nil
×
4064
}
4065

4066
// getAndBuildChanPolicies uses the given sqlc.GraphChannelPolicy and also
4067
// retrieves all the extra info required to build the complete
4068
// models.ChannelEdgePolicy types. It returns two policies, which may be nil if
4069
// the provided sqlc.GraphChannelPolicy records are nil.
4070
func getAndBuildChanPolicies(ctx context.Context, cfg *sqldb.QueryConfig,
4071
        db SQLQueries, dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4072
        channelID uint64, node1, node2 route.Vertex) (*models.ChannelEdgePolicy,
4073
        *models.ChannelEdgePolicy, error) {
×
4074

×
4075
        if dbPol1 == nil && dbPol2 == nil {
×
4076
                return nil, nil, nil
×
4077
        }
×
4078

4079
        var policyIDs = make([]int64, 0, 2)
×
4080
        if dbPol1 != nil {
×
4081
                policyIDs = append(policyIDs, dbPol1.ID)
×
4082
        }
×
4083
        if dbPol2 != nil {
×
4084
                policyIDs = append(policyIDs, dbPol2.ID)
×
4085
        }
×
4086

4087
        batchData, err := batchLoadChannelData(ctx, cfg, db, nil, policyIDs)
×
4088
        if err != nil {
×
4089
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
4090
                        "data: %w", err)
×
4091
        }
×
4092

4093
        pol1, err := buildChanPolicyWithBatchData(
×
4094
                dbPol1, channelID, node2, batchData,
×
4095
        )
×
4096
        if err != nil {
×
4097
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
4098
        }
×
4099

4100
        pol2, err := buildChanPolicyWithBatchData(
×
4101
                dbPol2, channelID, node1, batchData,
×
4102
        )
×
4103
        if err != nil {
×
4104
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
4105
        }
×
4106

4107
        return pol1, pol2, nil
×
4108
}
4109

4110
// buildCachedChanPolicies builds models.CachedEdgePolicy instances from the
4111
// provided sqlc.GraphChannelPolicy objects. If a provided policy is nil,
4112
// then nil is returned for it.
4113
func buildCachedChanPolicies(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4114
        channelID uint64, node1, node2 route.Vertex) (*models.CachedEdgePolicy,
4115
        *models.CachedEdgePolicy, error) {
×
4116

×
4117
        var p1, p2 *models.CachedEdgePolicy
×
4118
        if dbPol1 != nil {
×
4119
                policy1, err := buildChanPolicy(*dbPol1, channelID, nil, node2)
×
4120
                if err != nil {
×
4121
                        return nil, nil, err
×
4122
                }
×
4123

4124
                p1 = models.NewCachedPolicy(policy1)
×
4125
        }
4126
        if dbPol2 != nil {
×
4127
                policy2, err := buildChanPolicy(*dbPol2, channelID, nil, node1)
×
4128
                if err != nil {
×
4129
                        return nil, nil, err
×
4130
                }
×
4131

4132
                p2 = models.NewCachedPolicy(policy2)
×
4133
        }
4134

4135
        return p1, p2, nil
×
4136
}
4137

4138
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4139
// provided sqlc.GraphChannelPolicy and other required information.
4140
func buildChanPolicy(dbPolicy sqlc.GraphChannelPolicy, channelID uint64,
4141
        extras map[uint64][]byte,
4142
        toNode route.Vertex) (*models.ChannelEdgePolicy, error) {
×
4143

×
4144
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4145
        if err != nil {
×
4146
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4147
                        "fields: %w", err)
×
4148
        }
×
4149

4150
        var inboundFee fn.Option[lnwire.Fee]
×
4151
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
4152
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4153

×
4154
                inboundFee = fn.Some(lnwire.Fee{
×
4155
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4156
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4157
                })
×
4158
        }
×
4159

4160
        return &models.ChannelEdgePolicy{
×
4161
                SigBytes:  dbPolicy.Signature,
×
4162
                ChannelID: channelID,
×
4163
                LastUpdate: time.Unix(
×
4164
                        dbPolicy.LastUpdate.Int64, 0,
×
4165
                ),
×
4166
                MessageFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags](
×
4167
                        dbPolicy.MessageFlags,
×
4168
                ),
×
4169
                ChannelFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags](
×
4170
                        dbPolicy.ChannelFlags,
×
4171
                ),
×
4172
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
4173
                MinHTLC: lnwire.MilliSatoshi(
×
4174
                        dbPolicy.MinHtlcMsat,
×
4175
                ),
×
4176
                MaxHTLC: lnwire.MilliSatoshi(
×
4177
                        dbPolicy.MaxHtlcMsat.Int64,
×
4178
                ),
×
4179
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4180
                        dbPolicy.BaseFeeMsat,
×
4181
                ),
×
4182
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4183
                ToNode:                    toNode,
×
4184
                InboundFee:                inboundFee,
×
4185
                ExtraOpaqueData:           recs,
×
4186
        }, nil
×
4187
}
4188

4189
// extractChannelPolicies extracts the sqlc.GraphChannelPolicy records from the give
4190
// row which is expected to be a sqlc type that contains channel policy
4191
// information. It returns two policies, which may be nil if the policy
4192
// information is not present in the row.
4193
//
4194
//nolint:ll,dupl,funlen
4195
func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy,
4196
        *sqlc.GraphChannelPolicy, error) {
×
4197

×
4198
        var policy1, policy2 *sqlc.GraphChannelPolicy
×
4199
        switch r := row.(type) {
×
4200
        case sqlc.ListChannelsWithPoliciesForCachePaginatedRow:
×
4201
                if r.Policy1Timelock.Valid {
×
4202
                        policy1 = &sqlc.GraphChannelPolicy{
×
4203
                                Timelock:                r.Policy1Timelock.Int32,
×
4204
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4205
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4206
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4207
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4208
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4209
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4210
                                Disabled:                r.Policy1Disabled,
×
4211
                                MessageFlags:            r.Policy1MessageFlags,
×
4212
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4213
                        }
×
4214
                }
×
4215
                if r.Policy2Timelock.Valid {
×
4216
                        policy2 = &sqlc.GraphChannelPolicy{
×
4217
                                Timelock:                r.Policy2Timelock.Int32,
×
4218
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4219
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4220
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4221
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4222
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4223
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4224
                                Disabled:                r.Policy2Disabled,
×
4225
                                MessageFlags:            r.Policy2MessageFlags,
×
4226
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4227
                        }
×
4228
                }
×
4229

4230
                return policy1, policy2, nil
×
4231

4232
        case sqlc.GetChannelsBySCIDWithPoliciesRow:
×
4233
                if r.Policy1ID.Valid {
×
4234
                        policy1 = &sqlc.GraphChannelPolicy{
×
4235
                                ID:                      r.Policy1ID.Int64,
×
4236
                                Version:                 r.Policy1Version.Int16,
×
4237
                                ChannelID:               r.GraphChannel.ID,
×
4238
                                NodeID:                  r.Policy1NodeID.Int64,
×
4239
                                Timelock:                r.Policy1Timelock.Int32,
×
4240
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4241
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4242
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4243
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4244
                                LastUpdate:              r.Policy1LastUpdate,
×
4245
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4246
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4247
                                Disabled:                r.Policy1Disabled,
×
4248
                                MessageFlags:            r.Policy1MessageFlags,
×
4249
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4250
                                Signature:               r.Policy1Signature,
×
4251
                        }
×
4252
                }
×
4253
                if r.Policy2ID.Valid {
×
4254
                        policy2 = &sqlc.GraphChannelPolicy{
×
4255
                                ID:                      r.Policy2ID.Int64,
×
4256
                                Version:                 r.Policy2Version.Int16,
×
4257
                                ChannelID:               r.GraphChannel.ID,
×
4258
                                NodeID:                  r.Policy2NodeID.Int64,
×
4259
                                Timelock:                r.Policy2Timelock.Int32,
×
4260
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4261
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4262
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4263
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4264
                                LastUpdate:              r.Policy2LastUpdate,
×
4265
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4266
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4267
                                Disabled:                r.Policy2Disabled,
×
4268
                                MessageFlags:            r.Policy2MessageFlags,
×
4269
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4270
                                Signature:               r.Policy2Signature,
×
4271
                        }
×
4272
                }
×
4273

4274
                return policy1, policy2, nil
×
4275

4276
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
4277
                if r.Policy1ID.Valid {
×
4278
                        policy1 = &sqlc.GraphChannelPolicy{
×
4279
                                ID:                      r.Policy1ID.Int64,
×
4280
                                Version:                 r.Policy1Version.Int16,
×
4281
                                ChannelID:               r.GraphChannel.ID,
×
4282
                                NodeID:                  r.Policy1NodeID.Int64,
×
4283
                                Timelock:                r.Policy1Timelock.Int32,
×
4284
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4285
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4286
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4287
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4288
                                LastUpdate:              r.Policy1LastUpdate,
×
4289
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4290
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4291
                                Disabled:                r.Policy1Disabled,
×
4292
                                MessageFlags:            r.Policy1MessageFlags,
×
4293
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4294
                                Signature:               r.Policy1Signature,
×
4295
                        }
×
4296
                }
×
4297
                if r.Policy2ID.Valid {
×
4298
                        policy2 = &sqlc.GraphChannelPolicy{
×
4299
                                ID:                      r.Policy2ID.Int64,
×
4300
                                Version:                 r.Policy2Version.Int16,
×
4301
                                ChannelID:               r.GraphChannel.ID,
×
4302
                                NodeID:                  r.Policy2NodeID.Int64,
×
4303
                                Timelock:                r.Policy2Timelock.Int32,
×
4304
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4305
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4306
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4307
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4308
                                LastUpdate:              r.Policy2LastUpdate,
×
4309
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4310
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4311
                                Disabled:                r.Policy2Disabled,
×
4312
                                MessageFlags:            r.Policy2MessageFlags,
×
4313
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4314
                                Signature:               r.Policy2Signature,
×
4315
                        }
×
4316
                }
×
4317

4318
                return policy1, policy2, nil
×
4319

4320
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4321
                if r.Policy1ID.Valid {
×
4322
                        policy1 = &sqlc.GraphChannelPolicy{
×
4323
                                ID:                      r.Policy1ID.Int64,
×
4324
                                Version:                 r.Policy1Version.Int16,
×
4325
                                ChannelID:               r.GraphChannel.ID,
×
4326
                                NodeID:                  r.Policy1NodeID.Int64,
×
4327
                                Timelock:                r.Policy1Timelock.Int32,
×
4328
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4329
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4330
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4331
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4332
                                LastUpdate:              r.Policy1LastUpdate,
×
4333
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4334
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4335
                                Disabled:                r.Policy1Disabled,
×
4336
                                MessageFlags:            r.Policy1MessageFlags,
×
4337
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4338
                                Signature:               r.Policy1Signature,
×
4339
                        }
×
4340
                }
×
4341
                if r.Policy2ID.Valid {
×
4342
                        policy2 = &sqlc.GraphChannelPolicy{
×
4343
                                ID:                      r.Policy2ID.Int64,
×
4344
                                Version:                 r.Policy2Version.Int16,
×
4345
                                ChannelID:               r.GraphChannel.ID,
×
4346
                                NodeID:                  r.Policy2NodeID.Int64,
×
4347
                                Timelock:                r.Policy2Timelock.Int32,
×
4348
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4349
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4350
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4351
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4352
                                LastUpdate:              r.Policy2LastUpdate,
×
4353
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4354
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4355
                                Disabled:                r.Policy2Disabled,
×
4356
                                MessageFlags:            r.Policy2MessageFlags,
×
4357
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4358
                                Signature:               r.Policy2Signature,
×
4359
                        }
×
4360
                }
×
4361

4362
                return policy1, policy2, nil
×
4363

4364
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4365
                if r.Policy1ID.Valid {
×
4366
                        policy1 = &sqlc.GraphChannelPolicy{
×
4367
                                ID:                      r.Policy1ID.Int64,
×
4368
                                Version:                 r.Policy1Version.Int16,
×
4369
                                ChannelID:               r.GraphChannel.ID,
×
4370
                                NodeID:                  r.Policy1NodeID.Int64,
×
4371
                                Timelock:                r.Policy1Timelock.Int32,
×
4372
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4373
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4374
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4375
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4376
                                LastUpdate:              r.Policy1LastUpdate,
×
4377
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4378
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4379
                                Disabled:                r.Policy1Disabled,
×
4380
                                MessageFlags:            r.Policy1MessageFlags,
×
4381
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4382
                                Signature:               r.Policy1Signature,
×
4383
                        }
×
4384
                }
×
4385
                if r.Policy2ID.Valid {
×
4386
                        policy2 = &sqlc.GraphChannelPolicy{
×
4387
                                ID:                      r.Policy2ID.Int64,
×
4388
                                Version:                 r.Policy2Version.Int16,
×
4389
                                ChannelID:               r.GraphChannel.ID,
×
4390
                                NodeID:                  r.Policy2NodeID.Int64,
×
4391
                                Timelock:                r.Policy2Timelock.Int32,
×
4392
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4393
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4394
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4395
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4396
                                LastUpdate:              r.Policy2LastUpdate,
×
4397
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4398
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4399
                                Disabled:                r.Policy2Disabled,
×
4400
                                MessageFlags:            r.Policy2MessageFlags,
×
4401
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4402
                                Signature:               r.Policy2Signature,
×
4403
                        }
×
4404
                }
×
4405

4406
                return policy1, policy2, nil
×
4407

4408
        case sqlc.ListChannelsForNodeIDsRow:
×
4409
                if r.Policy1ID.Valid {
×
4410
                        policy1 = &sqlc.GraphChannelPolicy{
×
4411
                                ID:                      r.Policy1ID.Int64,
×
4412
                                Version:                 r.Policy1Version.Int16,
×
4413
                                ChannelID:               r.GraphChannel.ID,
×
4414
                                NodeID:                  r.Policy1NodeID.Int64,
×
4415
                                Timelock:                r.Policy1Timelock.Int32,
×
4416
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4417
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4418
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4419
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4420
                                LastUpdate:              r.Policy1LastUpdate,
×
4421
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4422
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4423
                                Disabled:                r.Policy1Disabled,
×
4424
                                MessageFlags:            r.Policy1MessageFlags,
×
4425
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4426
                                Signature:               r.Policy1Signature,
×
4427
                        }
×
4428
                }
×
4429
                if r.Policy2ID.Valid {
×
4430
                        policy2 = &sqlc.GraphChannelPolicy{
×
4431
                                ID:                      r.Policy2ID.Int64,
×
4432
                                Version:                 r.Policy2Version.Int16,
×
4433
                                ChannelID:               r.GraphChannel.ID,
×
4434
                                NodeID:                  r.Policy2NodeID.Int64,
×
4435
                                Timelock:                r.Policy2Timelock.Int32,
×
4436
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4437
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4438
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4439
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4440
                                LastUpdate:              r.Policy2LastUpdate,
×
4441
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4442
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4443
                                Disabled:                r.Policy2Disabled,
×
4444
                                MessageFlags:            r.Policy2MessageFlags,
×
4445
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4446
                                Signature:               r.Policy2Signature,
×
4447
                        }
×
4448
                }
×
4449

4450
                return policy1, policy2, nil
×
4451

4452
        case sqlc.ListChannelsByNodeIDRow:
×
4453
                if r.Policy1ID.Valid {
×
4454
                        policy1 = &sqlc.GraphChannelPolicy{
×
4455
                                ID:                      r.Policy1ID.Int64,
×
4456
                                Version:                 r.Policy1Version.Int16,
×
4457
                                ChannelID:               r.GraphChannel.ID,
×
4458
                                NodeID:                  r.Policy1NodeID.Int64,
×
4459
                                Timelock:                r.Policy1Timelock.Int32,
×
4460
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4461
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4462
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4463
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4464
                                LastUpdate:              r.Policy1LastUpdate,
×
4465
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4466
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4467
                                Disabled:                r.Policy1Disabled,
×
4468
                                MessageFlags:            r.Policy1MessageFlags,
×
4469
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4470
                                Signature:               r.Policy1Signature,
×
4471
                        }
×
4472
                }
×
4473
                if r.Policy2ID.Valid {
×
4474
                        policy2 = &sqlc.GraphChannelPolicy{
×
4475
                                ID:                      r.Policy2ID.Int64,
×
4476
                                Version:                 r.Policy2Version.Int16,
×
4477
                                ChannelID:               r.GraphChannel.ID,
×
4478
                                NodeID:                  r.Policy2NodeID.Int64,
×
4479
                                Timelock:                r.Policy2Timelock.Int32,
×
4480
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4481
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4482
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4483
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4484
                                LastUpdate:              r.Policy2LastUpdate,
×
4485
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4486
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4487
                                Disabled:                r.Policy2Disabled,
×
4488
                                MessageFlags:            r.Policy2MessageFlags,
×
4489
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4490
                                Signature:               r.Policy2Signature,
×
4491
                        }
×
4492
                }
×
4493

4494
                return policy1, policy2, nil
×
4495

4496
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4497
                if r.Policy1ID.Valid {
×
4498
                        policy1 = &sqlc.GraphChannelPolicy{
×
4499
                                ID:                      r.Policy1ID.Int64,
×
4500
                                Version:                 r.Policy1Version.Int16,
×
4501
                                ChannelID:               r.GraphChannel.ID,
×
4502
                                NodeID:                  r.Policy1NodeID.Int64,
×
4503
                                Timelock:                r.Policy1Timelock.Int32,
×
4504
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4505
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4506
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4507
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4508
                                LastUpdate:              r.Policy1LastUpdate,
×
4509
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4510
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4511
                                Disabled:                r.Policy1Disabled,
×
4512
                                MessageFlags:            r.Policy1MessageFlags,
×
4513
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4514
                                Signature:               r.Policy1Signature,
×
4515
                        }
×
4516
                }
×
4517
                if r.Policy2ID.Valid {
×
4518
                        policy2 = &sqlc.GraphChannelPolicy{
×
4519
                                ID:                      r.Policy2ID.Int64,
×
4520
                                Version:                 r.Policy2Version.Int16,
×
4521
                                ChannelID:               r.GraphChannel.ID,
×
4522
                                NodeID:                  r.Policy2NodeID.Int64,
×
4523
                                Timelock:                r.Policy2Timelock.Int32,
×
4524
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4525
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4526
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4527
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4528
                                LastUpdate:              r.Policy2LastUpdate,
×
4529
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4530
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4531
                                Disabled:                r.Policy2Disabled,
×
4532
                                MessageFlags:            r.Policy2MessageFlags,
×
4533
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4534
                                Signature:               r.Policy2Signature,
×
4535
                        }
×
4536
                }
×
4537

4538
                return policy1, policy2, nil
×
4539

4540
        case sqlc.GetChannelsByIDsRow:
×
4541
                if r.Policy1ID.Valid {
×
4542
                        policy1 = &sqlc.GraphChannelPolicy{
×
4543
                                ID:                      r.Policy1ID.Int64,
×
4544
                                Version:                 r.Policy1Version.Int16,
×
4545
                                ChannelID:               r.GraphChannel.ID,
×
4546
                                NodeID:                  r.Policy1NodeID.Int64,
×
4547
                                Timelock:                r.Policy1Timelock.Int32,
×
4548
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4549
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4550
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4551
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4552
                                LastUpdate:              r.Policy1LastUpdate,
×
4553
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4554
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4555
                                Disabled:                r.Policy1Disabled,
×
4556
                                MessageFlags:            r.Policy1MessageFlags,
×
4557
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4558
                                Signature:               r.Policy1Signature,
×
4559
                        }
×
4560
                }
×
4561
                if r.Policy2ID.Valid {
×
4562
                        policy2 = &sqlc.GraphChannelPolicy{
×
4563
                                ID:                      r.Policy2ID.Int64,
×
4564
                                Version:                 r.Policy2Version.Int16,
×
4565
                                ChannelID:               r.GraphChannel.ID,
×
4566
                                NodeID:                  r.Policy2NodeID.Int64,
×
4567
                                Timelock:                r.Policy2Timelock.Int32,
×
4568
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4569
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4570
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4571
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4572
                                LastUpdate:              r.Policy2LastUpdate,
×
4573
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4574
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4575
                                Disabled:                r.Policy2Disabled,
×
4576
                                MessageFlags:            r.Policy2MessageFlags,
×
4577
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4578
                                Signature:               r.Policy2Signature,
×
4579
                        }
×
4580
                }
×
4581

4582
                return policy1, policy2, nil
×
4583

4584
        default:
×
4585
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
4586
                        "extractChannelPolicies: %T", r)
×
4587
        }
4588
}
4589

4590
// channelIDToBytes converts a channel ID (SCID) to a byte array
4591
// representation.
4592
func channelIDToBytes(channelID uint64) []byte {
×
4593
        var chanIDB [8]byte
×
4594
        byteOrder.PutUint64(chanIDB[:], channelID)
×
4595

×
4596
        return chanIDB[:]
×
4597
}
×
4598

4599
// buildNodeAddresses converts a slice of nodeAddress into a slice of net.Addr.
4600
func buildNodeAddresses(addresses []nodeAddress) ([]net.Addr, error) {
×
4601
        if len(addresses) == 0 {
×
4602
                return nil, nil
×
4603
        }
×
4604

4605
        result := make([]net.Addr, 0, len(addresses))
×
4606
        for _, addr := range addresses {
×
4607
                netAddr, err := parseAddress(addr.addrType, addr.address)
×
4608
                if err != nil {
×
4609
                        return nil, fmt.Errorf("unable to parse address %s "+
×
4610
                                "of type %d: %w", addr.address, addr.addrType,
×
4611
                                err)
×
4612
                }
×
4613
                if netAddr != nil {
×
4614
                        result = append(result, netAddr)
×
4615
                }
×
4616
        }
4617

4618
        // If we have no valid addresses, return nil instead of empty slice.
4619
        if len(result) == 0 {
×
4620
                return nil, nil
×
4621
        }
×
4622

4623
        return result, nil
×
4624
}
4625

4626
// parseAddress parses the given address string based on the address type
4627
// and returns a net.Addr instance. It supports IPv4, IPv6, Tor v2, Tor v3,
4628
// and opaque addresses.
4629
func parseAddress(addrType dbAddressType, address string) (net.Addr, error) {
×
4630
        switch addrType {
×
4631
        case addressTypeIPv4:
×
4632
                tcp, err := net.ResolveTCPAddr("tcp4", address)
×
4633
                if err != nil {
×
4634
                        return nil, err
×
4635
                }
×
4636

4637
                tcp.IP = tcp.IP.To4()
×
4638

×
4639
                return tcp, nil
×
4640

4641
        case addressTypeIPv6:
×
4642
                tcp, err := net.ResolveTCPAddr("tcp6", address)
×
4643
                if err != nil {
×
4644
                        return nil, err
×
4645
                }
×
4646

4647
                return tcp, nil
×
4648

4649
        case addressTypeTorV3, addressTypeTorV2:
×
4650
                service, portStr, err := net.SplitHostPort(address)
×
4651
                if err != nil {
×
4652
                        return nil, fmt.Errorf("unable to split tor "+
×
4653
                                "address: %v", address)
×
4654
                }
×
4655

4656
                port, err := strconv.Atoi(portStr)
×
4657
                if err != nil {
×
4658
                        return nil, err
×
4659
                }
×
4660

4661
                return &tor.OnionAddr{
×
4662
                        OnionService: service,
×
4663
                        Port:         port,
×
4664
                }, nil
×
4665

4666
        case addressTypeOpaque:
×
4667
                opaque, err := hex.DecodeString(address)
×
4668
                if err != nil {
×
4669
                        return nil, fmt.Errorf("unable to decode opaque "+
×
4670
                                "address: %v", address)
×
4671
                }
×
4672

4673
                return &lnwire.OpaqueAddrs{
×
4674
                        Payload: opaque,
×
4675
                }, nil
×
4676

4677
        default:
×
4678
                return nil, fmt.Errorf("unknown address type: %v", addrType)
×
4679
        }
4680
}
4681

4682
// batchNodeData holds all the related data for a batch of nodes.
4683
type batchNodeData struct {
4684
        // features is a map from a DB node ID to the feature bits for that
4685
        // node.
4686
        features map[int64][]int
4687

4688
        // addresses is a map from a DB node ID to the node's addresses.
4689
        addresses map[int64][]nodeAddress
4690

4691
        // extraFields is a map from a DB node ID to the extra signed fields
4692
        // for that node.
4693
        extraFields map[int64]map[uint64][]byte
4694
}
4695

4696
// nodeAddress holds the address type, position and address string for a
4697
// node. This is used to batch the fetching of node addresses.
4698
type nodeAddress struct {
4699
        addrType dbAddressType
4700
        position int32
4701
        address  string
4702
}
4703

4704
// batchLoadNodeData loads all related data for a batch of node IDs using the
4705
// provided SQLQueries interface. It returns a batchNodeData instance containing
4706
// the node features, addresses and extra signed fields.
4707
func batchLoadNodeData(ctx context.Context, cfg *sqldb.QueryConfig,
4708
        db SQLQueries, nodeIDs []int64) (*batchNodeData, error) {
×
4709

×
4710
        // Batch load the node features.
×
4711
        features, err := batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
4712
        if err != nil {
×
4713
                return nil, fmt.Errorf("unable to batch load node "+
×
4714
                        "features: %w", err)
×
4715
        }
×
4716

4717
        // Batch load the node addresses.
4718
        addrs, err := batchLoadNodeAddressesHelper(ctx, cfg, db, nodeIDs)
×
4719
        if err != nil {
×
4720
                return nil, fmt.Errorf("unable to batch load node "+
×
4721
                        "addresses: %w", err)
×
4722
        }
×
4723

4724
        // Batch load the node extra signed fields.
4725
        extraTypes, err := batchLoadNodeExtraTypesHelper(ctx, cfg, db, nodeIDs)
×
4726
        if err != nil {
×
4727
                return nil, fmt.Errorf("unable to batch load node extra "+
×
4728
                        "signed fields: %w", err)
×
4729
        }
×
4730

4731
        return &batchNodeData{
×
4732
                features:    features,
×
4733
                addresses:   addrs,
×
4734
                extraFields: extraTypes,
×
4735
        }, nil
×
4736
}
4737

4738
// batchLoadNodeFeaturesHelper loads node features for a batch of node IDs
4739
// using ExecuteBatchQuery wrapper around the GetNodeFeaturesBatch query.
4740
func batchLoadNodeFeaturesHelper(ctx context.Context,
4741
        cfg *sqldb.QueryConfig, db SQLQueries,
4742
        nodeIDs []int64) (map[int64][]int, error) {
×
4743

×
4744
        features := make(map[int64][]int)
×
4745

×
4746
        return features, sqldb.ExecuteBatchQuery(
×
4747
                ctx, cfg, nodeIDs,
×
4748
                func(id int64) int64 {
×
4749
                        return id
×
4750
                },
×
4751
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature,
4752
                        error) {
×
4753

×
4754
                        return db.GetNodeFeaturesBatch(ctx, ids)
×
4755
                },
×
4756
                func(ctx context.Context, feature sqlc.GraphNodeFeature) error {
×
4757
                        features[feature.NodeID] = append(
×
4758
                                features[feature.NodeID],
×
4759
                                int(feature.FeatureBit),
×
4760
                        )
×
4761

×
4762
                        return nil
×
4763
                },
×
4764
        )
4765
}
4766

4767
// batchLoadNodeAddressesHelper loads node addresses using ExecuteBatchQuery
4768
// wrapper around the GetNodeAddressesBatch query. It returns a map from
4769
// node ID to a slice of nodeAddress structs.
4770
func batchLoadNodeAddressesHelper(ctx context.Context,
4771
        cfg *sqldb.QueryConfig, db SQLQueries,
4772
        nodeIDs []int64) (map[int64][]nodeAddress, error) {
×
4773

×
4774
        addrs := make(map[int64][]nodeAddress)
×
4775

×
4776
        return addrs, sqldb.ExecuteBatchQuery(
×
4777
                ctx, cfg, nodeIDs,
×
4778
                func(id int64) int64 {
×
4779
                        return id
×
4780
                },
×
4781
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress,
4782
                        error) {
×
4783

×
4784
                        return db.GetNodeAddressesBatch(ctx, ids)
×
4785
                },
×
4786
                func(ctx context.Context, addr sqlc.GraphNodeAddress) error {
×
4787
                        addrs[addr.NodeID] = append(
×
4788
                                addrs[addr.NodeID], nodeAddress{
×
4789
                                        addrType: dbAddressType(addr.Type),
×
4790
                                        position: addr.Position,
×
4791
                                        address:  addr.Address,
×
4792
                                },
×
4793
                        )
×
4794

×
4795
                        return nil
×
4796
                },
×
4797
        )
4798
}
4799

4800
// batchLoadNodeExtraTypesHelper loads node extra type bytes for a batch of
4801
// node IDs using ExecuteBatchQuery wrapper around the GetNodeExtraTypesBatch
4802
// query.
4803
func batchLoadNodeExtraTypesHelper(ctx context.Context,
4804
        cfg *sqldb.QueryConfig, db SQLQueries,
4805
        nodeIDs []int64) (map[int64]map[uint64][]byte, error) {
×
4806

×
4807
        extraFields := make(map[int64]map[uint64][]byte)
×
4808

×
4809
        callback := func(ctx context.Context,
×
4810
                field sqlc.GraphNodeExtraType) error {
×
4811

×
4812
                if extraFields[field.NodeID] == nil {
×
4813
                        extraFields[field.NodeID] = make(map[uint64][]byte)
×
4814
                }
×
4815
                extraFields[field.NodeID][uint64(field.Type)] = field.Value
×
4816

×
4817
                return nil
×
4818
        }
4819

4820
        return extraFields, sqldb.ExecuteBatchQuery(
×
4821
                ctx, cfg, nodeIDs,
×
4822
                func(id int64) int64 {
×
4823
                        return id
×
4824
                },
×
4825
                func(ctx context.Context, ids []int64) (
4826
                        []sqlc.GraphNodeExtraType, error) {
×
4827

×
4828
                        return db.GetNodeExtraTypesBatch(ctx, ids)
×
4829
                },
×
4830
                callback,
4831
        )
4832
}
4833

4834
// buildChanPoliciesWithBatchData builds two models.ChannelEdgePolicy instances
4835
// from the provided sqlc.GraphChannelPolicy records and the
4836
// provided batchChannelData.
4837
func buildChanPoliciesWithBatchData(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4838
        channelID uint64, node1, node2 route.Vertex,
4839
        batchData *batchChannelData) (*models.ChannelEdgePolicy,
4840
        *models.ChannelEdgePolicy, error) {
×
4841

×
4842
        pol1, err := buildChanPolicyWithBatchData(
×
4843
                dbPol1, channelID, node2, batchData,
×
4844
        )
×
4845
        if err != nil {
×
4846
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
4847
        }
×
4848

4849
        pol2, err := buildChanPolicyWithBatchData(
×
4850
                dbPol2, channelID, node1, batchData,
×
4851
        )
×
4852
        if err != nil {
×
4853
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
4854
        }
×
4855

4856
        return pol1, pol2, nil
×
4857
}
4858

4859
// buildChanPolicyWithBatchData builds a models.ChannelEdgePolicy instance from
4860
// the provided sqlc.GraphChannelPolicy and the provided batchChannelData.
4861
func buildChanPolicyWithBatchData(dbPol *sqlc.GraphChannelPolicy,
4862
        channelID uint64, toNode route.Vertex,
4863
        batchData *batchChannelData) (*models.ChannelEdgePolicy, error) {
×
4864

×
4865
        if dbPol == nil {
×
4866
                return nil, nil
×
4867
        }
×
4868

4869
        var dbPol1Extras map[uint64][]byte
×
4870
        if extras, exists := batchData.policyExtras[dbPol.ID]; exists {
×
4871
                dbPol1Extras = extras
×
4872
        } else {
×
4873
                dbPol1Extras = make(map[uint64][]byte)
×
4874
        }
×
4875

4876
        return buildChanPolicy(*dbPol, channelID, dbPol1Extras, toNode)
×
4877
}
4878

4879
// batchChannelData holds all the related data for a batch of channels.
4880
type batchChannelData struct {
4881
        // chanFeatures is a map from DB channel ID to a slice of feature bits.
4882
        chanfeatures map[int64][]int
4883

4884
        // chanExtras is a map from DB channel ID to a map of TLV type to
4885
        // extra signed field bytes.
4886
        chanExtraTypes map[int64]map[uint64][]byte
4887

4888
        // policyExtras is a map from DB channel policy ID to a map of TLV type
4889
        // to extra signed field bytes.
4890
        policyExtras map[int64]map[uint64][]byte
4891
}
4892

4893
// batchLoadChannelData loads all related data for batches of channels and
4894
// policies.
4895
func batchLoadChannelData(ctx context.Context, cfg *sqldb.QueryConfig,
4896
        db SQLQueries, channelIDs []int64,
4897
        policyIDs []int64) (*batchChannelData, error) {
×
4898

×
4899
        batchData := &batchChannelData{
×
4900
                chanfeatures:   make(map[int64][]int),
×
4901
                chanExtraTypes: make(map[int64]map[uint64][]byte),
×
4902
                policyExtras:   make(map[int64]map[uint64][]byte),
×
4903
        }
×
4904

×
4905
        // Batch load channel features and extras
×
4906
        var err error
×
4907
        if len(channelIDs) > 0 {
×
4908
                batchData.chanfeatures, err = batchLoadChannelFeaturesHelper(
×
4909
                        ctx, cfg, db, channelIDs,
×
4910
                )
×
4911
                if err != nil {
×
4912
                        return nil, fmt.Errorf("unable to batch load "+
×
4913
                                "channel features: %w", err)
×
4914
                }
×
4915

4916
                batchData.chanExtraTypes, err = batchLoadChannelExtrasHelper(
×
4917
                        ctx, cfg, db, channelIDs,
×
4918
                )
×
4919
                if err != nil {
×
4920
                        return nil, fmt.Errorf("unable to batch load "+
×
4921
                                "channel extras: %w", err)
×
4922
                }
×
4923
        }
4924

4925
        if len(policyIDs) > 0 {
×
4926
                policyExtras, err := batchLoadChannelPolicyExtrasHelper(
×
4927
                        ctx, cfg, db, policyIDs,
×
4928
                )
×
4929
                if err != nil {
×
4930
                        return nil, fmt.Errorf("unable to batch load "+
×
4931
                                "policy extras: %w", err)
×
4932
                }
×
4933
                batchData.policyExtras = policyExtras
×
4934
        }
4935

4936
        return batchData, nil
×
4937
}
4938

4939
// batchLoadChannelFeaturesHelper loads channel features for a batch of
4940
// channel IDs using ExecuteBatchQuery wrapper around the
4941
// GetChannelFeaturesBatch query. It returns a map from DB channel ID to a
4942
// slice of feature bits.
4943
func batchLoadChannelFeaturesHelper(ctx context.Context,
4944
        cfg *sqldb.QueryConfig, db SQLQueries,
4945
        channelIDs []int64) (map[int64][]int, error) {
×
4946

×
4947
        features := make(map[int64][]int)
×
4948

×
4949
        return features, sqldb.ExecuteBatchQuery(
×
4950
                ctx, cfg, channelIDs,
×
4951
                func(id int64) int64 {
×
4952
                        return id
×
4953
                },
×
4954
                func(ctx context.Context,
4955
                        ids []int64) ([]sqlc.GraphChannelFeature, error) {
×
4956

×
4957
                        return db.GetChannelFeaturesBatch(ctx, ids)
×
4958
                },
×
4959
                func(ctx context.Context,
4960
                        feature sqlc.GraphChannelFeature) error {
×
4961

×
4962
                        features[feature.ChannelID] = append(
×
4963
                                features[feature.ChannelID],
×
4964
                                int(feature.FeatureBit),
×
4965
                        )
×
4966

×
4967
                        return nil
×
4968
                },
×
4969
        )
4970
}
4971

4972
// batchLoadChannelExtrasHelper loads channel extra types for a batch of
4973
// channel IDs using ExecuteBatchQuery wrapper around the GetChannelExtrasBatch
4974
// query. It returns a map from DB channel ID to a map of TLV type to extra
4975
// signed field bytes.
4976
func batchLoadChannelExtrasHelper(ctx context.Context,
4977
        cfg *sqldb.QueryConfig, db SQLQueries,
4978
        channelIDs []int64) (map[int64]map[uint64][]byte, error) {
×
4979

×
4980
        extras := make(map[int64]map[uint64][]byte)
×
4981

×
4982
        cb := func(ctx context.Context,
×
4983
                extra sqlc.GraphChannelExtraType) error {
×
4984

×
4985
                if extras[extra.ChannelID] == nil {
×
4986
                        extras[extra.ChannelID] = make(map[uint64][]byte)
×
4987
                }
×
4988
                extras[extra.ChannelID][uint64(extra.Type)] = extra.Value
×
4989

×
4990
                return nil
×
4991
        }
4992

4993
        return extras, sqldb.ExecuteBatchQuery(
×
4994
                ctx, cfg, channelIDs,
×
4995
                func(id int64) int64 {
×
4996
                        return id
×
4997
                },
×
4998
                func(ctx context.Context,
4999
                        ids []int64) ([]sqlc.GraphChannelExtraType, error) {
×
5000

×
5001
                        return db.GetChannelExtrasBatch(ctx, ids)
×
5002
                }, cb,
×
5003
        )
5004
}
5005

5006
// batchLoadChannelPolicyExtrasHelper loads channel policy extra types for a
5007
// batch of policy IDs using ExecuteBatchQuery wrapper around the
5008
// GetChannelPolicyExtraTypesBatch query. It returns a map from DB policy ID to
5009
// a map of TLV type to extra signed field bytes.
5010
func batchLoadChannelPolicyExtrasHelper(ctx context.Context,
5011
        cfg *sqldb.QueryConfig, db SQLQueries,
5012
        policyIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5013

×
5014
        extras := make(map[int64]map[uint64][]byte)
×
5015

×
5016
        return extras, sqldb.ExecuteBatchQuery(
×
5017
                ctx, cfg, policyIDs,
×
5018
                func(id int64) int64 {
×
5019
                        return id
×
5020
                },
×
5021
                func(ctx context.Context, ids []int64) (
5022
                        []sqlc.GetChannelPolicyExtraTypesBatchRow, error) {
×
5023

×
5024
                        return db.GetChannelPolicyExtraTypesBatch(ctx, ids)
×
5025
                },
×
5026
                func(ctx context.Context,
5027
                        row sqlc.GetChannelPolicyExtraTypesBatchRow) error {
×
5028

×
5029
                        if extras[row.PolicyID] == nil {
×
5030
                                extras[row.PolicyID] = make(map[uint64][]byte)
×
5031
                        }
×
5032
                        extras[row.PolicyID][uint64(row.Type)] = row.Value
×
5033

×
5034
                        return nil
×
5035
                },
5036
        )
5037
}
5038

5039
// forEachNodePaginated executes a paginated query to process each node in the
5040
// graph. It uses the provided SQLQueries interface to fetch nodes in batches
5041
// and applies the provided processNode function to each node.
5042
func forEachNodePaginated(ctx context.Context, cfg *sqldb.QueryConfig,
5043
        db SQLQueries, protocol ProtocolVersion,
5044
        processNode func(context.Context, int64,
5045
                *models.LightningNode) error) error {
×
5046

×
5047
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
5048
                limit int32) ([]sqlc.GraphNode, error) {
×
5049

×
5050
                return db.ListNodesPaginated(
×
5051
                        ctx, sqlc.ListNodesPaginatedParams{
×
5052
                                Version: int16(protocol),
×
5053
                                ID:      lastID,
×
5054
                                Limit:   limit,
×
5055
                        },
×
5056
                )
×
5057
        }
×
5058

5059
        extractPageCursor := func(node sqlc.GraphNode) int64 {
×
5060
                return node.ID
×
5061
        }
×
5062

5063
        collectFunc := func(node sqlc.GraphNode) (int64, error) {
×
5064
                return node.ID, nil
×
5065
        }
×
5066

5067
        batchQueryFunc := func(ctx context.Context,
×
5068
                nodeIDs []int64) (*batchNodeData, error) {
×
5069

×
5070
                return batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
5071
        }
×
5072

5073
        processItem := func(ctx context.Context, dbNode sqlc.GraphNode,
×
5074
                batchData *batchNodeData) error {
×
5075

×
5076
                node, err := buildNodeWithBatchData(dbNode, batchData)
×
5077
                if err != nil {
×
5078
                        return fmt.Errorf("unable to build "+
×
5079
                                "node(id=%d): %w", dbNode.ID, err)
×
5080
                }
×
5081

5082
                return processNode(ctx, dbNode.ID, node)
×
5083
        }
5084

5085
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
5086
                ctx, cfg, int64(-1), pageQueryFunc, extractPageCursor,
×
5087
                collectFunc, batchQueryFunc, processItem,
×
5088
        )
×
5089
}
5090

5091
// forEachChannelWithPolicies executes a paginated query to process each channel
5092
// with policies in the graph.
5093
func forEachChannelWithPolicies(ctx context.Context, db SQLQueries,
5094
        cfg *SQLStoreConfig, processChannel func(*models.ChannelEdgeInfo,
5095
                *models.ChannelEdgePolicy,
5096
                *models.ChannelEdgePolicy) error) error {
×
5097

×
5098
        type channelBatchIDs struct {
×
5099
                channelID int64
×
5100
                policyIDs []int64
×
5101
        }
×
5102

×
5103
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
5104
                limit int32) ([]sqlc.ListChannelsWithPoliciesPaginatedRow,
×
5105
                error) {
×
5106

×
5107
                return db.ListChannelsWithPoliciesPaginated(
×
5108
                        ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
5109
                                Version: int16(ProtocolV1),
×
5110
                                ID:      lastID,
×
5111
                                Limit:   limit,
×
5112
                        },
×
5113
                )
×
5114
        }
×
5115

5116
        extractPageCursor := func(
×
5117
                row sqlc.ListChannelsWithPoliciesPaginatedRow) int64 {
×
5118

×
5119
                return row.GraphChannel.ID
×
5120
        }
×
5121

5122
        collectFunc := func(row sqlc.ListChannelsWithPoliciesPaginatedRow) (
×
5123
                channelBatchIDs, error) {
×
5124

×
5125
                ids := channelBatchIDs{
×
5126
                        channelID: row.GraphChannel.ID,
×
5127
                }
×
5128

×
5129
                // Extract policy IDs from the row.
×
5130
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5131
                if err != nil {
×
5132
                        return ids, err
×
5133
                }
×
5134

5135
                if dbPol1 != nil {
×
5136
                        ids.policyIDs = append(ids.policyIDs, dbPol1.ID)
×
5137
                }
×
5138
                if dbPol2 != nil {
×
5139
                        ids.policyIDs = append(ids.policyIDs, dbPol2.ID)
×
5140
                }
×
5141

5142
                return ids, nil
×
5143
        }
5144

5145
        batchDataFunc := func(ctx context.Context,
×
5146
                allIDs []channelBatchIDs) (*batchChannelData, error) {
×
5147

×
5148
                // Separate channel IDs from policy IDs.
×
5149
                var (
×
5150
                        channelIDs = make([]int64, len(allIDs))
×
5151
                        policyIDs  = make([]int64, 0, len(allIDs)*2)
×
5152
                )
×
5153

×
5154
                for i, ids := range allIDs {
×
5155
                        channelIDs[i] = ids.channelID
×
5156
                        policyIDs = append(policyIDs, ids.policyIDs...)
×
5157
                }
×
5158

5159
                return batchLoadChannelData(
×
5160
                        ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
5161
                )
×
5162
        }
5163

5164
        processItem := func(ctx context.Context,
×
5165
                row sqlc.ListChannelsWithPoliciesPaginatedRow,
×
5166
                batchData *batchChannelData) error {
×
5167

×
5168
                node1, node2, err := buildNodeVertices(
×
5169
                        row.Node1Pubkey, row.Node2Pubkey,
×
5170
                )
×
5171
                if err != nil {
×
5172
                        return err
×
5173
                }
×
5174

5175
                edge, err := buildEdgeInfoWithBatchData(
×
5176
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
5177
                        batchData,
×
5178
                )
×
5179
                if err != nil {
×
5180
                        return fmt.Errorf("unable to build channel info: %w",
×
5181
                                err)
×
5182
                }
×
5183

5184
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5185
                if err != nil {
×
5186
                        return err
×
5187
                }
×
5188

5189
                p1, p2, err := buildChanPoliciesWithBatchData(
×
5190
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
5191
                )
×
5192
                if err != nil {
×
5193
                        return err
×
5194
                }
×
5195

5196
                return processChannel(edge, p1, p2)
×
5197
        }
5198

5199
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
5200
                ctx, cfg.QueryCfg, int64(-1), pageQueryFunc, extractPageCursor,
×
5201
                collectFunc, batchDataFunc, processItem,
×
5202
        )
×
5203
}
5204

5205
// buildDirectedChannel builds a DirectedChannel instance from the provided
5206
// data.
5207
func buildDirectedChannel(chain chainhash.Hash, nodeID int64,
5208
        nodePub route.Vertex, channelRow sqlc.ListChannelsForNodeIDsRow,
5209
        channelBatchData *batchChannelData, features *lnwire.FeatureVector,
5210
        toNodeCallback func() route.Vertex) (*DirectedChannel, error) {
×
5211

×
5212
        node1, node2, err := buildNodeVertices(
×
5213
                channelRow.Node1Pubkey, channelRow.Node2Pubkey,
×
5214
        )
×
5215
        if err != nil {
×
5216
                return nil, fmt.Errorf("unable to build node vertices: %w", err)
×
5217
        }
×
5218

5219
        edge, err := buildEdgeInfoWithBatchData(
×
5220
                chain, channelRow.GraphChannel, node1, node2, channelBatchData,
×
5221
        )
×
5222
        if err != nil {
×
5223
                return nil, fmt.Errorf("unable to build channel info: %w", err)
×
5224
        }
×
5225

5226
        dbPol1, dbPol2, err := extractChannelPolicies(channelRow)
×
5227
        if err != nil {
×
5228
                return nil, fmt.Errorf("unable to extract channel policies: %w",
×
5229
                        err)
×
5230
        }
×
5231

5232
        p1, p2, err := buildChanPoliciesWithBatchData(
×
5233
                dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
5234
                channelBatchData,
×
5235
        )
×
5236
        if err != nil {
×
5237
                return nil, fmt.Errorf("unable to build channel policies: %w",
×
5238
                        err)
×
5239
        }
×
5240

5241
        // Determine outgoing and incoming policy for this specific node.
5242
        p1ToNode := channelRow.GraphChannel.NodeID2
×
5243
        p2ToNode := channelRow.GraphChannel.NodeID1
×
5244
        outPolicy, inPolicy := p1, p2
×
5245
        if (p1 != nil && p1ToNode == nodeID) ||
×
5246
                (p2 != nil && p2ToNode != nodeID) {
×
5247

×
5248
                outPolicy, inPolicy = p2, p1
×
5249
        }
×
5250

5251
        // Build cached policy.
5252
        var cachedInPolicy *models.CachedEdgePolicy
×
5253
        if inPolicy != nil {
×
5254
                cachedInPolicy = models.NewCachedPolicy(inPolicy)
×
5255
                cachedInPolicy.ToNodePubKey = toNodeCallback
×
5256
                cachedInPolicy.ToNodeFeatures = features
×
5257
        }
×
5258

5259
        // Extract inbound fee.
5260
        var inboundFee lnwire.Fee
×
5261
        if outPolicy != nil {
×
5262
                outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
5263
                        inboundFee = fee
×
5264
                })
×
5265
        }
5266

5267
        // Build directed channel.
5268
        directedChannel := &DirectedChannel{
×
5269
                ChannelID:    edge.ChannelID,
×
5270
                IsNode1:      nodePub == edge.NodeKey1Bytes,
×
5271
                OtherNode:    edge.NodeKey2Bytes,
×
5272
                Capacity:     edge.Capacity,
×
5273
                OutPolicySet: outPolicy != nil,
×
5274
                InPolicy:     cachedInPolicy,
×
5275
                InboundFee:   inboundFee,
×
5276
        }
×
5277

×
5278
        if nodePub == edge.NodeKey2Bytes {
×
5279
                directedChannel.OtherNode = edge.NodeKey1Bytes
×
5280
        }
×
5281

5282
        return directedChannel, nil
×
5283
}
5284

5285
// batchBuildChannelEdges builds a slice of ChannelEdge instances from the
5286
// provided rows. It uses batch loading for channels, policies, and nodes.
5287
func batchBuildChannelEdges[T sqlc.ChannelAndNodes](ctx context.Context,
5288
        cfg *SQLStoreConfig, db SQLQueries, rows []T) ([]ChannelEdge, error) {
×
5289

×
5290
        var (
×
5291
                channelIDs = make([]int64, len(rows))
×
5292
                policyIDs  = make([]int64, 0, len(rows)*2)
×
5293
                nodeIDs    = make([]int64, 0, len(rows)*2)
×
5294

×
5295
                // nodeIDSet is used to ensure we only collect unique node IDs.
×
5296
                nodeIDSet = make(map[int64]bool)
×
5297

×
5298
                // edges will hold the final channel edges built from the rows.
×
5299
                edges = make([]ChannelEdge, 0, len(rows))
×
5300
        )
×
5301

×
5302
        // Collect all IDs needed for batch loading.
×
5303
        for i, row := range rows {
×
5304
                channelIDs[i] = row.Channel().ID
×
5305

×
5306
                // Collect policy IDs
×
5307
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5308
                if err != nil {
×
5309
                        return nil, fmt.Errorf("unable to extract channel "+
×
5310
                                "policies: %w", err)
×
5311
                }
×
5312
                if dbPol1 != nil {
×
5313
                        policyIDs = append(policyIDs, dbPol1.ID)
×
5314
                }
×
5315
                if dbPol2 != nil {
×
5316
                        policyIDs = append(policyIDs, dbPol2.ID)
×
5317
                }
×
5318

5319
                var (
×
5320
                        node1ID = row.Node1().ID
×
5321
                        node2ID = row.Node2().ID
×
5322
                )
×
5323

×
5324
                // Collect unique node IDs.
×
5325
                if !nodeIDSet[node1ID] {
×
5326
                        nodeIDs = append(nodeIDs, node1ID)
×
5327
                        nodeIDSet[node1ID] = true
×
5328
                }
×
5329

5330
                if !nodeIDSet[node2ID] {
×
5331
                        nodeIDs = append(nodeIDs, node2ID)
×
5332
                        nodeIDSet[node2ID] = true
×
5333
                }
×
5334
        }
5335

5336
        // Batch the data for all the channels and policies.
5337
        channelBatchData, err := batchLoadChannelData(
×
5338
                ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
5339
        )
×
5340
        if err != nil {
×
5341
                return nil, fmt.Errorf("unable to batch load channel and "+
×
5342
                        "policy data: %w", err)
×
5343
        }
×
5344

5345
        // Batch the data for all the nodes.
5346
        nodeBatchData, err := batchLoadNodeData(ctx, cfg.QueryCfg, db, nodeIDs)
×
5347
        if err != nil {
×
5348
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
5349
                        err)
×
5350
        }
×
5351

5352
        // Build all channel edges using batch data.
5353
        for _, row := range rows {
×
5354
                // Build nodes using batch data.
×
5355
                node1, err := buildNodeWithBatchData(row.Node1(), nodeBatchData)
×
5356
                if err != nil {
×
5357
                        return nil, fmt.Errorf("unable to build node1: %w", err)
×
5358
                }
×
5359

5360
                node2, err := buildNodeWithBatchData(row.Node2(), nodeBatchData)
×
5361
                if err != nil {
×
5362
                        return nil, fmt.Errorf("unable to build node2: %w", err)
×
5363
                }
×
5364

5365
                // Build channel info using batch data.
5366
                channel, err := buildEdgeInfoWithBatchData(
×
5367
                        cfg.ChainHash, row.Channel(), node1.PubKeyBytes,
×
5368
                        node2.PubKeyBytes, channelBatchData,
×
5369
                )
×
5370
                if err != nil {
×
5371
                        return nil, fmt.Errorf("unable to build channel "+
×
5372
                                "info: %w", err)
×
5373
                }
×
5374

5375
                // Extract and build policies using batch data.
5376
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5377
                if err != nil {
×
5378
                        return nil, fmt.Errorf("unable to extract channel "+
×
5379
                                "policies: %w", err)
×
5380
                }
×
5381

5382
                p1, p2, err := buildChanPoliciesWithBatchData(
×
5383
                        dbPol1, dbPol2, channel.ChannelID,
×
5384
                        node1.PubKeyBytes, node2.PubKeyBytes, channelBatchData,
×
5385
                )
×
5386
                if err != nil {
×
5387
                        return nil, fmt.Errorf("unable to build channel "+
×
5388
                                "policies: %w", err)
×
5389
                }
×
5390

5391
                edges = append(edges, ChannelEdge{
×
5392
                        Info:    channel,
×
5393
                        Policy1: p1,
×
5394
                        Policy2: p2,
×
5395
                        Node1:   node1,
×
5396
                        Node2:   node2,
×
5397
                })
×
5398
        }
5399

5400
        return edges, nil
×
5401
}
5402

5403
// batchBuildChannelInfo builds a slice of models.ChannelEdgeInfo
5404
// instances from the provided rows using batch loading for channel data.
5405
func batchBuildChannelInfo[T sqlc.ChannelAndNodeIDs](ctx context.Context,
5406
        cfg *SQLStoreConfig, db SQLQueries, rows []T) (
5407
        []*models.ChannelEdgeInfo, []int64, error) {
×
5408

×
5409
        if len(rows) == 0 {
×
5410
                return nil, nil, nil
×
5411
        }
×
5412

5413
        // Collect all the channel IDs needed for batch loading.
5414
        channelIDs := make([]int64, len(rows))
×
5415
        for i, row := range rows {
×
5416
                channelIDs[i] = row.Channel().ID
×
5417
        }
×
5418

5419
        // Batch load the channel data.
5420
        channelBatchData, err := batchLoadChannelData(
×
5421
                ctx, cfg.QueryCfg, db, channelIDs, nil,
×
5422
        )
×
5423
        if err != nil {
×
5424
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
5425
                        "data: %w", err)
×
5426
        }
×
5427

5428
        // Build all channel edges using batch data.
5429
        edges := make([]*models.ChannelEdgeInfo, 0, len(rows))
×
5430
        for _, row := range rows {
×
5431
                node1, node2, err := buildNodeVertices(
×
5432
                        row.Node1Pub(), row.Node2Pub(),
×
5433
                )
×
5434
                if err != nil {
×
5435
                        return nil, nil, err
×
5436
                }
×
5437

5438
                // Build channel info using batch data
5439
                info, err := buildEdgeInfoWithBatchData(
×
5440
                        cfg.ChainHash, row.Channel(), node1, node2,
×
5441
                        channelBatchData,
×
5442
                )
×
5443
                if err != nil {
×
5444
                        return nil, nil, err
×
5445
                }
×
5446

5447
                edges = append(edges, info)
×
5448
        }
5449

5450
        return edges, channelIDs, nil
×
5451
}
5452

5453
// handleZombieMarking is a helper function that handles the logic of
5454
// marking a channel as a zombie in the database. It takes into account whether
5455
// we are in strict zombie pruning mode, and adjusts the node public keys
5456
// accordingly based on the last update timestamps of the channel policies.
5457
func handleZombieMarking(ctx context.Context, db SQLQueries,
5458
        row sqlc.GetChannelsBySCIDWithPoliciesRow, info *models.ChannelEdgeInfo,
5459
        strictZombiePruning bool, scid uint64) error {
×
5460

×
5461
        nodeKey1, nodeKey2 := info.NodeKey1Bytes, info.NodeKey2Bytes
×
5462

×
5463
        if strictZombiePruning {
×
5464
                var e1UpdateTime, e2UpdateTime *time.Time
×
5465
                if row.Policy1LastUpdate.Valid {
×
5466
                        e1Time := time.Unix(row.Policy1LastUpdate.Int64, 0)
×
5467
                        e1UpdateTime = &e1Time
×
5468
                }
×
5469
                if row.Policy2LastUpdate.Valid {
×
5470
                        e2Time := time.Unix(row.Policy2LastUpdate.Int64, 0)
×
5471
                        e2UpdateTime = &e2Time
×
5472
                }
×
5473

5474
                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
5475
                        info.NodeKey1Bytes, info.NodeKey2Bytes, e1UpdateTime,
×
5476
                        e2UpdateTime,
×
5477
                )
×
5478
        }
5479

5480
        return db.UpsertZombieChannel(
×
5481
                ctx, sqlc.UpsertZombieChannelParams{
×
5482
                        Version:  int16(ProtocolV1),
×
5483
                        Scid:     channelIDToBytes(scid),
×
5484
                        NodeKey1: nodeKey1[:],
×
5485
                        NodeKey2: nodeKey2[:],
×
5486
                },
×
5487
        )
×
5488
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc