• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 17994533969

25 Sep 2025 01:52AM UTC coverage: 66.626% (-0.03%) from 66.651%
17994533969

Pull #10128

github

web-flow
Merge 2b459bef3 into df46d1862
Pull Request #10128: multi: update ChanUpdatesInHorizon and NodeUpdatesInHorizon to return iterators (iter.Seq[T])

308 of 622 new or added lines in 7 files covered. (49.52%)

75 existing lines in 15 files now uncovered.

136757 of 205262 relevant lines covered (66.63%)

21384.62 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "iter"
11
        "maps"
12
        "math"
13
        "net"
14
        "slices"
15
        "strconv"
16
        "sync"
17
        "time"
18

19
        "github.com/btcsuite/btcd/btcec/v2"
20
        "github.com/btcsuite/btcd/btcutil"
21
        "github.com/btcsuite/btcd/chaincfg/chainhash"
22
        "github.com/btcsuite/btcd/wire"
23
        "github.com/lightningnetwork/lnd/aliasmgr"
24
        "github.com/lightningnetwork/lnd/batch"
25
        "github.com/lightningnetwork/lnd/fn/v2"
26
        "github.com/lightningnetwork/lnd/graph/db/models"
27
        "github.com/lightningnetwork/lnd/lnwire"
28
        "github.com/lightningnetwork/lnd/routing/route"
29
        "github.com/lightningnetwork/lnd/sqldb"
30
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
31
        "github.com/lightningnetwork/lnd/tlv"
32
        "github.com/lightningnetwork/lnd/tor"
33
)
34

35
// ProtocolVersion is an enum that defines the gossip protocol version of a
36
// message.
37
type ProtocolVersion uint8
38

39
const (
40
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
41
        ProtocolV1 ProtocolVersion = 1
42
)
43

44
// String returns a string representation of the protocol version.
45
func (v ProtocolVersion) String() string {
×
46
        return fmt.Sprintf("V%d", v)
×
47
}
×
48

49
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
50
// execute queries against the SQL graph tables.
51
//
52
//nolint:ll,interfacebloat
53
type SQLQueries interface {
54
        /*
55
                Node queries.
56
        */
57
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
58
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.GraphNode, error)
59
        GetNodesByIDs(ctx context.Context, ids []int64) ([]sqlc.GraphNode, error)
60
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
61
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.GraphNode, error)
62
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.GraphNode, error)
63
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
64
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
65
        DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error)
66
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
67
        DeleteNode(ctx context.Context, id int64) error
68

69
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeExtraType, error)
70
        GetNodeExtraTypesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeExtraType, error)
71
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
72
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
73

74
        UpsertNodeAddress(ctx context.Context, arg sqlc.UpsertNodeAddressParams) error
75
        GetNodeAddresses(ctx context.Context, nodeID int64) ([]sqlc.GetNodeAddressesRow, error)
76
        GetNodeAddressesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress, error)
77
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
78

79
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
80
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeFeature, error)
81
        GetNodeFeaturesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature, error)
82
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
83
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
84

85
        /*
86
                Source node queries.
87
        */
88
        AddSourceNode(ctx context.Context, nodeID int64) error
89
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
90

91
        /*
92
                Channel queries.
93
        */
94
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
95
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
96
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.GraphChannel, error)
97
        GetChannelsBySCIDs(ctx context.Context, arg sqlc.GetChannelsBySCIDsParams) ([]sqlc.GraphChannel, error)
98
        GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]sqlc.GetChannelsByOutpointsRow, error)
99
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
100
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
101
        GetChannelsBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelsBySCIDWithPoliciesParams) ([]sqlc.GetChannelsBySCIDWithPoliciesRow, error)
102
        GetChannelsByIDs(ctx context.Context, ids []int64) ([]sqlc.GetChannelsByIDsRow, error)
103
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
104
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
105
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
106
        ListChannelsForNodeIDs(ctx context.Context, arg sqlc.ListChannelsForNodeIDsParams) ([]sqlc.ListChannelsForNodeIDsRow, error)
107
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
108
        ListChannelsWithPoliciesForCachePaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesForCachePaginatedParams) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow, error)
109
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
110
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
111
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
112
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.GraphChannel, error)
113
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
114
        DeleteChannels(ctx context.Context, ids []int64) error
115

116
        UpsertChannelExtraType(ctx context.Context, arg sqlc.UpsertChannelExtraTypeParams) error
117
        GetChannelExtrasBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelExtraType, error)
118
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
119
        GetChannelFeaturesBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelFeature, error)
120

121
        /*
122
                Channel Policy table queries.
123
        */
124
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
125
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.GraphChannelPolicy, error)
126
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
127

128
        UpsertChanPolicyExtraType(ctx context.Context, arg sqlc.UpsertChanPolicyExtraTypeParams) error
129
        GetChannelPolicyExtraTypesBatch(ctx context.Context, policyIds []int64) ([]sqlc.GetChannelPolicyExtraTypesBatchRow, error)
130
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
131

132
        /*
133
                Zombie index queries.
134
        */
135
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
136
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.GraphZombieChannel, error)
137
        GetZombieChannelsSCIDs(ctx context.Context, arg sqlc.GetZombieChannelsSCIDsParams) ([]sqlc.GraphZombieChannel, error)
138
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
139
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
140
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
141

142
        /*
143
                Prune log table queries.
144
        */
145
        GetPruneTip(ctx context.Context) (sqlc.GraphPruneLog, error)
146
        GetPruneHashByHeight(ctx context.Context, blockHeight int64) ([]byte, error)
147
        GetPruneEntriesForHeights(ctx context.Context, heights []int64) ([]sqlc.GraphPruneLog, error)
148
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
149
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
150

151
        /*
152
                Closed SCID table queries.
153
        */
154
        InsertClosedChannel(ctx context.Context, scid []byte) error
155
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
156
        GetClosedChannelsSCIDs(ctx context.Context, scids [][]byte) ([][]byte, error)
157

158
        /*
159
                Migration specific queries.
160

161
                NOTE: these should not be used in code other than migrations.
162
                Once sqldbv2 is in place, these can be removed from this struct
163
                as then migrations will have their own dedicated queries
164
                structs.
165
        */
166
        InsertNodeMig(ctx context.Context, arg sqlc.InsertNodeMigParams) (int64, error)
167
        InsertChannelMig(ctx context.Context, arg sqlc.InsertChannelMigParams) (int64, error)
168
        InsertEdgePolicyMig(ctx context.Context, arg sqlc.InsertEdgePolicyMigParams) (int64, error)
169
}
170

171
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
172
// database operations.
173
type BatchedSQLQueries interface {
174
        SQLQueries
175
        sqldb.BatchedTx[SQLQueries]
176
}
177

178
// SQLStore is an implementation of the V1Store interface that uses a SQL
179
// database as the backend.
180
type SQLStore struct {
181
        cfg *SQLStoreConfig
182
        db  BatchedSQLQueries
183

184
        // cacheMu guards all caches (rejectCache and chanCache). If
185
        // this mutex will be acquired at the same time as the DB mutex then
186
        // the cacheMu MUST be acquired first to prevent deadlock.
187
        cacheMu     sync.RWMutex
188
        rejectCache *rejectCache
189
        chanCache   *channelCache
190

191
        chanScheduler batch.Scheduler[SQLQueries]
192
        nodeScheduler batch.Scheduler[SQLQueries]
193

194
        srcNodes  map[ProtocolVersion]*srcNodeInfo
195
        srcNodeMu sync.Mutex
196
}
197

198
// A compile-time assertion to ensure that SQLStore implements the V1Store
199
// interface.
200
var _ V1Store = (*SQLStore)(nil)
201

202
// SQLStoreConfig holds the configuration for the SQLStore.
203
type SQLStoreConfig struct {
204
        // ChainHash is the genesis hash for the chain that all the gossip
205
        // messages in this store are aimed at.
206
        ChainHash chainhash.Hash
207

208
        // QueryConfig holds configuration values for SQL queries.
209
        QueryCfg *sqldb.QueryConfig
210
}
211

212
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
213
// storage backend.
214
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
215
        options ...StoreOptionModifier) (*SQLStore, error) {
×
216

×
217
        opts := DefaultOptions()
×
218
        for _, o := range options {
×
219
                o(opts)
×
220
        }
×
221

222
        if opts.NoMigration {
×
223
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
224
                        "supported for SQL stores")
×
225
        }
×
226

227
        s := &SQLStore{
×
228
                cfg:         cfg,
×
229
                db:          db,
×
230
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
231
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
232
                srcNodes:    make(map[ProtocolVersion]*srcNodeInfo),
×
233
        }
×
234

×
235
        s.chanScheduler = batch.NewTimeScheduler(
×
236
                db, &s.cacheMu, opts.BatchCommitInterval,
×
237
        )
×
238
        s.nodeScheduler = batch.NewTimeScheduler(
×
239
                db, nil, opts.BatchCommitInterval,
×
240
        )
×
241

×
242
        return s, nil
×
243
}
244

245
// AddNode adds a vertex/node to the graph database. If the node is not
246
// in the database from before, this will add a new, unconnected one to the
247
// graph. If it is present from before, this will update that node's
248
// information.
249
//
250
// NOTE: part of the V1Store interface.
251
func (s *SQLStore) AddNode(ctx context.Context,
252
        node *models.Node, opts ...batch.SchedulerOption) error {
×
253

×
254
        r := &batch.Request[SQLQueries]{
×
255
                Opts: batch.NewSchedulerOptions(opts...),
×
256
                Do: func(queries SQLQueries) error {
×
257
                        _, err := upsertNode(ctx, queries, node)
×
258
                        return err
×
259
                },
×
260
        }
261

262
        return s.nodeScheduler.Execute(ctx, r)
×
263
}
264

265
// FetchNode attempts to look up a target node by its identity public
266
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
267
// returned.
268
//
269
// NOTE: part of the V1Store interface.
270
func (s *SQLStore) FetchNode(ctx context.Context,
271
        pubKey route.Vertex) (*models.Node, error) {
×
272

×
273
        var node *models.Node
×
274
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
275
                var err error
×
276
                _, node, err = getNodeByPubKey(ctx, s.cfg.QueryCfg, db, pubKey)
×
277

×
278
                return err
×
279
        }, sqldb.NoOpReset)
×
280
        if err != nil {
×
281
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
282
        }
×
283

284
        return node, nil
×
285
}
286

287
// HasNode determines if the graph has a vertex identified by the
288
// target node identity public key. If the node exists in the database, a
289
// timestamp of when the data for the node was lasted updated is returned along
290
// with a true boolean. Otherwise, an empty time.Time is returned with a false
291
// boolean.
292
//
293
// NOTE: part of the V1Store interface.
294
func (s *SQLStore) HasNode(ctx context.Context,
295
        pubKey [33]byte) (time.Time, bool, error) {
×
296

×
297
        var (
×
298
                exists     bool
×
299
                lastUpdate time.Time
×
300
        )
×
301
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
302
                dbNode, err := db.GetNodeByPubKey(
×
303
                        ctx, sqlc.GetNodeByPubKeyParams{
×
304
                                Version: int16(ProtocolV1),
×
305
                                PubKey:  pubKey[:],
×
306
                        },
×
307
                )
×
308
                if errors.Is(err, sql.ErrNoRows) {
×
309
                        return nil
×
310
                } else if err != nil {
×
311
                        return fmt.Errorf("unable to fetch node: %w", err)
×
312
                }
×
313

314
                exists = true
×
315

×
316
                if dbNode.LastUpdate.Valid {
×
317
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
318
                }
×
319

320
                return nil
×
321
        }, sqldb.NoOpReset)
322
        if err != nil {
×
323
                return time.Time{}, false,
×
324
                        fmt.Errorf("unable to fetch node: %w", err)
×
325
        }
×
326

327
        return lastUpdate, exists, nil
×
328
}
329

330
// AddrsForNode returns all known addresses for the target node public key
331
// that the graph DB is aware of. The returned boolean indicates if the
332
// given node is unknown to the graph DB or not.
333
//
334
// NOTE: part of the V1Store interface.
335
func (s *SQLStore) AddrsForNode(ctx context.Context,
336
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
337

×
338
        var (
×
339
                addresses []net.Addr
×
340
                known     bool
×
341
        )
×
342
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
343
                // First, check if the node exists and get its DB ID if it
×
344
                // does.
×
345
                dbID, err := db.GetNodeIDByPubKey(
×
346
                        ctx, sqlc.GetNodeIDByPubKeyParams{
×
347
                                Version: int16(ProtocolV1),
×
348
                                PubKey:  nodePub.SerializeCompressed(),
×
349
                        },
×
350
                )
×
351
                if errors.Is(err, sql.ErrNoRows) {
×
352
                        return nil
×
353
                }
×
354

355
                known = true
×
356

×
357
                addresses, err = getNodeAddresses(ctx, db, dbID)
×
358
                if err != nil {
×
359
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
360
                                err)
×
361
                }
×
362

363
                return nil
×
364
        }, sqldb.NoOpReset)
365
        if err != nil {
×
366
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
367
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
368
        }
×
369

370
        return known, addresses, nil
×
371
}
372

373
// DeleteNode starts a new database transaction to remove a vertex/node
374
// from the database according to the node's public key.
375
//
376
// NOTE: part of the V1Store interface.
377
func (s *SQLStore) DeleteNode(ctx context.Context,
378
        pubKey route.Vertex) error {
×
379

×
380
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
381
                res, err := db.DeleteNodeByPubKey(
×
382
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
383
                                Version: int16(ProtocolV1),
×
384
                                PubKey:  pubKey[:],
×
385
                        },
×
386
                )
×
387
                if err != nil {
×
388
                        return err
×
389
                }
×
390

391
                rows, err := res.RowsAffected()
×
392
                if err != nil {
×
393
                        return err
×
394
                }
×
395

396
                if rows == 0 {
×
397
                        return ErrGraphNodeNotFound
×
398
                } else if rows > 1 {
×
399
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
400
                }
×
401

402
                return err
×
403
        }, sqldb.NoOpReset)
404
        if err != nil {
×
405
                return fmt.Errorf("unable to delete node: %w", err)
×
406
        }
×
407

408
        return nil
×
409
}
410

411
// FetchNodeFeatures returns the features of the given node. If no features are
412
// known for the node, an empty feature vector is returned.
413
//
414
// NOTE: this is part of the graphdb.NodeTraverser interface.
415
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
416
        *lnwire.FeatureVector, error) {
×
417

×
418
        ctx := context.TODO()
×
419

×
420
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
421
}
×
422

423
// DisabledChannelIDs returns the channel ids of disabled channels.
424
// A channel is disabled when two of the associated ChanelEdgePolicies
425
// have their disabled bit on.
426
//
427
// NOTE: part of the V1Store interface.
428
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
429
        var (
×
430
                ctx     = context.TODO()
×
431
                chanIDs []uint64
×
432
        )
×
433
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
434
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
435
                if err != nil {
×
436
                        return fmt.Errorf("unable to fetch disabled "+
×
437
                                "channels: %w", err)
×
438
                }
×
439

440
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
441

×
442
                return nil
×
443
        }, sqldb.NoOpReset)
444
        if err != nil {
×
445
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
446
                        err)
×
447
        }
×
448

449
        return chanIDs, nil
×
450
}
451

452
// LookupAlias attempts to return the alias as advertised by the target node.
453
//
454
// NOTE: part of the V1Store interface.
455
func (s *SQLStore) LookupAlias(ctx context.Context,
456
        pub *btcec.PublicKey) (string, error) {
×
457

×
458
        var alias string
×
459
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
460
                dbNode, err := db.GetNodeByPubKey(
×
461
                        ctx, sqlc.GetNodeByPubKeyParams{
×
462
                                Version: int16(ProtocolV1),
×
463
                                PubKey:  pub.SerializeCompressed(),
×
464
                        },
×
465
                )
×
466
                if errors.Is(err, sql.ErrNoRows) {
×
467
                        return ErrNodeAliasNotFound
×
468
                } else if err != nil {
×
469
                        return fmt.Errorf("unable to fetch node: %w", err)
×
470
                }
×
471

472
                if !dbNode.Alias.Valid {
×
473
                        return ErrNodeAliasNotFound
×
474
                }
×
475

476
                alias = dbNode.Alias.String
×
477

×
478
                return nil
×
479
        }, sqldb.NoOpReset)
480
        if err != nil {
×
481
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
482
        }
×
483

484
        return alias, nil
×
485
}
486

487
// SourceNode returns the source node of the graph. The source node is treated
488
// as the center node within a star-graph. This method may be used to kick off
489
// a path finding algorithm in order to explore the reachability of another
490
// node based off the source node.
491
//
492
// NOTE: part of the V1Store interface.
493
func (s *SQLStore) SourceNode(ctx context.Context) (*models.Node,
494
        error) {
×
495

×
496
        var node *models.Node
×
497
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
498
                _, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
499
                if err != nil {
×
500
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
501
                                err)
×
502
                }
×
503

504
                _, node, err = getNodeByPubKey(ctx, s.cfg.QueryCfg, db, nodePub)
×
505

×
506
                return err
×
507
        }, sqldb.NoOpReset)
508
        if err != nil {
×
509
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
510
        }
×
511

512
        return node, nil
×
513
}
514

515
// SetSourceNode sets the source node within the graph database. The source
516
// node is to be used as the center of a star-graph within path finding
517
// algorithms.
518
//
519
// NOTE: part of the V1Store interface.
520
func (s *SQLStore) SetSourceNode(ctx context.Context,
521
        node *models.Node) error {
×
522

×
523
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
524
                id, err := upsertNode(ctx, db, node)
×
525
                if err != nil {
×
526
                        return fmt.Errorf("unable to upsert source node: %w",
×
527
                                err)
×
528
                }
×
529

530
                // Make sure that if a source node for this version is already
531
                // set, then the ID is the same as the one we are about to set.
532
                dbSourceNodeID, _, err := s.getSourceNode(ctx, db, ProtocolV1)
×
533
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
534
                        return fmt.Errorf("unable to fetch source node: %w",
×
535
                                err)
×
536
                } else if err == nil {
×
537
                        if dbSourceNodeID != id {
×
538
                                return fmt.Errorf("v1 source node already "+
×
539
                                        "set to a different node: %d vs %d",
×
540
                                        dbSourceNodeID, id)
×
541
                        }
×
542

543
                        return nil
×
544
                }
545

546
                return db.AddSourceNode(ctx, id)
×
547
        }, sqldb.NoOpReset)
548
}
549

550
// NodeUpdatesInHorizon returns all the known lightning node which have an
551
// update timestamp within the passed range. This method can be used by two
552
// nodes to quickly determine if they have the same set of up to date node
553
// announcements.
554
//
555
// NOTE: This is part of the V1Store interface.
556
func (s *SQLStore) NodeUpdatesInHorizon(startTime, endTime time.Time,
NEW
557
        opts ...IteratorOption) iter.Seq2[models.Node, error] {
×
558

×
NEW
559
        cfg := defaultIteratorConfig()
×
NEW
560
        for _, opt := range opts {
×
NEW
561
                opt(cfg)
×
NEW
562
        }
×
563

NEW
564
        return func(yield func(models.Node, error) bool) {
×
NEW
565
                var (
×
NEW
566
                        ctx            = context.TODO()
×
NEW
567
                        lastUpdateTime sql.NullInt64
×
NEW
568
                        lastPubKey     = make([]byte, 33)
×
NEW
569
                        hasMore        = true
×
570
                )
×
571

×
NEW
572
                // Each iteration, we'll read a batch amount of nodes, yield
×
NEW
573
                // them, then decide is we have more or not.
×
NEW
574
                for hasMore {
×
NEW
575
                        var batch []models.Node
×
NEW
576

×
NEW
577
                        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
578
                                rows, err := db.GetNodesByLastUpdateRange(
×
NEW
579
                                        ctx, sqlc.GetNodesByLastUpdateRangeParams{
×
NEW
580
                                                StartTime:  sqldb.SQLInt64(startTime.Unix()),
×
NEW
581
                                                EndTime:    sqldb.SQLInt64(endTime.Unix()),
×
NEW
582
                                                LastUpdate: lastUpdateTime,
×
NEW
583
                                                LastPubKey: lastPubKey,
×
NEW
584
                                                OnlyPublic: sql.NullBool{
×
NEW
585
                                                        Bool: cfg.iterPublicNodes, Valid: true,
×
NEW
586
                                                },
×
NEW
587
                                                MaxResults: sqldb.SQLInt32(
×
NEW
588
                                                        cfg.nodeUpdateIterBatchSize,
×
NEW
589
                                                ),
×
NEW
590
                                        },
×
NEW
591
                                )
×
NEW
592
                                if err != nil {
×
NEW
593
                                        return err
×
NEW
594
                                }
×
595

NEW
596
                                hasMore = len(rows) == cfg.nodeUpdateIterBatchSize
×
NEW
597

×
NEW
598
                                err = forEachNodeInBatch(
×
NEW
599
                                        ctx, s.cfg.QueryCfg, db, rows,
×
NEW
600
                                        func(_ int64, node *models.Node) error {
×
NEW
601
                                                batch = append(batch, *node)
×
NEW
602

×
NEW
603
                                                // Update pagination cursors
×
NEW
604
                                                // based on the last processed
×
NEW
605
                                                // node.
×
NEW
606
                                                lastUpdateTime = sql.NullInt64{
×
NEW
607
                                                        Int64: node.LastUpdate.Unix(),
×
NEW
608
                                                        Valid: true,
×
NEW
609
                                                }
×
NEW
610
                                                lastPubKey = node.PubKeyBytes[:]
×
NEW
611

×
NEW
612
                                                return nil
×
NEW
613
                                        },
×
614
                                )
NEW
615
                                if err != nil {
×
NEW
616
                                        return fmt.Errorf("unable to build "+
×
NEW
617
                                                "nodes: %w", err)
×
NEW
618
                                }
×
619

620
                                return nil
×
621
                        }, sqldb.NoOpReset)
622

NEW
623
                        if err != nil {
×
NEW
624
                                log.Errorf("NodeUpdatesInHorizon batch "+
×
NEW
625
                                        "error: %v", err)
×
NEW
626
                                yield(models.Node{}, err)
×
NEW
627
                                return
×
NEW
628
                        }
×
629

NEW
630
                        for _, node := range batch {
×
NEW
631
                                if !yield(node, nil) {
×
NEW
632
                                        return
×
NEW
633
                                }
×
634
                        }
635

636
                        // If the batch didn't yield anything, then we're done.
NEW
637
                        if len(batch) == 0 {
×
NEW
638
                                break
×
639
                        }
640
                }
641
        }
642
}
643

644
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
645
// undirected edge from the two target nodes are created. The information stored
646
// denotes the static attributes of the channel, such as the channelID, the keys
647
// involved in creation of the channel, and the set of features that the channel
648
// supports. The chanPoint and chanID are used to uniquely identify the edge
649
// globally within the database.
650
//
651
// NOTE: part of the V1Store interface.
652
func (s *SQLStore) AddChannelEdge(ctx context.Context,
653
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
654

×
655
        var alreadyExists bool
×
656
        r := &batch.Request[SQLQueries]{
×
657
                Opts: batch.NewSchedulerOptions(opts...),
×
658
                Reset: func() {
×
659
                        alreadyExists = false
×
660
                },
×
661
                Do: func(tx SQLQueries) error {
×
662
                        chanIDB := channelIDToBytes(edge.ChannelID)
×
663

×
664
                        // Make sure that the channel doesn't already exist. We
×
665
                        // do this explicitly instead of relying on catching a
×
666
                        // unique constraint error because relying on SQL to
×
667
                        // throw that error would abort the entire batch of
×
668
                        // transactions.
×
669
                        _, err := tx.GetChannelBySCID(
×
670
                                ctx, sqlc.GetChannelBySCIDParams{
×
671
                                        Scid:    chanIDB,
×
672
                                        Version: int16(ProtocolV1),
×
673
                                },
×
674
                        )
×
675
                        if err == nil {
×
676
                                alreadyExists = true
×
677
                                return nil
×
678
                        } else if !errors.Is(err, sql.ErrNoRows) {
×
679
                                return fmt.Errorf("unable to fetch channel: %w",
×
680
                                        err)
×
681
                        }
×
682

683
                        return insertChannel(ctx, tx, edge)
×
684
                },
685
                OnCommit: func(err error) error {
×
686
                        switch {
×
687
                        case err != nil:
×
688
                                return err
×
689
                        case alreadyExists:
×
690
                                return ErrEdgeAlreadyExist
×
691
                        default:
×
692
                                s.rejectCache.remove(edge.ChannelID)
×
693
                                s.chanCache.remove(edge.ChannelID)
×
694
                                return nil
×
695
                        }
696
                },
697
        }
698

699
        return s.chanScheduler.Execute(ctx, r)
×
700
}
701

702
// HighestChanID returns the "highest" known channel ID in the channel graph.
703
// This represents the "newest" channel from the PoV of the chain. This method
704
// can be used by peers to quickly determine if their graphs are in sync.
705
//
706
// NOTE: This is part of the V1Store interface.
707
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
708
        var highestChanID uint64
×
709
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
710
                chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
×
711
                if errors.Is(err, sql.ErrNoRows) {
×
712
                        return nil
×
713
                } else if err != nil {
×
714
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
715
                                err)
×
716
                }
×
717

718
                highestChanID = byteOrder.Uint64(chanID)
×
719

×
720
                return nil
×
721
        }, sqldb.NoOpReset)
722
        if err != nil {
×
723
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
724
        }
×
725

726
        return highestChanID, nil
×
727
}
728

729
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
730
// within the database for the referenced channel. The `flags` attribute within
731
// the ChannelEdgePolicy determines which of the directed edges are being
732
// updated. If the flag is 1, then the first node's information is being
733
// updated, otherwise it's the second node's information. The node ordering is
734
// determined by the lexicographical ordering of the identity public keys of the
735
// nodes on either side of the channel.
736
//
737
// NOTE: part of the V1Store interface.
738
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
739
        edge *models.ChannelEdgePolicy,
740
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
741

×
742
        var (
×
743
                isUpdate1    bool
×
744
                edgeNotFound bool
×
745
                from, to     route.Vertex
×
746
        )
×
747

×
748
        r := &batch.Request[SQLQueries]{
×
749
                Opts: batch.NewSchedulerOptions(opts...),
×
750
                Reset: func() {
×
751
                        isUpdate1 = false
×
752
                        edgeNotFound = false
×
753
                },
×
754
                Do: func(tx SQLQueries) error {
×
755
                        var err error
×
756
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
757
                                ctx, tx, edge,
×
758
                        )
×
759
                        if err != nil {
×
760
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
761
                        }
×
762

763
                        // Silence ErrEdgeNotFound so that the batch can
764
                        // succeed, but propagate the error via local state.
765
                        if errors.Is(err, ErrEdgeNotFound) {
×
766
                                edgeNotFound = true
×
767
                                return nil
×
768
                        }
×
769

770
                        return err
×
771
                },
772
                OnCommit: func(err error) error {
×
773
                        switch {
×
774
                        case err != nil:
×
775
                                return err
×
776
                        case edgeNotFound:
×
777
                                return ErrEdgeNotFound
×
778
                        default:
×
779
                                s.updateEdgeCache(edge, isUpdate1)
×
780
                                return nil
×
781
                        }
782
                },
783
        }
784

785
        err := s.chanScheduler.Execute(ctx, r)
×
786

×
787
        return from, to, err
×
788
}
789

790
// updateEdgeCache updates our reject and channel caches with the new
791
// edge policy information.
792
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
793
        isUpdate1 bool) {
×
794

×
795
        // If an entry for this channel is found in reject cache, we'll modify
×
796
        // the entry with the updated timestamp for the direction that was just
×
797
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
798
        // during the next query for this edge.
×
799
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
800
                if isUpdate1 {
×
801
                        entry.upd1Time = e.LastUpdate.Unix()
×
802
                } else {
×
803
                        entry.upd2Time = e.LastUpdate.Unix()
×
804
                }
×
805
                s.rejectCache.insert(e.ChannelID, entry)
×
806
        }
807

808
        // If an entry for this channel is found in channel cache, we'll modify
809
        // the entry with the updated policy for the direction that was just
810
        // written. If the edge doesn't exist, we'll defer loading the info and
811
        // policies and lazily read from disk during the next query.
812
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
813
                if isUpdate1 {
×
814
                        channel.Policy1 = e
×
815
                } else {
×
816
                        channel.Policy2 = e
×
817
                }
×
818
                s.chanCache.insert(e.ChannelID, channel)
×
819
        }
820
}
821

822
// ForEachSourceNodeChannel iterates through all channels of the source node,
823
// executing the passed callback on each. The call-back is provided with the
824
// channel's outpoint, whether we have a policy for the channel and the channel
825
// peer's node information.
826
//
827
// NOTE: part of the V1Store interface.
828
func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context,
829
        cb func(chanPoint wire.OutPoint, havePolicy bool,
830
                otherNode *models.Node) error, reset func()) error {
×
831

×
832
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
833
                nodeID, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
834
                if err != nil {
×
835
                        return fmt.Errorf("unable to fetch source node: %w",
×
836
                                err)
×
837
                }
×
838

839
                return forEachNodeChannel(
×
840
                        ctx, db, s.cfg, nodeID,
×
841
                        func(info *models.ChannelEdgeInfo,
×
842
                                outPolicy *models.ChannelEdgePolicy,
×
843
                                _ *models.ChannelEdgePolicy) error {
×
844

×
845
                                // Fetch the other node.
×
846
                                var (
×
847
                                        otherNodePub [33]byte
×
848
                                        node1        = info.NodeKey1Bytes
×
849
                                        node2        = info.NodeKey2Bytes
×
850
                                )
×
851
                                switch {
×
852
                                case bytes.Equal(node1[:], nodePub[:]):
×
853
                                        otherNodePub = node2
×
854
                                case bytes.Equal(node2[:], nodePub[:]):
×
855
                                        otherNodePub = node1
×
856
                                default:
×
857
                                        return fmt.Errorf("node not " +
×
858
                                                "participating in this channel")
×
859
                                }
860

861
                                _, otherNode, err := getNodeByPubKey(
×
862
                                        ctx, s.cfg.QueryCfg, db, otherNodePub,
×
863
                                )
×
864
                                if err != nil {
×
865
                                        return fmt.Errorf("unable to fetch "+
×
866
                                                "other node(%x): %w",
×
867
                                                otherNodePub, err)
×
868
                                }
×
869

870
                                return cb(
×
871
                                        info.ChannelPoint, outPolicy != nil,
×
872
                                        otherNode,
×
873
                                )
×
874
                        },
875
                )
876
        }, reset)
877
}
878

879
// ForEachNode iterates through all the stored vertices/nodes in the graph,
880
// executing the passed callback with each node encountered. If the callback
881
// returns an error, then the transaction is aborted and the iteration stops
882
// early.
883
//
884
// NOTE: part of the V1Store interface.
885
func (s *SQLStore) ForEachNode(ctx context.Context,
886
        cb func(node *models.Node) error, reset func()) error {
×
887

×
888
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
889
                return forEachNodePaginated(
×
890
                        ctx, s.cfg.QueryCfg, db,
×
891
                        ProtocolV1, func(_ context.Context, _ int64,
×
892
                                node *models.Node) error {
×
893

×
894
                                return cb(node)
×
895
                        },
×
896
                )
897
        }, reset)
898
}
899

900
// ForEachNodeDirectedChannel iterates through all channels of a given node,
901
// executing the passed callback on the directed edge representing the channel
902
// and its incoming policy. If the callback returns an error, then the iteration
903
// is halted with the error propagated back up to the caller.
904
//
905
// Unknown policies are passed into the callback as nil values.
906
//
907
// NOTE: this is part of the graphdb.NodeTraverser interface.
908
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
909
        cb func(channel *DirectedChannel) error, reset func()) error {
×
910

×
911
        var ctx = context.TODO()
×
912

×
913
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
914
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
915
        }, reset)
×
916
}
917

918
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
919
// graph, executing the passed callback with each node encountered. If the
920
// callback returns an error, then the transaction is aborted and the iteration
921
// stops early.
922
func (s *SQLStore) ForEachNodeCacheable(ctx context.Context,
923
        cb func(route.Vertex, *lnwire.FeatureVector) error,
924
        reset func()) error {
×
925

×
926
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
927
                return forEachNodeCacheable(
×
928
                        ctx, s.cfg.QueryCfg, db,
×
929
                        func(_ int64, nodePub route.Vertex,
×
930
                                features *lnwire.FeatureVector) error {
×
931

×
932
                                return cb(nodePub, features)
×
933
                        },
×
934
                )
935
        }, reset)
936
        if err != nil {
×
937
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
938
        }
×
939

940
        return nil
×
941
}
942

943
// ForEachNodeChannel iterates through all channels of the given node,
944
// executing the passed callback with an edge info structure and the policies
945
// of each end of the channel. The first edge policy is the outgoing edge *to*
946
// the connecting node, while the second is the incoming edge *from* the
947
// connecting node. If the callback returns an error, then the iteration is
948
// halted with the error propagated back up to the caller.
949
//
950
// Unknown policies are passed into the callback as nil values.
951
//
952
// NOTE: part of the V1Store interface.
953
func (s *SQLStore) ForEachNodeChannel(ctx context.Context, nodePub route.Vertex,
954
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
955
                *models.ChannelEdgePolicy) error, reset func()) error {
×
956

×
957
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
958
                dbNode, err := db.GetNodeByPubKey(
×
959
                        ctx, sqlc.GetNodeByPubKeyParams{
×
960
                                Version: int16(ProtocolV1),
×
961
                                PubKey:  nodePub[:],
×
962
                        },
×
963
                )
×
964
                if errors.Is(err, sql.ErrNoRows) {
×
965
                        return nil
×
966
                } else if err != nil {
×
967
                        return fmt.Errorf("unable to fetch node: %w", err)
×
968
                }
×
969

970
                return forEachNodeChannel(ctx, db, s.cfg, dbNode.ID, cb)
×
971
        }, reset)
972
}
973

974
// extractMaxUpdateTime returns the maximum of the two policy update times.
975
// This is used for pagination cursor tracking.
NEW
976
func extractMaxUpdateTime(row sqlc.GetChannelsByPolicyLastUpdateRangeRow) int64 {
×
NEW
977
        if row.Policy1LastUpdate.Valid && row.Policy2LastUpdate.Valid {
×
NEW
978
                return max(row.Policy1LastUpdate.Int64, row.Policy2LastUpdate.Int64)
×
NEW
979
        } else if row.Policy1LastUpdate.Valid {
×
NEW
980
                return row.Policy1LastUpdate.Int64
×
NEW
981
        } else if row.Policy2LastUpdate.Valid {
×
NEW
982
                return row.Policy2LastUpdate.Int64
×
NEW
983
        }
×
NEW
984
        return 0
×
985
}
986

987
// buildChannelFromRow constructs a ChannelEdge from a database row.
988
// This includes building the nodes, channel info, and policies.
989
func (s *SQLStore) buildChannelFromRow(ctx context.Context, db SQLQueries,
NEW
990
        row sqlc.GetChannelsByPolicyLastUpdateRangeRow) (ChannelEdge, error) {
×
NEW
991

×
NEW
992
        node1, err := buildNode(ctx, s.cfg.QueryCfg, db, row.GraphNode)
×
NEW
993
        if err != nil {
×
NEW
994
                return ChannelEdge{}, fmt.Errorf("unable to build node1: %w", err)
×
NEW
995
        }
×
996

NEW
997
        node2, err := buildNode(ctx, s.cfg.QueryCfg, db, row.GraphNode_2)
×
NEW
998
        if err != nil {
×
NEW
999
                return ChannelEdge{}, fmt.Errorf("unable to build node2: %w", err)
×
NEW
1000
        }
×
1001

NEW
1002
        channel, err := getAndBuildEdgeInfo(
×
NEW
1003
                ctx, s.cfg, db,
×
NEW
1004
                row.GraphChannel, node1.PubKeyBytes,
×
NEW
1005
                node2.PubKeyBytes,
×
NEW
1006
        )
×
NEW
1007
        if err != nil {
×
NEW
1008
                return ChannelEdge{}, fmt.Errorf("unable to build "+
×
NEW
1009
                        "channel info: %w", err)
×
NEW
1010
        }
×
1011

NEW
1012
        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
NEW
1013
        if err != nil {
×
NEW
1014
                return ChannelEdge{}, fmt.Errorf("unable to extract "+
×
NEW
1015
                        "channel policies: %w", err)
×
NEW
1016
        }
×
1017

NEW
1018
        p1, p2, err := getAndBuildChanPolicies(
×
NEW
1019
                ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, channel.ChannelID,
×
NEW
1020
                node1.PubKeyBytes, node2.PubKeyBytes,
×
NEW
1021
        )
×
NEW
1022
        if err != nil {
×
NEW
1023
                return ChannelEdge{}, fmt.Errorf("unable to build "+
×
NEW
1024
                        "channel policies: %w", err)
×
NEW
1025
        }
×
1026

NEW
1027
        return ChannelEdge{
×
NEW
1028
                Info:    channel,
×
NEW
1029
                Policy1: p1,
×
NEW
1030
                Policy2: p2,
×
NEW
1031
                Node1:   node1,
×
NEW
1032
                Node2:   node2,
×
NEW
1033
        }, nil
×
1034
}
1035

1036
// updateChanCacheBatch updates the channel cache with multiple edges at once.
1037
// This method acquires the cache lock only once for the entire batch.
NEW
1038
func (s *SQLStore) updateChanCacheBatch(edgesToCache map[uint64]ChannelEdge) {
×
NEW
1039
        if len(edgesToCache) == 0 {
×
NEW
1040
                return
×
NEW
1041
        }
×
1042

NEW
1043
        s.cacheMu.Lock()
×
NEW
1044
        defer s.cacheMu.Unlock()
×
NEW
1045

×
NEW
1046
        for chanID, edge := range edgesToCache {
×
NEW
1047
                s.chanCache.insert(chanID, edge)
×
NEW
1048
        }
×
1049
}
1050

1051
// ChanUpdatesInHorizon returns all the known channel edges which have at least
1052
// one edge that has an update timestamp within the specified horizon.
1053
//
1054
// Iterator Lifecycle:
1055
// 1. Initialize state (edgesSeen map, cache tracking, pagination cursors)
1056
// 2. Query batch of channels with policies in time range
1057
// 3. For each channel: check if seen, check cache, or build from DB
1058
// 4. Yield channels to caller
1059
// 5. Update cache after successful batch
1060
// 6. Repeat with updated pagination cursor until no more results
1061
//
1062
// NOTE: This is part of the V1Store interface.
1063
func (s *SQLStore) ChanUpdatesInHorizon(startTime, endTime time.Time,
NEW
1064
        opts ...IteratorOption) iter.Seq2[ChannelEdge, error] {
×
1065

×
NEW
1066
        // Apply options.
×
NEW
1067
        cfg := defaultIteratorConfig()
×
NEW
1068
        for _, opt := range opts {
×
NEW
1069
                opt(cfg)
×
NEW
1070
        }
×
1071

NEW
1072
        return func(yield func(ChannelEdge, error) bool) {
×
1073

×
NEW
1074
                var (
×
NEW
1075
                        ctx            = context.TODO()
×
NEW
1076
                        edgesSeen      = make(map[uint64]struct{})
×
NEW
1077
                        edgesToCache   = make(map[uint64]ChannelEdge)
×
NEW
1078
                        hits           int
×
NEW
1079
                        total          int
×
NEW
1080
                        lastUpdateTime sql.NullInt64
×
NEW
1081
                        lastID         sql.NullInt64
×
NEW
1082
                        hasMore        = true
×
1083
                )
×
1084

×
NEW
1085
                // Each iteration, we'll read a batch amount of channel updates
×
NEW
1086
                // (consulting the cache along the way), yield them, then loop
×
NEW
1087
                // back to decide if we have any more updates to read out.
×
NEW
1088
                for hasMore {
×
NEW
1089
                        var batch []ChannelEdge
×
NEW
1090

×
NEW
1091
                        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
1092
                                rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
NEW
1093
                                        ctx, sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
NEW
1094
                                                Version:        int16(ProtocolV1),
×
NEW
1095
                                                StartTime:      sqldb.SQLInt64(startTime.Unix()),
×
NEW
1096
                                                EndTime:        sqldb.SQLInt64(endTime.Unix()),
×
NEW
1097
                                                LastUpdateTime: lastUpdateTime,
×
NEW
1098
                                                LastID:         lastID,
×
NEW
1099
                                                MaxResults: sql.NullInt32{
×
NEW
1100
                                                        Int32: int32(cfg.chanUpdateIterBatchSize),
×
NEW
1101
                                                        Valid: true,
×
NEW
1102
                                                },
×
NEW
1103
                                        },
×
NEW
1104
                                )
×
NEW
1105
                                if err != nil {
×
NEW
1106
                                        return err
×
NEW
1107
                                }
×
1108

NEW
1109
                                hasMore = len(rows) == cfg.chanUpdateIterBatchSize
×
NEW
1110

×
NEW
1111
                                for _, row := range rows {
×
NEW
1112
                                        // Update pagination cursor.
×
NEW
1113
                                        lastUpdateTime = sql.NullInt64{
×
NEW
1114
                                                Int64: extractMaxUpdateTime(row),
×
NEW
1115
                                                Valid: true,
×
NEW
1116
                                        }
×
NEW
1117
                                        lastID = sql.NullInt64{
×
NEW
1118
                                                Int64: row.GraphChannel.ID,
×
NEW
1119
                                                Valid: true,
×
NEW
1120
                                        }
×
NEW
1121

×
NEW
1122
                                        // Skip if we've already processed this channel.
×
NEW
1123
                                        chanIDInt := byteOrder.Uint64(row.GraphChannel.Scid)
×
NEW
1124
                                        if _, ok := edgesSeen[chanIDInt]; ok {
×
NEW
1125
                                                continue
×
1126
                                        }
1127

NEW
1128
                                        if channel, ok := s.chanCache.get(chanIDInt); ok {
×
NEW
1129
                                                hits++
×
NEW
1130
                                                total++
×
NEW
1131
                                                edgesSeen[chanIDInt] = struct{}{}
×
NEW
1132
                                                batch = append(batch, channel)
×
NEW
1133
                                                continue
×
1134
                                        }
1135

NEW
1136
                                        chanEdge, err := s.buildChannelFromRow(ctx, db, row)
×
NEW
1137
                                        if err != nil {
×
NEW
1138
                                                return err
×
NEW
1139
                                        }
×
1140

NEW
1141
                                        edgesSeen[chanIDInt] = struct{}{}
×
NEW
1142
                                        edgesToCache[chanIDInt] = chanEdge
×
NEW
1143

×
NEW
1144
                                        batch = append(batch, chanEdge)
×
NEW
1145

×
NEW
1146
                                        total++
×
1147
                                }
1148

NEW
1149
                                return nil
×
NEW
1150
                        }, func() {
×
NEW
1151
                                batch = nil
×
NEW
1152
                        })
×
1153

NEW
1154
                        if err != nil {
×
NEW
1155
                                log.Errorf("ChanUpdatesInHorizon batch error: %v", err)
×
NEW
1156
                                yield(ChannelEdge{}, err)
×
NEW
1157
                                return
×
1158
                        }
×
1159

NEW
1160
                        for _, edge := range batch {
×
NEW
1161
                                if !yield(edge, nil) {
×
NEW
1162
                                        return
×
NEW
1163
                                }
×
1164
                        }
1165

1166
                        // Update cache after successful batch yield, setting
1167
                        // the cache lock only once for the entire batch.
NEW
1168
                        s.updateChanCacheBatch(edgesToCache)
×
NEW
1169
                        edgesToCache = make(map[uint64]ChannelEdge)
×
UNCOV
1170

×
NEW
1171
                        // If the batch didn't yield anything, then we're done.
×
NEW
1172
                        if len(batch) == 0 {
×
NEW
1173
                                break
×
1174
                        }
1175
                }
1176

NEW
1177
                if total > 0 {
×
NEW
1178
                        log.Debugf("ChanUpdatesInHorizon hit percentage: "+
×
NEW
1179
                                "%.2f (%d/%d)", float64(hits)*100/float64(total),
×
NEW
1180
                                hits, total)
×
NEW
1181
                } else {
×
NEW
1182
                        log.Debugf("ChanUpdatesInHorizon returned no edges "+
×
NEW
1183
                                "in horizon (%s, %s)", startTime, endTime)
×
UNCOV
1184
                }
×
1185
        }
1186
}
1187

1188
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1189
// data to the call-back. If withAddrs is true, then the call-back will also be
1190
// provided with the addresses associated with the node. The address retrieval
1191
// result in an additional round-trip to the database, so it should only be used
1192
// if the addresses are actually needed.
1193
//
1194
// NOTE: part of the V1Store interface.
1195
func (s *SQLStore) ForEachNodeCached(ctx context.Context, withAddrs bool,
1196
        cb func(ctx context.Context, node route.Vertex, addrs []net.Addr,
1197
                chans map[uint64]*DirectedChannel) error, reset func()) error {
×
1198

×
1199
        type nodeCachedBatchData struct {
×
1200
                features      map[int64][]int
×
1201
                addrs         map[int64][]nodeAddress
×
1202
                chanBatchData *batchChannelData
×
1203
                chanMap       map[int64][]sqlc.ListChannelsForNodeIDsRow
×
1204
        }
×
1205

×
1206
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1207
                // pageQueryFunc is used to query the next page of nodes.
×
1208
                pageQueryFunc := func(ctx context.Context, lastID int64,
×
1209
                        limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
1210

×
1211
                        return db.ListNodeIDsAndPubKeys(
×
1212
                                ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
1213
                                        Version: int16(ProtocolV1),
×
1214
                                        ID:      lastID,
×
1215
                                        Limit:   limit,
×
1216
                                },
×
1217
                        )
×
1218
                }
×
1219

1220
                // batchDataFunc is then used to batch load the data required
1221
                // for each page of nodes.
1222
                batchDataFunc := func(ctx context.Context,
×
1223
                        nodeIDs []int64) (*nodeCachedBatchData, error) {
×
1224

×
1225
                        // Batch load node features.
×
1226
                        nodeFeatures, err := batchLoadNodeFeaturesHelper(
×
1227
                                ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1228
                        )
×
1229
                        if err != nil {
×
1230
                                return nil, fmt.Errorf("unable to batch load "+
×
1231
                                        "node features: %w", err)
×
1232
                        }
×
1233

1234
                        // Maybe fetch the node's addresses if requested.
1235
                        var nodeAddrs map[int64][]nodeAddress
×
1236
                        if withAddrs {
×
1237
                                nodeAddrs, err = batchLoadNodeAddressesHelper(
×
1238
                                        ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1239
                                )
×
1240
                                if err != nil {
×
1241
                                        return nil, fmt.Errorf("unable to "+
×
1242
                                                "batch load node "+
×
1243
                                                "addresses: %w", err)
×
1244
                                }
×
1245
                        }
1246

1247
                        // Batch load ALL unique channels for ALL nodes in this
1248
                        // page.
1249
                        allChannels, err := db.ListChannelsForNodeIDs(
×
1250
                                ctx, sqlc.ListChannelsForNodeIDsParams{
×
1251
                                        Version:  int16(ProtocolV1),
×
1252
                                        Node1Ids: nodeIDs,
×
1253
                                        Node2Ids: nodeIDs,
×
1254
                                },
×
1255
                        )
×
1256
                        if err != nil {
×
1257
                                return nil, fmt.Errorf("unable to batch "+
×
1258
                                        "fetch channels for nodes: %w", err)
×
1259
                        }
×
1260

1261
                        // Deduplicate channels and collect IDs.
1262
                        var (
×
1263
                                allChannelIDs []int64
×
1264
                                allPolicyIDs  []int64
×
1265
                        )
×
1266
                        uniqueChannels := make(
×
1267
                                map[int64]sqlc.ListChannelsForNodeIDsRow,
×
1268
                        )
×
1269

×
1270
                        for _, channel := range allChannels {
×
1271
                                channelID := channel.GraphChannel.ID
×
1272

×
1273
                                // Only process each unique channel once.
×
1274
                                _, exists := uniqueChannels[channelID]
×
1275
                                if exists {
×
1276
                                        continue
×
1277
                                }
1278

1279
                                uniqueChannels[channelID] = channel
×
1280
                                allChannelIDs = append(allChannelIDs, channelID)
×
1281

×
1282
                                if channel.Policy1ID.Valid {
×
1283
                                        allPolicyIDs = append(
×
1284
                                                allPolicyIDs,
×
1285
                                                channel.Policy1ID.Int64,
×
1286
                                        )
×
1287
                                }
×
1288
                                if channel.Policy2ID.Valid {
×
1289
                                        allPolicyIDs = append(
×
1290
                                                allPolicyIDs,
×
1291
                                                channel.Policy2ID.Int64,
×
1292
                                        )
×
1293
                                }
×
1294
                        }
1295

1296
                        // Batch load channel data for all unique channels.
1297
                        channelBatchData, err := batchLoadChannelData(
×
1298
                                ctx, s.cfg.QueryCfg, db, allChannelIDs,
×
1299
                                allPolicyIDs,
×
1300
                        )
×
1301
                        if err != nil {
×
1302
                                return nil, fmt.Errorf("unable to batch "+
×
1303
                                        "load channel data: %w", err)
×
1304
                        }
×
1305

1306
                        // Create map of node ID to channels that involve this
1307
                        // node.
1308
                        nodeIDSet := make(map[int64]bool)
×
1309
                        for _, nodeID := range nodeIDs {
×
1310
                                nodeIDSet[nodeID] = true
×
1311
                        }
×
1312

1313
                        nodeChannelMap := make(
×
1314
                                map[int64][]sqlc.ListChannelsForNodeIDsRow,
×
1315
                        )
×
1316
                        for _, channel := range uniqueChannels {
×
1317
                                // Add channel to both nodes if they're in our
×
1318
                                // current page.
×
1319
                                node1 := channel.GraphChannel.NodeID1
×
1320
                                if nodeIDSet[node1] {
×
1321
                                        nodeChannelMap[node1] = append(
×
1322
                                                nodeChannelMap[node1], channel,
×
1323
                                        )
×
1324
                                }
×
1325
                                node2 := channel.GraphChannel.NodeID2
×
1326
                                if nodeIDSet[node2] {
×
1327
                                        nodeChannelMap[node2] = append(
×
1328
                                                nodeChannelMap[node2], channel,
×
1329
                                        )
×
1330
                                }
×
1331
                        }
1332

1333
                        return &nodeCachedBatchData{
×
1334
                                features:      nodeFeatures,
×
1335
                                addrs:         nodeAddrs,
×
1336
                                chanBatchData: channelBatchData,
×
1337
                                chanMap:       nodeChannelMap,
×
1338
                        }, nil
×
1339
                }
1340

1341
                // processItem is used to process each node in the current page.
1342
                processItem := func(ctx context.Context,
×
1343
                        nodeData sqlc.ListNodeIDsAndPubKeysRow,
×
1344
                        batchData *nodeCachedBatchData) error {
×
1345

×
1346
                        // Build feature vector for this node.
×
1347
                        fv := lnwire.EmptyFeatureVector()
×
1348
                        features, exists := batchData.features[nodeData.ID]
×
1349
                        if exists {
×
1350
                                for _, bit := range features {
×
1351
                                        fv.Set(lnwire.FeatureBit(bit))
×
1352
                                }
×
1353
                        }
1354

1355
                        var nodePub route.Vertex
×
1356
                        copy(nodePub[:], nodeData.PubKey)
×
1357

×
1358
                        nodeChannels := batchData.chanMap[nodeData.ID]
×
1359

×
1360
                        toNodeCallback := func() route.Vertex {
×
1361
                                return nodePub
×
1362
                        }
×
1363

1364
                        // Build cached channels map for this node.
1365
                        channels := make(map[uint64]*DirectedChannel)
×
1366
                        for _, channelRow := range nodeChannels {
×
1367
                                directedChan, err := buildDirectedChannel(
×
1368
                                        s.cfg.ChainHash, nodeData.ID, nodePub,
×
1369
                                        channelRow, batchData.chanBatchData, fv,
×
1370
                                        toNodeCallback,
×
1371
                                )
×
1372
                                if err != nil {
×
1373
                                        return err
×
1374
                                }
×
1375

1376
                                channels[directedChan.ChannelID] = directedChan
×
1377
                        }
1378

1379
                        addrs, err := buildNodeAddresses(
×
1380
                                batchData.addrs[nodeData.ID],
×
1381
                        )
×
1382
                        if err != nil {
×
1383
                                return fmt.Errorf("unable to build node "+
×
1384
                                        "addresses: %w", err)
×
1385
                        }
×
1386

1387
                        return cb(ctx, nodePub, addrs, channels)
×
1388
                }
1389

1390
                return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
1391
                        ctx, s.cfg.QueryCfg, int64(-1), pageQueryFunc,
×
1392
                        func(node sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
1393
                                return node.ID
×
1394
                        },
×
1395
                        func(node sqlc.ListNodeIDsAndPubKeysRow) (int64,
1396
                                error) {
×
1397

×
1398
                                return node.ID, nil
×
1399
                        },
×
1400
                        batchDataFunc, processItem,
1401
                )
1402
        }, reset)
1403
}
1404

1405
// ForEachChannelCacheable iterates through all the channel edges stored
1406
// within the graph and invokes the passed callback for each edge. The
1407
// callback takes two edges as since this is a directed graph, both the
1408
// in/out edges are visited. If the callback returns an error, then the
1409
// transaction is aborted and the iteration stops early.
1410
//
1411
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1412
// pointer for that particular channel edge routing policy will be
1413
// passed into the callback.
1414
//
1415
// NOTE: this method is like ForEachChannel but fetches only the data
1416
// required for the graph cache.
1417
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1418
        *models.CachedEdgePolicy, *models.CachedEdgePolicy) error,
1419
        reset func()) error {
×
1420

×
1421
        ctx := context.TODO()
×
1422

×
1423
        handleChannel := func(_ context.Context,
×
1424
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) error {
×
1425

×
1426
                node1, node2, err := buildNodeVertices(
×
1427
                        row.Node1Pubkey, row.Node2Pubkey,
×
1428
                )
×
1429
                if err != nil {
×
1430
                        return err
×
1431
                }
×
1432

1433
                edge := buildCacheableChannelInfo(
×
1434
                        row.Scid, row.Capacity.Int64, node1, node2,
×
1435
                )
×
1436

×
1437
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1438
                if err != nil {
×
1439
                        return err
×
1440
                }
×
1441

1442
                pol1, pol2, err := buildCachedChanPolicies(
×
1443
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1444
                )
×
1445
                if err != nil {
×
1446
                        return err
×
1447
                }
×
1448

1449
                return cb(edge, pol1, pol2)
×
1450
        }
1451

1452
        extractCursor := func(
×
1453
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) int64 {
×
1454

×
1455
                return row.ID
×
1456
        }
×
1457

1458
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1459
                //nolint:ll
×
1460
                queryFunc := func(ctx context.Context, lastID int64,
×
1461
                        limit int32) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow,
×
1462
                        error) {
×
1463

×
1464
                        return db.ListChannelsWithPoliciesForCachePaginated(
×
1465
                                ctx, sqlc.ListChannelsWithPoliciesForCachePaginatedParams{
×
1466
                                        Version: int16(ProtocolV1),
×
1467
                                        ID:      lastID,
×
1468
                                        Limit:   limit,
×
1469
                                },
×
1470
                        )
×
1471
                }
×
1472

1473
                return sqldb.ExecutePaginatedQuery(
×
1474
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
1475
                        extractCursor, handleChannel,
×
1476
                )
×
1477
        }, reset)
1478
}
1479

1480
// ForEachChannel iterates through all the channel edges stored within the
1481
// graph and invokes the passed callback for each edge. The callback takes two
1482
// edges as since this is a directed graph, both the in/out edges are visited.
1483
// If the callback returns an error, then the transaction is aborted and the
1484
// iteration stops early.
1485
//
1486
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1487
// for that particular channel edge routing policy will be passed into the
1488
// callback.
1489
//
1490
// NOTE: part of the V1Store interface.
1491
func (s *SQLStore) ForEachChannel(ctx context.Context,
1492
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1493
                *models.ChannelEdgePolicy) error, reset func()) error {
×
1494

×
1495
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1496
                return forEachChannelWithPolicies(ctx, db, s.cfg, cb)
×
1497
        }, reset)
×
1498
}
1499

1500
// FilterChannelRange returns the channel ID's of all known channels which were
1501
// mined in a block height within the passed range. The channel IDs are grouped
1502
// by their common block height. This method can be used to quickly share with a
1503
// peer the set of channels we know of within a particular range to catch them
1504
// up after a period of time offline. If withTimestamps is true then the
1505
// timestamp info of the latest received channel update messages of the channel
1506
// will be included in the response.
1507
//
1508
// NOTE: This is part of the V1Store interface.
1509
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1510
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1511

×
1512
        var (
×
1513
                ctx       = context.TODO()
×
1514
                startSCID = &lnwire.ShortChannelID{
×
1515
                        BlockHeight: startHeight,
×
1516
                }
×
1517
                endSCID = lnwire.ShortChannelID{
×
1518
                        BlockHeight: endHeight,
×
1519
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1520
                        TxPosition:  math.MaxUint16,
×
1521
                }
×
1522
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1523
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1524
        )
×
1525

×
1526
        // 1) get all channels where channelID is between start and end chan ID.
×
1527
        // 2) skip if not public (ie, no channel_proof)
×
1528
        // 3) collect that channel.
×
1529
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1530
        //    and add those timestamps to the collected channel.
×
1531
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1532
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1533
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1534
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1535
                                StartScid: chanIDStart,
×
1536
                                EndScid:   chanIDEnd,
×
1537
                        },
×
1538
                )
×
1539
                if err != nil {
×
1540
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1541
                                err)
×
1542
                }
×
1543

1544
                for _, dbChan := range dbChans {
×
1545
                        cid := lnwire.NewShortChanIDFromInt(
×
1546
                                byteOrder.Uint64(dbChan.Scid),
×
1547
                        )
×
1548
                        chanInfo := NewChannelUpdateInfo(
×
1549
                                cid, time.Time{}, time.Time{},
×
1550
                        )
×
1551

×
1552
                        if !withTimestamps {
×
1553
                                channelsPerBlock[cid.BlockHeight] = append(
×
1554
                                        channelsPerBlock[cid.BlockHeight],
×
1555
                                        chanInfo,
×
1556
                                )
×
1557

×
1558
                                continue
×
1559
                        }
1560

1561
                        //nolint:ll
1562
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1563
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1564
                                        Version:   int16(ProtocolV1),
×
1565
                                        ChannelID: dbChan.ID,
×
1566
                                        NodeID:    dbChan.NodeID1,
×
1567
                                },
×
1568
                        )
×
1569
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1570
                                return fmt.Errorf("unable to fetch node1 "+
×
1571
                                        "policy: %w", err)
×
1572
                        } else if err == nil {
×
1573
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1574
                                        node1Policy.LastUpdate.Int64, 0,
×
1575
                                )
×
1576
                        }
×
1577

1578
                        //nolint:ll
1579
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1580
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1581
                                        Version:   int16(ProtocolV1),
×
1582
                                        ChannelID: dbChan.ID,
×
1583
                                        NodeID:    dbChan.NodeID2,
×
1584
                                },
×
1585
                        )
×
1586
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1587
                                return fmt.Errorf("unable to fetch node2 "+
×
1588
                                        "policy: %w", err)
×
1589
                        } else if err == nil {
×
1590
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1591
                                        node2Policy.LastUpdate.Int64, 0,
×
1592
                                )
×
1593
                        }
×
1594

1595
                        channelsPerBlock[cid.BlockHeight] = append(
×
1596
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1597
                        )
×
1598
                }
1599

1600
                return nil
×
1601
        }, func() {
×
1602
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1603
        })
×
1604
        if err != nil {
×
1605
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1606
        }
×
1607

1608
        if len(channelsPerBlock) == 0 {
×
1609
                return nil, nil
×
1610
        }
×
1611

1612
        // Return the channel ranges in ascending block height order.
1613
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1614
        slices.Sort(blocks)
×
1615

×
1616
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1617
                return BlockChannelRange{
×
1618
                        Height:   block,
×
1619
                        Channels: channelsPerBlock[block],
×
1620
                }
×
1621
        }), nil
×
1622
}
1623

1624
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1625
// zombie. This method is used on an ad-hoc basis, when channels need to be
1626
// marked as zombies outside the normal pruning cycle.
1627
//
1628
// NOTE: part of the V1Store interface.
1629
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1630
        pubKey1, pubKey2 [33]byte) error {
×
1631

×
1632
        ctx := context.TODO()
×
1633

×
1634
        s.cacheMu.Lock()
×
1635
        defer s.cacheMu.Unlock()
×
1636

×
1637
        chanIDB := channelIDToBytes(chanID)
×
1638

×
1639
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1640
                return db.UpsertZombieChannel(
×
1641
                        ctx, sqlc.UpsertZombieChannelParams{
×
1642
                                Version:  int16(ProtocolV1),
×
1643
                                Scid:     chanIDB,
×
1644
                                NodeKey1: pubKey1[:],
×
1645
                                NodeKey2: pubKey2[:],
×
1646
                        },
×
1647
                )
×
1648
        }, sqldb.NoOpReset)
×
1649
        if err != nil {
×
1650
                return fmt.Errorf("unable to upsert zombie channel "+
×
1651
                        "(channel_id=%d): %w", chanID, err)
×
1652
        }
×
1653

1654
        s.rejectCache.remove(chanID)
×
1655
        s.chanCache.remove(chanID)
×
1656

×
1657
        return nil
×
1658
}
1659

1660
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1661
//
1662
// NOTE: part of the V1Store interface.
1663
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1664
        s.cacheMu.Lock()
×
1665
        defer s.cacheMu.Unlock()
×
1666

×
1667
        var (
×
1668
                ctx     = context.TODO()
×
1669
                chanIDB = channelIDToBytes(chanID)
×
1670
        )
×
1671

×
1672
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1673
                res, err := db.DeleteZombieChannel(
×
1674
                        ctx, sqlc.DeleteZombieChannelParams{
×
1675
                                Scid:    chanIDB,
×
1676
                                Version: int16(ProtocolV1),
×
1677
                        },
×
1678
                )
×
1679
                if err != nil {
×
1680
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1681
                                err)
×
1682
                }
×
1683

1684
                rows, err := res.RowsAffected()
×
1685
                if err != nil {
×
1686
                        return err
×
1687
                }
×
1688

1689
                if rows == 0 {
×
1690
                        return ErrZombieEdgeNotFound
×
1691
                } else if rows > 1 {
×
1692
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1693
                                "expected 1", rows)
×
1694
                }
×
1695

1696
                return nil
×
1697
        }, sqldb.NoOpReset)
1698
        if err != nil {
×
1699
                return fmt.Errorf("unable to mark edge live "+
×
1700
                        "(channel_id=%d): %w", chanID, err)
×
1701
        }
×
1702

1703
        s.rejectCache.remove(chanID)
×
1704
        s.chanCache.remove(chanID)
×
1705

×
1706
        return err
×
1707
}
1708

1709
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1710
// zombie, then the two node public keys corresponding to this edge are also
1711
// returned.
1712
//
1713
// NOTE: part of the V1Store interface.
1714
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
1715
        error) {
×
1716

×
1717
        var (
×
1718
                ctx              = context.TODO()
×
1719
                isZombie         bool
×
1720
                pubKey1, pubKey2 route.Vertex
×
1721
                chanIDB          = channelIDToBytes(chanID)
×
1722
        )
×
1723

×
1724
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1725
                zombie, err := db.GetZombieChannel(
×
1726
                        ctx, sqlc.GetZombieChannelParams{
×
1727
                                Scid:    chanIDB,
×
1728
                                Version: int16(ProtocolV1),
×
1729
                        },
×
1730
                )
×
1731
                if errors.Is(err, sql.ErrNoRows) {
×
1732
                        return nil
×
1733
                }
×
1734
                if err != nil {
×
1735
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1736
                                err)
×
1737
                }
×
1738

1739
                copy(pubKey1[:], zombie.NodeKey1)
×
1740
                copy(pubKey2[:], zombie.NodeKey2)
×
1741
                isZombie = true
×
1742

×
1743
                return nil
×
1744
        }, sqldb.NoOpReset)
1745
        if err != nil {
×
1746
                return false, route.Vertex{}, route.Vertex{},
×
1747
                        fmt.Errorf("%w: %w (chanID=%d)",
×
1748
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
×
1749
        }
×
1750

1751
        return isZombie, pubKey1, pubKey2, nil
×
1752
}
1753

1754
// NumZombies returns the current number of zombie channels in the graph.
1755
//
1756
// NOTE: part of the V1Store interface.
1757
func (s *SQLStore) NumZombies() (uint64, error) {
×
1758
        var (
×
1759
                ctx        = context.TODO()
×
1760
                numZombies uint64
×
1761
        )
×
1762
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1763
                count, err := db.CountZombieChannels(ctx, int16(ProtocolV1))
×
1764
                if err != nil {
×
1765
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1766
                                err)
×
1767
                }
×
1768

1769
                numZombies = uint64(count)
×
1770

×
1771
                return nil
×
1772
        }, sqldb.NoOpReset)
1773
        if err != nil {
×
1774
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1775
        }
×
1776

1777
        return numZombies, nil
×
1778
}
1779

1780
// DeleteChannelEdges removes edges with the given channel IDs from the
1781
// database and marks them as zombies. This ensures that we're unable to re-add
1782
// it to our database once again. If an edge does not exist within the
1783
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1784
// true, then when we mark these edges as zombies, we'll set up the keys such
1785
// that we require the node that failed to send the fresh update to be the one
1786
// that resurrects the channel from its zombie state. The markZombie bool
1787
// denotes whether to mark the channel as a zombie.
1788
//
1789
// NOTE: part of the V1Store interface.
1790
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1791
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
1792

×
1793
        s.cacheMu.Lock()
×
1794
        defer s.cacheMu.Unlock()
×
1795

×
1796
        // Keep track of which channels we end up finding so that we can
×
1797
        // correctly return ErrEdgeNotFound if we do not find a channel.
×
1798
        chanLookup := make(map[uint64]struct{}, len(chanIDs))
×
1799
        for _, chanID := range chanIDs {
×
1800
                chanLookup[chanID] = struct{}{}
×
1801
        }
×
1802

1803
        var (
×
1804
                ctx   = context.TODO()
×
1805
                edges []*models.ChannelEdgeInfo
×
1806
        )
×
1807
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1808
                // First, collect all channel rows.
×
1809
                var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow
×
1810
                chanCallBack := func(ctx context.Context,
×
1811
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
1812

×
1813
                        // Deleting the entry from the map indicates that we
×
1814
                        // have found the channel.
×
1815
                        scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1816
                        delete(chanLookup, scid)
×
1817

×
1818
                        channelRows = append(channelRows, row)
×
1819

×
1820
                        return nil
×
1821
                }
×
1822

1823
                err := s.forEachChanWithPoliciesInSCIDList(
×
1824
                        ctx, db, chanCallBack, chanIDs,
×
1825
                )
×
1826
                if err != nil {
×
1827
                        return err
×
1828
                }
×
1829

1830
                if len(chanLookup) > 0 {
×
1831
                        return ErrEdgeNotFound
×
1832
                }
×
1833

1834
                if len(channelRows) == 0 {
×
1835
                        return nil
×
1836
                }
×
1837

1838
                // Batch build all channel edges.
1839
                var chanIDsToDelete []int64
×
1840
                edges, chanIDsToDelete, err = batchBuildChannelInfo(
×
1841
                        ctx, s.cfg, db, channelRows,
×
1842
                )
×
1843
                if err != nil {
×
1844
                        return err
×
1845
                }
×
1846

1847
                if markZombie {
×
1848
                        for i, row := range channelRows {
×
1849
                                scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1850

×
1851
                                err := handleZombieMarking(
×
1852
                                        ctx, db, row, edges[i],
×
1853
                                        strictZombiePruning, scid,
×
1854
                                )
×
1855
                                if err != nil {
×
1856
                                        return fmt.Errorf("unable to mark "+
×
1857
                                                "channel as zombie: %w", err)
×
1858
                                }
×
1859
                        }
1860
                }
1861

1862
                return s.deleteChannels(ctx, db, chanIDsToDelete)
×
1863
        }, func() {
×
1864
                edges = nil
×
1865

×
1866
                // Re-fill the lookup map.
×
1867
                for _, chanID := range chanIDs {
×
1868
                        chanLookup[chanID] = struct{}{}
×
1869
                }
×
1870
        })
1871
        if err != nil {
×
1872
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1873
                        err)
×
1874
        }
×
1875

1876
        for _, chanID := range chanIDs {
×
1877
                s.rejectCache.remove(chanID)
×
1878
                s.chanCache.remove(chanID)
×
1879
        }
×
1880

1881
        return edges, nil
×
1882
}
1883

1884
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1885
// channel identified by the channel ID. If the channel can't be found, then
1886
// ErrEdgeNotFound is returned. A struct which houses the general information
1887
// for the channel itself is returned as well as two structs that contain the
1888
// routing policies for the channel in either direction.
1889
//
1890
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1891
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1892
// the ChannelEdgeInfo will only include the public keys of each node.
1893
//
1894
// NOTE: part of the V1Store interface.
1895
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1896
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1897
        *models.ChannelEdgePolicy, error) {
×
1898

×
1899
        var (
×
1900
                ctx              = context.TODO()
×
1901
                edge             *models.ChannelEdgeInfo
×
1902
                policy1, policy2 *models.ChannelEdgePolicy
×
1903
                chanIDB          = channelIDToBytes(chanID)
×
1904
        )
×
1905
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1906
                row, err := db.GetChannelBySCIDWithPolicies(
×
1907
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1908
                                Scid:    chanIDB,
×
1909
                                Version: int16(ProtocolV1),
×
1910
                        },
×
1911
                )
×
1912
                if errors.Is(err, sql.ErrNoRows) {
×
1913
                        // First check if this edge is perhaps in the zombie
×
1914
                        // index.
×
1915
                        zombie, err := db.GetZombieChannel(
×
1916
                                ctx, sqlc.GetZombieChannelParams{
×
1917
                                        Scid:    chanIDB,
×
1918
                                        Version: int16(ProtocolV1),
×
1919
                                },
×
1920
                        )
×
1921
                        if errors.Is(err, sql.ErrNoRows) {
×
1922
                                return ErrEdgeNotFound
×
1923
                        } else if err != nil {
×
1924
                                return fmt.Errorf("unable to check if "+
×
1925
                                        "channel is zombie: %w", err)
×
1926
                        }
×
1927

1928
                        // At this point, we know the channel is a zombie, so
1929
                        // we'll return an error indicating this, and we will
1930
                        // populate the edge info with the public keys of each
1931
                        // party as this is the only information we have about
1932
                        // it.
1933
                        edge = &models.ChannelEdgeInfo{}
×
1934
                        copy(edge.NodeKey1Bytes[:], zombie.NodeKey1)
×
1935
                        copy(edge.NodeKey2Bytes[:], zombie.NodeKey2)
×
1936

×
1937
                        return ErrZombieEdge
×
1938
                } else if err != nil {
×
1939
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1940
                }
×
1941

1942
                node1, node2, err := buildNodeVertices(
×
1943
                        row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
1944
                )
×
1945
                if err != nil {
×
1946
                        return err
×
1947
                }
×
1948

1949
                edge, err = getAndBuildEdgeInfo(
×
1950
                        ctx, s.cfg, db, row.GraphChannel, node1, node2,
×
1951
                )
×
1952
                if err != nil {
×
1953
                        return fmt.Errorf("unable to build channel info: %w",
×
1954
                                err)
×
1955
                }
×
1956

1957
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1958
                if err != nil {
×
1959
                        return fmt.Errorf("unable to extract channel "+
×
1960
                                "policies: %w", err)
×
1961
                }
×
1962

1963
                policy1, policy2, err = getAndBuildChanPolicies(
×
1964
                        ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, edge.ChannelID,
×
1965
                        node1, node2,
×
1966
                )
×
1967
                if err != nil {
×
1968
                        return fmt.Errorf("unable to build channel "+
×
1969
                                "policies: %w", err)
×
1970
                }
×
1971

1972
                return nil
×
1973
        }, sqldb.NoOpReset)
1974
        if err != nil {
×
1975
                // If we are returning the ErrZombieEdge, then we also need to
×
1976
                // return the edge info as the method comment indicates that
×
1977
                // this will be populated when the edge is a zombie.
×
1978
                return edge, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1979
                        err)
×
1980
        }
×
1981

1982
        return edge, policy1, policy2, nil
×
1983
}
1984

1985
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
1986
// the channel identified by the funding outpoint. If the channel can't be
1987
// found, then ErrEdgeNotFound is returned. A struct which houses the general
1988
// information for the channel itself is returned as well as two structs that
1989
// contain the routing policies for the channel in either direction.
1990
//
1991
// NOTE: part of the V1Store interface.
1992
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
1993
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1994
        *models.ChannelEdgePolicy, error) {
×
1995

×
1996
        var (
×
1997
                ctx              = context.TODO()
×
1998
                edge             *models.ChannelEdgeInfo
×
1999
                policy1, policy2 *models.ChannelEdgePolicy
×
2000
        )
×
2001
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2002
                row, err := db.GetChannelByOutpointWithPolicies(
×
2003
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
2004
                                Outpoint: op.String(),
×
2005
                                Version:  int16(ProtocolV1),
×
2006
                        },
×
2007
                )
×
2008
                if errors.Is(err, sql.ErrNoRows) {
×
2009
                        return ErrEdgeNotFound
×
2010
                } else if err != nil {
×
2011
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2012
                }
×
2013

2014
                node1, node2, err := buildNodeVertices(
×
2015
                        row.Node1Pubkey, row.Node2Pubkey,
×
2016
                )
×
2017
                if err != nil {
×
2018
                        return err
×
2019
                }
×
2020

2021
                edge, err = getAndBuildEdgeInfo(
×
2022
                        ctx, s.cfg, db, row.GraphChannel, node1, node2,
×
2023
                )
×
2024
                if err != nil {
×
2025
                        return fmt.Errorf("unable to build channel info: %w",
×
2026
                                err)
×
2027
                }
×
2028

2029
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2030
                if err != nil {
×
2031
                        return fmt.Errorf("unable to extract channel "+
×
2032
                                "policies: %w", err)
×
2033
                }
×
2034

2035
                policy1, policy2, err = getAndBuildChanPolicies(
×
2036
                        ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, edge.ChannelID,
×
2037
                        node1, node2,
×
2038
                )
×
2039
                if err != nil {
×
2040
                        return fmt.Errorf("unable to build channel "+
×
2041
                                "policies: %w", err)
×
2042
                }
×
2043

2044
                return nil
×
2045
        }, sqldb.NoOpReset)
2046
        if err != nil {
×
2047
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
2048
                        err)
×
2049
        }
×
2050

2051
        return edge, policy1, policy2, nil
×
2052
}
2053

2054
// HasChannelEdge returns true if the database knows of a channel edge with the
2055
// passed channel ID, and false otherwise. If an edge with that ID is found
2056
// within the graph, then two time stamps representing the last time the edge
2057
// was updated for both directed edges are returned along with the boolean. If
2058
// it is not found, then the zombie index is checked and its result is returned
2059
// as the second boolean.
2060
//
2061
// NOTE: part of the V1Store interface.
2062
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
2063
        bool, error) {
×
2064

×
2065
        ctx := context.TODO()
×
2066

×
2067
        var (
×
2068
                exists          bool
×
2069
                isZombie        bool
×
2070
                node1LastUpdate time.Time
×
2071
                node2LastUpdate time.Time
×
2072
        )
×
2073

×
2074
        // We'll query the cache with the shared lock held to allow multiple
×
2075
        // readers to access values in the cache concurrently if they exist.
×
2076
        s.cacheMu.RLock()
×
2077
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2078
                s.cacheMu.RUnlock()
×
2079
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2080
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2081
                exists, isZombie = entry.flags.unpack()
×
2082

×
2083
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2084
        }
×
2085
        s.cacheMu.RUnlock()
×
2086

×
2087
        s.cacheMu.Lock()
×
2088
        defer s.cacheMu.Unlock()
×
2089

×
2090
        // The item was not found with the shared lock, so we'll acquire the
×
2091
        // exclusive lock and check the cache again in case another method added
×
2092
        // the entry to the cache while no lock was held.
×
2093
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2094
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2095
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2096
                exists, isZombie = entry.flags.unpack()
×
2097

×
2098
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2099
        }
×
2100

2101
        chanIDB := channelIDToBytes(chanID)
×
2102
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2103
                channel, err := db.GetChannelBySCID(
×
2104
                        ctx, sqlc.GetChannelBySCIDParams{
×
2105
                                Scid:    chanIDB,
×
2106
                                Version: int16(ProtocolV1),
×
2107
                        },
×
2108
                )
×
2109
                if errors.Is(err, sql.ErrNoRows) {
×
2110
                        // Check if it is a zombie channel.
×
2111
                        isZombie, err = db.IsZombieChannel(
×
2112
                                ctx, sqlc.IsZombieChannelParams{
×
2113
                                        Scid:    chanIDB,
×
2114
                                        Version: int16(ProtocolV1),
×
2115
                                },
×
2116
                        )
×
2117
                        if err != nil {
×
2118
                                return fmt.Errorf("could not check if channel "+
×
2119
                                        "is zombie: %w", err)
×
2120
                        }
×
2121

2122
                        return nil
×
2123
                } else if err != nil {
×
2124
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2125
                }
×
2126

2127
                exists = true
×
2128

×
2129
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
2130
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2131
                                Version:   int16(ProtocolV1),
×
2132
                                ChannelID: channel.ID,
×
2133
                                NodeID:    channel.NodeID1,
×
2134
                        },
×
2135
                )
×
2136
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2137
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2138
                                err)
×
2139
                } else if err == nil {
×
2140
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
2141
                }
×
2142

2143
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
2144
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2145
                                Version:   int16(ProtocolV1),
×
2146
                                ChannelID: channel.ID,
×
2147
                                NodeID:    channel.NodeID2,
×
2148
                        },
×
2149
                )
×
2150
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2151
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2152
                                err)
×
2153
                } else if err == nil {
×
2154
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
2155
                }
×
2156

2157
                return nil
×
2158
        }, sqldb.NoOpReset)
2159
        if err != nil {
×
2160
                return time.Time{}, time.Time{}, false, false,
×
2161
                        fmt.Errorf("unable to fetch channel: %w", err)
×
2162
        }
×
2163

2164
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
2165
                upd1Time: node1LastUpdate.Unix(),
×
2166
                upd2Time: node2LastUpdate.Unix(),
×
2167
                flags:    packRejectFlags(exists, isZombie),
×
2168
        })
×
2169

×
2170
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2171
}
2172

2173
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2174
// passed channel point (outpoint). If the passed channel doesn't exist within
2175
// the database, then ErrEdgeNotFound is returned.
2176
//
2177
// NOTE: part of the V1Store interface.
2178
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
2179
        var (
×
2180
                ctx       = context.TODO()
×
2181
                channelID uint64
×
2182
        )
×
2183
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2184
                chanID, err := db.GetSCIDByOutpoint(
×
2185
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2186
                                Outpoint: chanPoint.String(),
×
2187
                                Version:  int16(ProtocolV1),
×
2188
                        },
×
2189
                )
×
2190
                if errors.Is(err, sql.ErrNoRows) {
×
2191
                        return ErrEdgeNotFound
×
2192
                } else if err != nil {
×
2193
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2194
                                err)
×
2195
                }
×
2196

2197
                channelID = byteOrder.Uint64(chanID)
×
2198

×
2199
                return nil
×
2200
        }, sqldb.NoOpReset)
2201
        if err != nil {
×
2202
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2203
        }
×
2204

2205
        return channelID, nil
×
2206
}
2207

2208
// IsPublicNode is a helper method that determines whether the node with the
2209
// given public key is seen as a public node in the graph from the graph's
2210
// source node's point of view.
2211
//
2212
// NOTE: part of the V1Store interface.
2213
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
2214
        ctx := context.TODO()
×
2215

×
2216
        var isPublic bool
×
2217
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2218
                var err error
×
2219
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2220

×
2221
                return err
×
2222
        }, sqldb.NoOpReset)
×
2223
        if err != nil {
×
2224
                return false, fmt.Errorf("unable to check if node is "+
×
2225
                        "public: %w", err)
×
2226
        }
×
2227

2228
        return isPublic, nil
×
2229
}
2230

2231
// FetchChanInfos returns the set of channel edges that correspond to the passed
2232
// channel ID's. If an edge is the query is unknown to the database, it will
2233
// skipped and the result will contain only those edges that exist at the time
2234
// of the query. This can be used to respond to peer queries that are seeking to
2235
// fill in gaps in their view of the channel graph.
2236
//
2237
// NOTE: part of the V1Store interface.
2238
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2239
        var (
×
2240
                ctx   = context.TODO()
×
2241
                edges = make(map[uint64]ChannelEdge)
×
2242
        )
×
2243
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2244
                // First, collect all channel rows.
×
2245
                var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow
×
2246
                chanCallBack := func(ctx context.Context,
×
2247
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
2248

×
2249
                        channelRows = append(channelRows, row)
×
2250
                        return nil
×
2251
                }
×
2252

2253
                err := s.forEachChanWithPoliciesInSCIDList(
×
2254
                        ctx, db, chanCallBack, chanIDs,
×
2255
                )
×
2256
                if err != nil {
×
2257
                        return err
×
2258
                }
×
2259

2260
                if len(channelRows) == 0 {
×
2261
                        return nil
×
2262
                }
×
2263

2264
                // Batch build all channel edges.
2265
                chans, err := batchBuildChannelEdges(
×
2266
                        ctx, s.cfg, db, channelRows,
×
2267
                )
×
2268
                if err != nil {
×
2269
                        return fmt.Errorf("unable to build channel edges: %w",
×
2270
                                err)
×
2271
                }
×
2272

2273
                for _, c := range chans {
×
2274
                        edges[c.Info.ChannelID] = c
×
2275
                }
×
2276

2277
                return err
×
2278
        }, func() {
×
2279
                clear(edges)
×
2280
        })
×
2281
        if err != nil {
×
2282
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2283
        }
×
2284

2285
        res := make([]ChannelEdge, 0, len(edges))
×
2286
        for _, chanID := range chanIDs {
×
2287
                edge, ok := edges[chanID]
×
2288
                if !ok {
×
2289
                        continue
×
2290
                }
2291

2292
                res = append(res, edge)
×
2293
        }
2294

2295
        return res, nil
×
2296
}
2297

2298
// forEachChanWithPoliciesInSCIDList is a wrapper around the
2299
// GetChannelsBySCIDWithPolicies query that allows us to iterate through
2300
// channels in a paginated manner.
2301
func (s *SQLStore) forEachChanWithPoliciesInSCIDList(ctx context.Context,
2302
        db SQLQueries, cb func(ctx context.Context,
2303
                row sqlc.GetChannelsBySCIDWithPoliciesRow) error,
2304
        chanIDs []uint64) error {
×
2305

×
2306
        queryWrapper := func(ctx context.Context,
×
2307
                scids [][]byte) ([]sqlc.GetChannelsBySCIDWithPoliciesRow,
×
2308
                error) {
×
2309

×
2310
                return db.GetChannelsBySCIDWithPolicies(
×
2311
                        ctx, sqlc.GetChannelsBySCIDWithPoliciesParams{
×
2312
                                Version: int16(ProtocolV1),
×
2313
                                Scids:   scids,
×
2314
                        },
×
2315
                )
×
2316
        }
×
2317

2318
        return sqldb.ExecuteBatchQuery(
×
2319
                ctx, s.cfg.QueryCfg, chanIDs, channelIDToBytes, queryWrapper,
×
2320
                cb,
×
2321
        )
×
2322
}
2323

2324
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2325
// ID's that we don't know and are not known zombies of the passed set. In other
2326
// words, we perform a set difference of our set of chan ID's and the ones
2327
// passed in. This method can be used by callers to determine the set of
2328
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2329
// known zombies is also returned.
2330
//
2331
// NOTE: part of the V1Store interface.
2332
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
2333
        []ChannelUpdateInfo, error) {
×
2334

×
2335
        var (
×
2336
                ctx          = context.TODO()
×
2337
                newChanIDs   []uint64
×
2338
                knownZombies []ChannelUpdateInfo
×
2339
                infoLookup   = make(
×
2340
                        map[uint64]ChannelUpdateInfo, len(chansInfo),
×
2341
                )
×
2342
        )
×
2343

×
2344
        // We first build a lookup map of the channel ID's to the
×
2345
        // ChannelUpdateInfo. This allows us to quickly delete channels that we
×
2346
        // already know about.
×
2347
        for _, chanInfo := range chansInfo {
×
2348
                infoLookup[chanInfo.ShortChannelID.ToUint64()] = chanInfo
×
2349
        }
×
2350

2351
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2352
                // The call-back function deletes known channels from
×
2353
                // infoLookup, so that we can later check which channels are
×
2354
                // zombies by only looking at the remaining channels in the set.
×
2355
                cb := func(ctx context.Context,
×
2356
                        channel sqlc.GraphChannel) error {
×
2357

×
2358
                        delete(infoLookup, byteOrder.Uint64(channel.Scid))
×
2359

×
2360
                        return nil
×
2361
                }
×
2362

2363
                err := s.forEachChanInSCIDList(ctx, db, cb, chansInfo)
×
2364
                if err != nil {
×
2365
                        return fmt.Errorf("unable to iterate through "+
×
2366
                                "channels: %w", err)
×
2367
                }
×
2368

2369
                // We want to ensure that we deal with the channels in the
2370
                // same order that they were passed in, so we iterate over the
2371
                // original chansInfo slice and then check if that channel is
2372
                // still in the infoLookup map.
2373
                for _, chanInfo := range chansInfo {
×
2374
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
2375
                        if _, ok := infoLookup[channelID]; !ok {
×
2376
                                continue
×
2377
                        }
2378

2379
                        isZombie, err := db.IsZombieChannel(
×
2380
                                ctx, sqlc.IsZombieChannelParams{
×
2381
                                        Scid:    channelIDToBytes(channelID),
×
2382
                                        Version: int16(ProtocolV1),
×
2383
                                },
×
2384
                        )
×
2385
                        if err != nil {
×
2386
                                return fmt.Errorf("unable to fetch zombie "+
×
2387
                                        "channel: %w", err)
×
2388
                        }
×
2389

2390
                        if isZombie {
×
2391
                                knownZombies = append(knownZombies, chanInfo)
×
2392

×
2393
                                continue
×
2394
                        }
2395

2396
                        newChanIDs = append(newChanIDs, channelID)
×
2397
                }
2398

2399
                return nil
×
2400
        }, func() {
×
2401
                newChanIDs = nil
×
2402
                knownZombies = nil
×
2403
                // Rebuild the infoLookup map in case of a rollback.
×
2404
                for _, chanInfo := range chansInfo {
×
2405
                        scid := chanInfo.ShortChannelID.ToUint64()
×
2406
                        infoLookup[scid] = chanInfo
×
2407
                }
×
2408
        })
2409
        if err != nil {
×
2410
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2411
        }
×
2412

2413
        return newChanIDs, knownZombies, nil
×
2414
}
2415

2416
// forEachChanInSCIDList is a helper method that executes a paged query
2417
// against the database to fetch all channels that match the passed
2418
// ChannelUpdateInfo slice. The callback function is called for each channel
2419
// that is found.
2420
func (s *SQLStore) forEachChanInSCIDList(ctx context.Context, db SQLQueries,
2421
        cb func(ctx context.Context, channel sqlc.GraphChannel) error,
2422
        chansInfo []ChannelUpdateInfo) error {
×
2423

×
2424
        queryWrapper := func(ctx context.Context,
×
2425
                scids [][]byte) ([]sqlc.GraphChannel, error) {
×
2426

×
2427
                return db.GetChannelsBySCIDs(
×
2428
                        ctx, sqlc.GetChannelsBySCIDsParams{
×
2429
                                Version: int16(ProtocolV1),
×
2430
                                Scids:   scids,
×
2431
                        },
×
2432
                )
×
2433
        }
×
2434

2435
        chanIDConverter := func(chanInfo ChannelUpdateInfo) []byte {
×
2436
                channelID := chanInfo.ShortChannelID.ToUint64()
×
2437

×
2438
                return channelIDToBytes(channelID)
×
2439
        }
×
2440

2441
        return sqldb.ExecuteBatchQuery(
×
2442
                ctx, s.cfg.QueryCfg, chansInfo, chanIDConverter, queryWrapper,
×
2443
                cb,
×
2444
        )
×
2445
}
2446

2447
// PruneGraphNodes is a garbage collection method which attempts to prune out
2448
// any nodes from the channel graph that are currently unconnected. This ensure
2449
// that we only maintain a graph of reachable nodes. In the event that a pruned
2450
// node gains more channels, it will be re-added back to the graph.
2451
//
2452
// NOTE: this prunes nodes across protocol versions. It will never prune the
2453
// source nodes.
2454
//
2455
// NOTE: part of the V1Store interface.
2456
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
2457
        var ctx = context.TODO()
×
2458

×
2459
        var prunedNodes []route.Vertex
×
2460
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2461
                var err error
×
2462
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2463

×
2464
                return err
×
2465
        }, func() {
×
2466
                prunedNodes = nil
×
2467
        })
×
2468
        if err != nil {
×
2469
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2470
        }
×
2471

2472
        return prunedNodes, nil
×
2473
}
2474

2475
// PruneGraph prunes newly closed channels from the channel graph in response
2476
// to a new block being solved on the network. Any transactions which spend the
2477
// funding output of any known channels within he graph will be deleted.
2478
// Additionally, the "prune tip", or the last block which has been used to
2479
// prune the graph is stored so callers can ensure the graph is fully in sync
2480
// with the current UTXO state. A slice of channels that have been closed by
2481
// the target block along with any pruned nodes are returned if the function
2482
// succeeds without error.
2483
//
2484
// NOTE: part of the V1Store interface.
2485
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2486
        blockHash *chainhash.Hash, blockHeight uint32) (
2487
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
2488

×
2489
        ctx := context.TODO()
×
2490

×
2491
        s.cacheMu.Lock()
×
2492
        defer s.cacheMu.Unlock()
×
2493

×
2494
        var (
×
2495
                closedChans []*models.ChannelEdgeInfo
×
2496
                prunedNodes []route.Vertex
×
2497
        )
×
2498
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2499
                // First, collect all channel rows that need to be pruned.
×
2500
                var channelRows []sqlc.GetChannelsByOutpointsRow
×
2501
                channelCallback := func(ctx context.Context,
×
2502
                        row sqlc.GetChannelsByOutpointsRow) error {
×
2503

×
2504
                        channelRows = append(channelRows, row)
×
2505

×
2506
                        return nil
×
2507
                }
×
2508

2509
                err := s.forEachChanInOutpoints(
×
2510
                        ctx, db, spentOutputs, channelCallback,
×
2511
                )
×
2512
                if err != nil {
×
2513
                        return fmt.Errorf("unable to fetch channels by "+
×
2514
                                "outpoints: %w", err)
×
2515
                }
×
2516

2517
                if len(channelRows) == 0 {
×
2518
                        // There are no channels to prune. So we can exit early
×
2519
                        // after updating the prune log.
×
2520
                        err = db.UpsertPruneLogEntry(
×
2521
                                ctx, sqlc.UpsertPruneLogEntryParams{
×
2522
                                        BlockHash:   blockHash[:],
×
2523
                                        BlockHeight: int64(blockHeight),
×
2524
                                },
×
2525
                        )
×
2526
                        if err != nil {
×
2527
                                return fmt.Errorf("unable to insert prune log "+
×
2528
                                        "entry: %w", err)
×
2529
                        }
×
2530

2531
                        return nil
×
2532
                }
2533

2534
                // Batch build all channel edges for pruning.
2535
                var chansToDelete []int64
×
2536
                closedChans, chansToDelete, err = batchBuildChannelInfo(
×
2537
                        ctx, s.cfg, db, channelRows,
×
2538
                )
×
2539
                if err != nil {
×
2540
                        return err
×
2541
                }
×
2542

2543
                err = s.deleteChannels(ctx, db, chansToDelete)
×
2544
                if err != nil {
×
2545
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2546
                }
×
2547

2548
                err = db.UpsertPruneLogEntry(
×
2549
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
2550
                                BlockHash:   blockHash[:],
×
2551
                                BlockHeight: int64(blockHeight),
×
2552
                        },
×
2553
                )
×
2554
                if err != nil {
×
2555
                        return fmt.Errorf("unable to insert prune log "+
×
2556
                                "entry: %w", err)
×
2557
                }
×
2558

2559
                // Now that we've pruned some channels, we'll also prune any
2560
                // nodes that no longer have any channels.
2561
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2562
                if err != nil {
×
2563
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2564
                                err)
×
2565
                }
×
2566

2567
                return nil
×
2568
        }, func() {
×
2569
                prunedNodes = nil
×
2570
                closedChans = nil
×
2571
        })
×
2572
        if err != nil {
×
2573
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2574
        }
×
2575

2576
        for _, channel := range closedChans {
×
2577
                s.rejectCache.remove(channel.ChannelID)
×
2578
                s.chanCache.remove(channel.ChannelID)
×
2579
        }
×
2580

2581
        return closedChans, prunedNodes, nil
×
2582
}
2583

2584
// forEachChanInOutpoints is a helper function that executes a paginated
2585
// query to fetch channels by their outpoints and applies the given call-back
2586
// to each.
2587
//
2588
// NOTE: this fetches channels for all protocol versions.
2589
func (s *SQLStore) forEachChanInOutpoints(ctx context.Context, db SQLQueries,
2590
        outpoints []*wire.OutPoint, cb func(ctx context.Context,
2591
                row sqlc.GetChannelsByOutpointsRow) error) error {
×
2592

×
2593
        // Create a wrapper that uses the transaction's db instance to execute
×
2594
        // the query.
×
2595
        queryWrapper := func(ctx context.Context,
×
2596
                pageOutpoints []string) ([]sqlc.GetChannelsByOutpointsRow,
×
2597
                error) {
×
2598

×
2599
                return db.GetChannelsByOutpoints(ctx, pageOutpoints)
×
2600
        }
×
2601

2602
        // Define the conversion function from Outpoint to string.
2603
        outpointToString := func(outpoint *wire.OutPoint) string {
×
2604
                return outpoint.String()
×
2605
        }
×
2606

2607
        return sqldb.ExecuteBatchQuery(
×
2608
                ctx, s.cfg.QueryCfg, outpoints, outpointToString,
×
2609
                queryWrapper, cb,
×
2610
        )
×
2611
}
2612

2613
func (s *SQLStore) deleteChannels(ctx context.Context, db SQLQueries,
2614
        dbIDs []int64) error {
×
2615

×
2616
        // Create a wrapper that uses the transaction's db instance to execute
×
2617
        // the query.
×
2618
        queryWrapper := func(ctx context.Context, ids []int64) ([]any, error) {
×
2619
                return nil, db.DeleteChannels(ctx, ids)
×
2620
        }
×
2621

2622
        idConverter := func(id int64) int64 {
×
2623
                return id
×
2624
        }
×
2625

2626
        return sqldb.ExecuteBatchQuery(
×
2627
                ctx, s.cfg.QueryCfg, dbIDs, idConverter,
×
2628
                queryWrapper, func(ctx context.Context, _ any) error {
×
2629
                        return nil
×
2630
                },
×
2631
        )
2632
}
2633

2634
// ChannelView returns the verifiable edge information for each active channel
2635
// within the known channel graph. The set of UTXOs (along with their scripts)
2636
// returned are the ones that need to be watched on chain to detect channel
2637
// closes on the resident blockchain.
2638
//
2639
// NOTE: part of the V1Store interface.
2640
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
2641
        var (
×
2642
                ctx        = context.TODO()
×
2643
                edgePoints []EdgePoint
×
2644
        )
×
2645

×
2646
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2647
                handleChannel := func(_ context.Context,
×
2648
                        channel sqlc.ListChannelsPaginatedRow) error {
×
2649

×
2650
                        pkScript, err := genMultiSigP2WSH(
×
2651
                                channel.BitcoinKey1, channel.BitcoinKey2,
×
2652
                        )
×
2653
                        if err != nil {
×
2654
                                return err
×
2655
                        }
×
2656

2657
                        op, err := wire.NewOutPointFromString(channel.Outpoint)
×
2658
                        if err != nil {
×
2659
                                return err
×
2660
                        }
×
2661

2662
                        edgePoints = append(edgePoints, EdgePoint{
×
2663
                                FundingPkScript: pkScript,
×
2664
                                OutPoint:        *op,
×
2665
                        })
×
2666

×
2667
                        return nil
×
2668
                }
2669

2670
                queryFunc := func(ctx context.Context, lastID int64,
×
2671
                        limit int32) ([]sqlc.ListChannelsPaginatedRow, error) {
×
2672

×
2673
                        return db.ListChannelsPaginated(
×
2674
                                ctx, sqlc.ListChannelsPaginatedParams{
×
2675
                                        Version: int16(ProtocolV1),
×
2676
                                        ID:      lastID,
×
2677
                                        Limit:   limit,
×
2678
                                },
×
2679
                        )
×
2680
                }
×
2681

2682
                extractCursor := func(row sqlc.ListChannelsPaginatedRow) int64 {
×
2683
                        return row.ID
×
2684
                }
×
2685

2686
                return sqldb.ExecutePaginatedQuery(
×
2687
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
2688
                        extractCursor, handleChannel,
×
2689
                )
×
2690
        }, func() {
×
2691
                edgePoints = nil
×
2692
        })
×
2693
        if err != nil {
×
2694
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
2695
        }
×
2696

2697
        return edgePoints, nil
×
2698
}
2699

2700
// PruneTip returns the block height and hash of the latest block that has been
2701
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2702
// to tell if the graph is currently in sync with the current best known UTXO
2703
// state.
2704
//
2705
// NOTE: part of the V1Store interface.
2706
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
2707
        var (
×
2708
                ctx       = context.TODO()
×
2709
                tipHash   chainhash.Hash
×
2710
                tipHeight uint32
×
2711
        )
×
2712
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2713
                pruneTip, err := db.GetPruneTip(ctx)
×
2714
                if errors.Is(err, sql.ErrNoRows) {
×
2715
                        return ErrGraphNeverPruned
×
2716
                } else if err != nil {
×
2717
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
2718
                }
×
2719

2720
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
2721
                tipHeight = uint32(pruneTip.BlockHeight)
×
2722

×
2723
                return nil
×
2724
        }, sqldb.NoOpReset)
2725
        if err != nil {
×
2726
                return nil, 0, err
×
2727
        }
×
2728

2729
        return &tipHash, tipHeight, nil
×
2730
}
2731

2732
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2733
//
2734
// NOTE: this prunes nodes across protocol versions. It will never prune the
2735
// source nodes.
2736
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
2737
        db SQLQueries) ([]route.Vertex, error) {
×
2738

×
2739
        nodeKeys, err := db.DeleteUnconnectedNodes(ctx)
×
2740
        if err != nil {
×
2741
                return nil, fmt.Errorf("unable to delete unconnected "+
×
2742
                        "nodes: %w", err)
×
2743
        }
×
2744

2745
        prunedNodes := make([]route.Vertex, len(nodeKeys))
×
2746
        for i, nodeKey := range nodeKeys {
×
2747
                pub, err := route.NewVertexFromBytes(nodeKey)
×
2748
                if err != nil {
×
2749
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
2750
                                "from bytes: %w", err)
×
2751
                }
×
2752

2753
                prunedNodes[i] = pub
×
2754
        }
2755

2756
        return prunedNodes, nil
×
2757
}
2758

2759
// DisconnectBlockAtHeight is used to indicate that the block specified
2760
// by the passed height has been disconnected from the main chain. This
2761
// will "rewind" the graph back to the height below, deleting channels
2762
// that are no longer confirmed from the graph. The prune log will be
2763
// set to the last prune height valid for the remaining chain.
2764
// Channels that were removed from the graph resulting from the
2765
// disconnected block are returned.
2766
//
2767
// NOTE: part of the V1Store interface.
2768
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
2769
        []*models.ChannelEdgeInfo, error) {
×
2770

×
2771
        ctx := context.TODO()
×
2772

×
2773
        var (
×
2774
                // Every channel having a ShortChannelID starting at 'height'
×
2775
                // will no longer be confirmed.
×
2776
                startShortChanID = lnwire.ShortChannelID{
×
2777
                        BlockHeight: height,
×
2778
                }
×
2779

×
2780
                // Delete everything after this height from the db up until the
×
2781
                // SCID alias range.
×
2782
                endShortChanID = aliasmgr.StartingAlias
×
2783

×
2784
                removedChans []*models.ChannelEdgeInfo
×
2785

×
2786
                chanIDStart = channelIDToBytes(startShortChanID.ToUint64())
×
2787
                chanIDEnd   = channelIDToBytes(endShortChanID.ToUint64())
×
2788
        )
×
2789

×
2790
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2791
                rows, err := db.GetChannelsBySCIDRange(
×
2792
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
2793
                                StartScid: chanIDStart,
×
2794
                                EndScid:   chanIDEnd,
×
2795
                        },
×
2796
                )
×
2797
                if err != nil {
×
2798
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2799
                }
×
2800

2801
                if len(rows) == 0 {
×
2802
                        // No channels to disconnect, but still clean up prune
×
2803
                        // log.
×
2804
                        return db.DeletePruneLogEntriesInRange(
×
2805
                                ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2806
                                        StartHeight: int64(height),
×
2807
                                        EndHeight: int64(
×
2808
                                                endShortChanID.BlockHeight,
×
2809
                                        ),
×
2810
                                },
×
2811
                        )
×
2812
                }
×
2813

2814
                // Batch build all channel edges for disconnection.
2815
                channelEdges, chanIDsToDelete, err := batchBuildChannelInfo(
×
2816
                        ctx, s.cfg, db, rows,
×
2817
                )
×
2818
                if err != nil {
×
2819
                        return err
×
2820
                }
×
2821

2822
                removedChans = channelEdges
×
2823

×
2824
                err = s.deleteChannels(ctx, db, chanIDsToDelete)
×
2825
                if err != nil {
×
2826
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2827
                }
×
2828

2829
                return db.DeletePruneLogEntriesInRange(
×
2830
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2831
                                StartHeight: int64(height),
×
2832
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
2833
                        },
×
2834
                )
×
2835
        }, func() {
×
2836
                removedChans = nil
×
2837
        })
×
2838
        if err != nil {
×
2839
                return nil, fmt.Errorf("unable to disconnect block at "+
×
2840
                        "height: %w", err)
×
2841
        }
×
2842

2843
        for _, channel := range removedChans {
×
2844
                s.rejectCache.remove(channel.ChannelID)
×
2845
                s.chanCache.remove(channel.ChannelID)
×
2846
        }
×
2847

2848
        return removedChans, nil
×
2849
}
2850

2851
// AddEdgeProof sets the proof of an existing edge in the graph database.
2852
//
2853
// NOTE: part of the V1Store interface.
2854
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
2855
        proof *models.ChannelAuthProof) error {
×
2856

×
2857
        var (
×
2858
                ctx       = context.TODO()
×
2859
                scidBytes = channelIDToBytes(scid.ToUint64())
×
2860
        )
×
2861

×
2862
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2863
                res, err := db.AddV1ChannelProof(
×
2864
                        ctx, sqlc.AddV1ChannelProofParams{
×
2865
                                Scid:              scidBytes,
×
2866
                                Node1Signature:    proof.NodeSig1Bytes,
×
2867
                                Node2Signature:    proof.NodeSig2Bytes,
×
2868
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
2869
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
2870
                        },
×
2871
                )
×
2872
                if err != nil {
×
2873
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
2874
                }
×
2875

2876
                n, err := res.RowsAffected()
×
2877
                if err != nil {
×
2878
                        return err
×
2879
                }
×
2880

2881
                if n == 0 {
×
2882
                        return fmt.Errorf("no rows affected when adding edge "+
×
2883
                                "proof for SCID %v", scid)
×
2884
                } else if n > 1 {
×
2885
                        return fmt.Errorf("multiple rows affected when adding "+
×
2886
                                "edge proof for SCID %v: %d rows affected",
×
2887
                                scid, n)
×
2888
                }
×
2889

2890
                return nil
×
2891
        }, sqldb.NoOpReset)
2892
        if err != nil {
×
2893
                return fmt.Errorf("unable to add edge proof: %w", err)
×
2894
        }
×
2895

2896
        return nil
×
2897
}
2898

2899
// PutClosedScid stores a SCID for a closed channel in the database. This is so
2900
// that we can ignore channel announcements that we know to be closed without
2901
// having to validate them and fetch a block.
2902
//
2903
// NOTE: part of the V1Store interface.
2904
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
2905
        var (
×
2906
                ctx     = context.TODO()
×
2907
                chanIDB = channelIDToBytes(scid.ToUint64())
×
2908
        )
×
2909

×
2910
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2911
                return db.InsertClosedChannel(ctx, chanIDB)
×
2912
        }, sqldb.NoOpReset)
×
2913
}
2914

2915
// IsClosedScid checks whether a channel identified by the passed in scid is
2916
// closed. This helps avoid having to perform expensive validation checks.
2917
//
2918
// NOTE: part of the V1Store interface.
2919
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
2920
        var (
×
2921
                ctx      = context.TODO()
×
2922
                isClosed bool
×
2923
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
2924
        )
×
2925
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2926
                var err error
×
2927
                isClosed, err = db.IsClosedChannel(ctx, chanIDB)
×
2928
                if err != nil {
×
2929
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
2930
                                err)
×
2931
                }
×
2932

2933
                return nil
×
2934
        }, sqldb.NoOpReset)
2935
        if err != nil {
×
2936
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
2937
                        err)
×
2938
        }
×
2939

2940
        return isClosed, nil
×
2941
}
2942

2943
// GraphSession will provide the call-back with access to a NodeTraverser
2944
// instance which can be used to perform queries against the channel graph.
2945
//
2946
// NOTE: part of the V1Store interface.
2947
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error,
2948
        reset func()) error {
×
2949

×
2950
        var ctx = context.TODO()
×
2951

×
2952
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2953
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
2954
        }, reset)
×
2955
}
2956

2957
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
2958
// read only transaction for a consistent view of the graph.
2959
type sqlNodeTraverser struct {
2960
        db    SQLQueries
2961
        chain chainhash.Hash
2962
}
2963

2964
// A compile-time assertion to ensure that sqlNodeTraverser implements the
2965
// NodeTraverser interface.
2966
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
2967

2968
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
2969
func newSQLNodeTraverser(db SQLQueries,
2970
        chain chainhash.Hash) *sqlNodeTraverser {
×
2971

×
2972
        return &sqlNodeTraverser{
×
2973
                db:    db,
×
2974
                chain: chain,
×
2975
        }
×
2976
}
×
2977

2978
// ForEachNodeDirectedChannel calls the callback for every channel of the given
2979
// node.
2980
//
2981
// NOTE: Part of the NodeTraverser interface.
2982
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
2983
        cb func(channel *DirectedChannel) error, _ func()) error {
×
2984

×
2985
        ctx := context.TODO()
×
2986

×
2987
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
2988
}
×
2989

2990
// FetchNodeFeatures returns the features of the given node. If the node is
2991
// unknown, assume no additional features are supported.
2992
//
2993
// NOTE: Part of the NodeTraverser interface.
2994
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
2995
        *lnwire.FeatureVector, error) {
×
2996

×
2997
        ctx := context.TODO()
×
2998

×
2999
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
3000
}
×
3001

3002
// forEachNodeDirectedChannel iterates through all channels of a given
3003
// node, executing the passed callback on the directed edge representing the
3004
// channel and its incoming policy. If the node is not found, no error is
3005
// returned.
3006
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
3007
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
3008

×
3009
        toNodeCallback := func() route.Vertex {
×
3010
                return nodePub
×
3011
        }
×
3012

3013
        dbID, err := db.GetNodeIDByPubKey(
×
3014
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
3015
                        Version: int16(ProtocolV1),
×
3016
                        PubKey:  nodePub[:],
×
3017
                },
×
3018
        )
×
3019
        if errors.Is(err, sql.ErrNoRows) {
×
3020
                return nil
×
3021
        } else if err != nil {
×
3022
                return fmt.Errorf("unable to fetch node: %w", err)
×
3023
        }
×
3024

3025
        rows, err := db.ListChannelsByNodeID(
×
3026
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3027
                        Version: int16(ProtocolV1),
×
3028
                        NodeID1: dbID,
×
3029
                },
×
3030
        )
×
3031
        if err != nil {
×
3032
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3033
        }
×
3034

3035
        // Exit early if there are no channels for this node so we don't
3036
        // do the unnecessary feature fetching.
3037
        if len(rows) == 0 {
×
3038
                return nil
×
3039
        }
×
3040

3041
        features, err := getNodeFeatures(ctx, db, dbID)
×
3042
        if err != nil {
×
3043
                return fmt.Errorf("unable to fetch node features: %w", err)
×
3044
        }
×
3045

3046
        for _, row := range rows {
×
3047
                node1, node2, err := buildNodeVertices(
×
3048
                        row.Node1Pubkey, row.Node2Pubkey,
×
3049
                )
×
3050
                if err != nil {
×
3051
                        return fmt.Errorf("unable to build node vertices: %w",
×
3052
                                err)
×
3053
                }
×
3054

3055
                edge := buildCacheableChannelInfo(
×
3056
                        row.GraphChannel.Scid, row.GraphChannel.Capacity.Int64,
×
3057
                        node1, node2,
×
3058
                )
×
3059

×
3060
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3061
                if err != nil {
×
3062
                        return err
×
3063
                }
×
3064

3065
                p1, p2, err := buildCachedChanPolicies(
×
3066
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
3067
                )
×
3068
                if err != nil {
×
3069
                        return err
×
3070
                }
×
3071

3072
                // Determine the outgoing and incoming policy for this
3073
                // channel and node combo.
3074
                outPolicy, inPolicy := p1, p2
×
3075
                if p1 != nil && node2 == nodePub {
×
3076
                        outPolicy, inPolicy = p2, p1
×
3077
                } else if p2 != nil && node1 != nodePub {
×
3078
                        outPolicy, inPolicy = p2, p1
×
3079
                }
×
3080

3081
                var cachedInPolicy *models.CachedEdgePolicy
×
3082
                if inPolicy != nil {
×
3083
                        cachedInPolicy = inPolicy
×
3084
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
3085
                        cachedInPolicy.ToNodeFeatures = features
×
3086
                }
×
3087

3088
                directedChannel := &DirectedChannel{
×
3089
                        ChannelID:    edge.ChannelID,
×
3090
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
3091
                        OtherNode:    edge.NodeKey2Bytes,
×
3092
                        Capacity:     edge.Capacity,
×
3093
                        OutPolicySet: outPolicy != nil,
×
3094
                        InPolicy:     cachedInPolicy,
×
3095
                }
×
3096
                if outPolicy != nil {
×
3097
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3098
                                directedChannel.InboundFee = fee
×
3099
                        })
×
3100
                }
3101

3102
                if nodePub == edge.NodeKey2Bytes {
×
3103
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
3104
                }
×
3105

3106
                if err := cb(directedChannel); err != nil {
×
3107
                        return err
×
3108
                }
×
3109
        }
3110

3111
        return nil
×
3112
}
3113

3114
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
3115
// and executes the provided callback for each node. It does so via pagination
3116
// along with batch loading of the node feature bits.
3117
func forEachNodeCacheable(ctx context.Context, cfg *sqldb.QueryConfig,
3118
        db SQLQueries, processNode func(nodeID int64, nodePub route.Vertex,
3119
                features *lnwire.FeatureVector) error) error {
×
3120

×
3121
        handleNode := func(_ context.Context,
×
3122
                dbNode sqlc.ListNodeIDsAndPubKeysRow,
×
3123
                featureBits map[int64][]int) error {
×
3124

×
3125
                fv := lnwire.EmptyFeatureVector()
×
3126
                if features, exists := featureBits[dbNode.ID]; exists {
×
3127
                        for _, bit := range features {
×
3128
                                fv.Set(lnwire.FeatureBit(bit))
×
3129
                        }
×
3130
                }
3131

3132
                var pub route.Vertex
×
3133
                copy(pub[:], dbNode.PubKey)
×
3134

×
3135
                return processNode(dbNode.ID, pub, fv)
×
3136
        }
3137

3138
        queryFunc := func(ctx context.Context, lastID int64,
×
3139
                limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
3140

×
3141
                return db.ListNodeIDsAndPubKeys(
×
3142
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
3143
                                Version: int16(ProtocolV1),
×
3144
                                ID:      lastID,
×
3145
                                Limit:   limit,
×
3146
                        },
×
3147
                )
×
3148
        }
×
3149

3150
        extractCursor := func(row sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
3151
                return row.ID
×
3152
        }
×
3153

3154
        collectFunc := func(node sqlc.ListNodeIDsAndPubKeysRow) (int64, error) {
×
3155
                return node.ID, nil
×
3156
        }
×
3157

3158
        batchQueryFunc := func(ctx context.Context,
×
3159
                nodeIDs []int64) (map[int64][]int, error) {
×
3160

×
3161
                return batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
3162
        }
×
3163

3164
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
3165
                ctx, cfg, int64(-1), queryFunc, extractCursor, collectFunc,
×
3166
                batchQueryFunc, handleNode,
×
3167
        )
×
3168
}
3169

3170
// forEachNodeChannel iterates through all channels of a node, executing
3171
// the passed callback on each. The call-back is provided with the channel's
3172
// edge information, the outgoing policy and the incoming policy for the
3173
// channel and node combo.
3174
func forEachNodeChannel(ctx context.Context, db SQLQueries,
3175
        cfg *SQLStoreConfig, id int64, cb func(*models.ChannelEdgeInfo,
3176
                *models.ChannelEdgePolicy,
3177
                *models.ChannelEdgePolicy) error) error {
×
3178

×
3179
        // Get all the V1 channels for this node.
×
3180
        rows, err := db.ListChannelsByNodeID(
×
3181
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3182
                        Version: int16(ProtocolV1),
×
3183
                        NodeID1: id,
×
3184
                },
×
3185
        )
×
3186
        if err != nil {
×
3187
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3188
        }
×
3189

3190
        // Collect all the channel and policy IDs.
3191
        var (
×
3192
                chanIDs   = make([]int64, 0, len(rows))
×
3193
                policyIDs = make([]int64, 0, 2*len(rows))
×
3194
        )
×
3195
        for _, row := range rows {
×
3196
                chanIDs = append(chanIDs, row.GraphChannel.ID)
×
3197

×
3198
                if row.Policy1ID.Valid {
×
3199
                        policyIDs = append(policyIDs, row.Policy1ID.Int64)
×
3200
                }
×
3201
                if row.Policy2ID.Valid {
×
3202
                        policyIDs = append(policyIDs, row.Policy2ID.Int64)
×
3203
                }
×
3204
        }
3205

3206
        batchData, err := batchLoadChannelData(
×
3207
                ctx, cfg.QueryCfg, db, chanIDs, policyIDs,
×
3208
        )
×
3209
        if err != nil {
×
3210
                return fmt.Errorf("unable to batch load channel data: %w", err)
×
3211
        }
×
3212

3213
        // Call the call-back for each channel and its known policies.
3214
        for _, row := range rows {
×
3215
                node1, node2, err := buildNodeVertices(
×
3216
                        row.Node1Pubkey, row.Node2Pubkey,
×
3217
                )
×
3218
                if err != nil {
×
3219
                        return fmt.Errorf("unable to build node vertices: %w",
×
3220
                                err)
×
3221
                }
×
3222

3223
                edge, err := buildEdgeInfoWithBatchData(
×
3224
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
3225
                        batchData,
×
3226
                )
×
3227
                if err != nil {
×
3228
                        return fmt.Errorf("unable to build channel info: %w",
×
3229
                                err)
×
3230
                }
×
3231

3232
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3233
                if err != nil {
×
3234
                        return fmt.Errorf("unable to extract channel "+
×
3235
                                "policies: %w", err)
×
3236
                }
×
3237

3238
                p1, p2, err := buildChanPoliciesWithBatchData(
×
3239
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
3240
                )
×
3241
                if err != nil {
×
3242
                        return fmt.Errorf("unable to build channel "+
×
3243
                                "policies: %w", err)
×
3244
                }
×
3245

3246
                // Determine the outgoing and incoming policy for this
3247
                // channel and node combo.
3248
                p1ToNode := row.GraphChannel.NodeID2
×
3249
                p2ToNode := row.GraphChannel.NodeID1
×
3250
                outPolicy, inPolicy := p1, p2
×
3251
                if (p1 != nil && p1ToNode == id) ||
×
3252
                        (p2 != nil && p2ToNode != id) {
×
3253

×
3254
                        outPolicy, inPolicy = p2, p1
×
3255
                }
×
3256

3257
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3258
                        return err
×
3259
                }
×
3260
        }
3261

3262
        return nil
×
3263
}
3264

3265
// updateChanEdgePolicy upserts the channel policy info we have stored for
3266
// a channel we already know of.
3267
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3268
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
3269
        error) {
×
3270

×
3271
        var (
×
3272
                node1Pub, node2Pub route.Vertex
×
3273
                isNode1            bool
×
3274
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3275
        )
×
3276

×
3277
        // Check that this edge policy refers to a channel that we already
×
3278
        // know of. We do this explicitly so that we can return the appropriate
×
3279
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
3280
        // abort the transaction which would abort the entire batch.
×
3281
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3282
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3283
                        Scid:    chanIDB,
×
3284
                        Version: int16(ProtocolV1),
×
3285
                },
×
3286
        )
×
3287
        if errors.Is(err, sql.ErrNoRows) {
×
3288
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3289
        } else if err != nil {
×
3290
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3291
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3292
        }
×
3293

3294
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3295
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3296

×
3297
        // Figure out which node this edge is from.
×
3298
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3299
        nodeID := dbChan.NodeID1
×
3300
        if !isNode1 {
×
3301
                nodeID = dbChan.NodeID2
×
3302
        }
×
3303

3304
        var (
×
3305
                inboundBase sql.NullInt64
×
3306
                inboundRate sql.NullInt64
×
3307
        )
×
3308
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3309
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3310
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3311
        })
×
3312

3313
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
3314
                Version:     int16(ProtocolV1),
×
3315
                ChannelID:   dbChan.ID,
×
3316
                NodeID:      nodeID,
×
3317
                Timelock:    int32(edge.TimeLockDelta),
×
3318
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
3319
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
3320
                MinHtlcMsat: int64(edge.MinHTLC),
×
3321
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3322
                Disabled: sql.NullBool{
×
3323
                        Valid: true,
×
3324
                        Bool:  edge.IsDisabled(),
×
3325
                },
×
3326
                MaxHtlcMsat: sql.NullInt64{
×
3327
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3328
                        Int64: int64(edge.MaxHTLC),
×
3329
                },
×
3330
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
3331
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
3332
                InboundBaseFeeMsat:      inboundBase,
×
3333
                InboundFeeRateMilliMsat: inboundRate,
×
3334
                Signature:               edge.SigBytes,
×
3335
        })
×
3336
        if err != nil {
×
3337
                return node1Pub, node2Pub, isNode1,
×
3338
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3339
        }
×
3340

3341
        // Convert the flat extra opaque data into a map of TLV types to
3342
        // values.
3343
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3344
        if err != nil {
×
3345
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3346
                        "marshal extra opaque data: %w", err)
×
3347
        }
×
3348

3349
        // Update the channel policy's extra signed fields.
3350
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3351
        if err != nil {
×
3352
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3353
                        "policy extra TLVs: %w", err)
×
3354
        }
×
3355

3356
        return node1Pub, node2Pub, isNode1, nil
×
3357
}
3358

3359
// getNodeByPubKey attempts to look up a target node by its public key.
3360
func getNodeByPubKey(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries,
3361
        pubKey route.Vertex) (int64, *models.Node, error) {
×
3362

×
3363
        dbNode, err := db.GetNodeByPubKey(
×
3364
                ctx, sqlc.GetNodeByPubKeyParams{
×
3365
                        Version: int16(ProtocolV1),
×
3366
                        PubKey:  pubKey[:],
×
3367
                },
×
3368
        )
×
3369
        if errors.Is(err, sql.ErrNoRows) {
×
3370
                return 0, nil, ErrGraphNodeNotFound
×
3371
        } else if err != nil {
×
3372
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3373
        }
×
3374

3375
        node, err := buildNode(ctx, cfg, db, dbNode)
×
3376
        if err != nil {
×
3377
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3378
        }
×
3379

3380
        return dbNode.ID, node, nil
×
3381
}
3382

3383
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3384
// provided parameters.
3385
func buildCacheableChannelInfo(scid []byte, capacity int64, node1Pub,
3386
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
3387

×
3388
        return &models.CachedEdgeInfo{
×
3389
                ChannelID:     byteOrder.Uint64(scid),
×
3390
                NodeKey1Bytes: node1Pub,
×
3391
                NodeKey2Bytes: node2Pub,
×
3392
                Capacity:      btcutil.Amount(capacity),
×
3393
        }
×
3394
}
×
3395

3396
// buildNode constructs a Node instance from the given database node
3397
// record. The node's features, addresses and extra signed fields are also
3398
// fetched from the database and set on the node.
3399
func buildNode(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries,
3400
        dbNode sqlc.GraphNode) (*models.Node, error) {
×
3401

×
3402
        data, err := batchLoadNodeData(ctx, cfg, db, []int64{dbNode.ID})
×
3403
        if err != nil {
×
3404
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
3405
                        err)
×
3406
        }
×
3407

3408
        return buildNodeWithBatchData(dbNode, data)
×
3409
}
3410

3411
// buildNodeWithBatchData builds a models.Node instance
3412
// from the provided sqlc.GraphNode and batchNodeData. If the node does have
3413
// features/addresses/extra fields, then the corresponding fields are expected
3414
// to be present in the batchNodeData.
3415
func buildNodeWithBatchData(dbNode sqlc.GraphNode,
3416
        batchData *batchNodeData) (*models.Node, error) {
×
3417

×
3418
        if dbNode.Version != int16(ProtocolV1) {
×
3419
                return nil, fmt.Errorf("unsupported node version: %d",
×
3420
                        dbNode.Version)
×
3421
        }
×
3422

3423
        var pub [33]byte
×
3424
        copy(pub[:], dbNode.PubKey)
×
3425

×
3426
        node := &models.Node{
×
3427
                PubKeyBytes: pub,
×
3428
                Features:    lnwire.EmptyFeatureVector(),
×
3429
                LastUpdate:  time.Unix(0, 0),
×
3430
        }
×
3431

×
3432
        if len(dbNode.Signature) == 0 {
×
3433
                return node, nil
×
3434
        }
×
3435

3436
        node.HaveNodeAnnouncement = true
×
3437
        node.AuthSigBytes = dbNode.Signature
×
3438
        node.Alias = dbNode.Alias.String
×
3439
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
3440

×
3441
        var err error
×
3442
        if dbNode.Color.Valid {
×
3443
                node.Color, err = DecodeHexColor(dbNode.Color.String)
×
3444
                if err != nil {
×
3445
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3446
                                err)
×
3447
                }
×
3448
        }
3449

3450
        // Use preloaded features.
3451
        if features, exists := batchData.features[dbNode.ID]; exists {
×
3452
                fv := lnwire.EmptyFeatureVector()
×
3453
                for _, bit := range features {
×
3454
                        fv.Set(lnwire.FeatureBit(bit))
×
3455
                }
×
3456
                node.Features = fv
×
3457
        }
3458

3459
        // Use preloaded addresses.
3460
        addresses, exists := batchData.addresses[dbNode.ID]
×
3461
        if exists && len(addresses) > 0 {
×
3462
                node.Addresses, err = buildNodeAddresses(addresses)
×
3463
                if err != nil {
×
3464
                        return nil, fmt.Errorf("unable to build addresses "+
×
3465
                                "for node(%d): %w", dbNode.ID, err)
×
3466
                }
×
3467
        }
3468

3469
        // Use preloaded extra fields.
3470
        if extraFields, exists := batchData.extraFields[dbNode.ID]; exists {
×
3471
                recs, err := lnwire.CustomRecords(extraFields).Serialize()
×
3472
                if err != nil {
×
3473
                        return nil, fmt.Errorf("unable to serialize extra "+
×
3474
                                "signed fields: %w", err)
×
3475
                }
×
3476
                if len(recs) != 0 {
×
3477
                        node.ExtraOpaqueData = recs
×
3478
                }
×
3479
        }
3480

3481
        return node, nil
×
3482
}
3483

3484
// forEachNodeInBatch fetches all nodes in the provided batch, builds them
3485
// with the preloaded data, and executes the provided callback for each node.
3486
func forEachNodeInBatch(ctx context.Context, cfg *sqldb.QueryConfig,
3487
        db SQLQueries, nodes []sqlc.GraphNode,
3488
        cb func(dbID int64, node *models.Node) error) error {
×
3489

×
3490
        // Extract node IDs for batch loading.
×
3491
        nodeIDs := make([]int64, len(nodes))
×
3492
        for i, node := range nodes {
×
3493
                nodeIDs[i] = node.ID
×
3494
        }
×
3495

3496
        // Batch load all related data for this page.
3497
        batchData, err := batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
3498
        if err != nil {
×
3499
                return fmt.Errorf("unable to batch load node data: %w", err)
×
3500
        }
×
3501

3502
        for _, dbNode := range nodes {
×
3503
                node, err := buildNodeWithBatchData(dbNode, batchData)
×
3504
                if err != nil {
×
3505
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
3506
                                dbNode.ID, err)
×
3507
                }
×
3508

3509
                if err := cb(dbNode.ID, node); err != nil {
×
3510
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
3511
                                dbNode.ID, err)
×
3512
                }
×
3513
        }
3514

3515
        return nil
×
3516
}
3517

3518
// getNodeFeatures fetches the feature bits and constructs the feature vector
3519
// for a node with the given DB ID.
3520
func getNodeFeatures(ctx context.Context, db SQLQueries,
3521
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3522

×
3523
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3524
        if err != nil {
×
3525
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3526
                        nodeID, err)
×
3527
        }
×
3528

3529
        features := lnwire.EmptyFeatureVector()
×
3530
        for _, feature := range rows {
×
3531
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3532
        }
×
3533

3534
        return features, nil
×
3535
}
3536

3537
// upsertNode upserts the node record into the database. If the node already
3538
// exists, then the node's information is updated. If the node doesn't exist,
3539
// then a new node is created. The node's features, addresses and extra TLV
3540
// types are also updated. The node's DB ID is returned.
3541
func upsertNode(ctx context.Context, db SQLQueries,
3542
        node *models.Node) (int64, error) {
×
3543

×
3544
        params := sqlc.UpsertNodeParams{
×
3545
                Version: int16(ProtocolV1),
×
3546
                PubKey:  node.PubKeyBytes[:],
×
3547
        }
×
3548

×
3549
        if node.HaveNodeAnnouncement {
×
3550
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
3551
                params.Color = sqldb.SQLStrValid(EncodeHexColor(node.Color))
×
3552
                params.Alias = sqldb.SQLStrValid(node.Alias)
×
3553
                params.Signature = node.AuthSigBytes
×
3554
        }
×
3555

3556
        nodeID, err := db.UpsertNode(ctx, params)
×
3557
        if err != nil {
×
3558
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3559
                        err)
×
3560
        }
×
3561

3562
        // We can exit here if we don't have the announcement yet.
3563
        if !node.HaveNodeAnnouncement {
×
3564
                return nodeID, nil
×
3565
        }
×
3566

3567
        // Update the node's features.
3568
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3569
        if err != nil {
×
3570
                return 0, fmt.Errorf("inserting node features: %w", err)
×
3571
        }
×
3572

3573
        // Update the node's addresses.
3574
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3575
        if err != nil {
×
3576
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
3577
        }
×
3578

3579
        // Convert the flat extra opaque data into a map of TLV types to
3580
        // values.
3581
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
3582
        if err != nil {
×
3583
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
3584
                        err)
×
3585
        }
×
3586

3587
        // Update the node's extra signed fields.
3588
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3589
        if err != nil {
×
3590
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
3591
        }
×
3592

3593
        return nodeID, nil
×
3594
}
3595

3596
// upsertNodeFeatures updates the node's features node_features table. This
3597
// includes deleting any feature bits no longer present and inserting any new
3598
// feature bits. If the feature bit does not yet exist in the features table,
3599
// then an entry is created in that table first.
3600
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3601
        features *lnwire.FeatureVector) error {
×
3602

×
3603
        // Get any existing features for the node.
×
3604
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3605
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3606
                return err
×
3607
        }
×
3608

3609
        // Copy the nodes latest set of feature bits.
3610
        newFeatures := make(map[int32]struct{})
×
3611
        if features != nil {
×
3612
                for feature := range features.Features() {
×
3613
                        newFeatures[int32(feature)] = struct{}{}
×
3614
                }
×
3615
        }
3616

3617
        // For any current feature that already exists in the DB, remove it from
3618
        // the in-memory map. For any existing feature that does not exist in
3619
        // the in-memory map, delete it from the database.
3620
        for _, feature := range existingFeatures {
×
3621
                // The feature is still present, so there are no updates to be
×
3622
                // made.
×
3623
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3624
                        delete(newFeatures, feature.FeatureBit)
×
3625
                        continue
×
3626
                }
3627

3628
                // The feature is no longer present, so we remove it from the
3629
                // database.
3630
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
3631
                        NodeID:     nodeID,
×
3632
                        FeatureBit: feature.FeatureBit,
×
3633
                })
×
3634
                if err != nil {
×
3635
                        return fmt.Errorf("unable to delete node(%d) "+
×
3636
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3637
                                err)
×
3638
                }
×
3639
        }
3640

3641
        // Any remaining entries in newFeatures are new features that need to be
3642
        // added to the database for the first time.
3643
        for feature := range newFeatures {
×
3644
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3645
                        NodeID:     nodeID,
×
3646
                        FeatureBit: feature,
×
3647
                })
×
3648
                if err != nil {
×
3649
                        return fmt.Errorf("unable to insert node(%d) "+
×
3650
                                "feature(%v): %w", nodeID, feature, err)
×
3651
                }
×
3652
        }
3653

3654
        return nil
×
3655
}
3656

3657
// fetchNodeFeatures fetches the features for a node with the given public key.
3658
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3659
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3660

×
3661
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3662
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3663
                        PubKey:  nodePub[:],
×
3664
                        Version: int16(ProtocolV1),
×
3665
                },
×
3666
        )
×
3667
        if err != nil {
×
3668
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
3669
                        nodePub, err)
×
3670
        }
×
3671

3672
        features := lnwire.EmptyFeatureVector()
×
3673
        for _, bit := range rows {
×
3674
                features.Set(lnwire.FeatureBit(bit))
×
3675
        }
×
3676

3677
        return features, nil
×
3678
}
3679

3680
// dbAddressType is an enum type that represents the different address types
3681
// that we store in the node_addresses table. The address type determines how
3682
// the address is to be serialised/deserialize.
3683
type dbAddressType uint8
3684

3685
const (
3686
        addressTypeIPv4   dbAddressType = 1
3687
        addressTypeIPv6   dbAddressType = 2
3688
        addressTypeTorV2  dbAddressType = 3
3689
        addressTypeTorV3  dbAddressType = 4
3690
        addressTypeDNS    dbAddressType = 5
3691
        addressTypeOpaque dbAddressType = math.MaxInt8
3692
)
3693

3694
// collectAddressRecords collects the addresses from the provided
3695
// net.Addr slice and returns a map of dbAddressType to a slice of address
3696
// strings.
3697
func collectAddressRecords(addresses []net.Addr) (map[dbAddressType][]string,
3698
        error) {
×
3699

×
3700
        // Copy the nodes latest set of addresses.
×
3701
        newAddresses := map[dbAddressType][]string{
×
3702
                addressTypeIPv4:   {},
×
3703
                addressTypeIPv6:   {},
×
3704
                addressTypeTorV2:  {},
×
3705
                addressTypeTorV3:  {},
×
3706
                addressTypeDNS:    {},
×
3707
                addressTypeOpaque: {},
×
3708
        }
×
3709
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3710
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3711
        }
×
3712

3713
        for _, address := range addresses {
×
3714
                switch addr := address.(type) {
×
3715
                case *net.TCPAddr:
×
3716
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3717
                                addAddr(addressTypeIPv4, addr)
×
3718
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3719
                                addAddr(addressTypeIPv6, addr)
×
3720
                        } else {
×
3721
                                return nil, fmt.Errorf("unhandled IP "+
×
3722
                                        "address: %v", addr)
×
3723
                        }
×
3724

3725
                case *tor.OnionAddr:
×
3726
                        switch len(addr.OnionService) {
×
3727
                        case tor.V2Len:
×
3728
                                addAddr(addressTypeTorV2, addr)
×
3729
                        case tor.V3Len:
×
3730
                                addAddr(addressTypeTorV3, addr)
×
3731
                        default:
×
3732
                                return nil, fmt.Errorf("invalid length for " +
×
3733
                                        "a tor address")
×
3734
                        }
3735

3736
                case *lnwire.DNSAddress:
×
3737
                        addAddr(addressTypeDNS, addr)
×
3738

3739
                case *lnwire.OpaqueAddrs:
×
3740
                        addAddr(addressTypeOpaque, addr)
×
3741

3742
                default:
×
3743
                        return nil, fmt.Errorf("unhandled address type: %T",
×
3744
                                addr)
×
3745
                }
3746
        }
3747

3748
        return newAddresses, nil
×
3749
}
3750

3751
// upsertNodeAddresses updates the node's addresses in the database. This
3752
// includes deleting any existing addresses and inserting the new set of
3753
// addresses. The deletion is necessary since the ordering of the addresses may
3754
// change, and we need to ensure that the database reflects the latest set of
3755
// addresses so that at the time of reconstructing the node announcement, the
3756
// order is preserved and the signature over the message remains valid.
3757
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
3758
        addresses []net.Addr) error {
×
3759

×
3760
        // Delete any existing addresses for the node. This is required since
×
3761
        // even if the new set of addresses is the same, the ordering may have
×
3762
        // changed for a given address type.
×
3763
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
3764
        if err != nil {
×
3765
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
3766
                        nodeID, err)
×
3767
        }
×
3768

3769
        newAddresses, err := collectAddressRecords(addresses)
×
3770
        if err != nil {
×
3771
                return err
×
3772
        }
×
3773

3774
        // Any remaining entries in newAddresses are new addresses that need to
3775
        // be added to the database for the first time.
3776
        for addrType, addrList := range newAddresses {
×
3777
                for position, addr := range addrList {
×
3778
                        err := db.UpsertNodeAddress(
×
3779
                                ctx, sqlc.UpsertNodeAddressParams{
×
3780
                                        NodeID:   nodeID,
×
3781
                                        Type:     int16(addrType),
×
3782
                                        Address:  addr,
×
3783
                                        Position: int32(position),
×
3784
                                },
×
3785
                        )
×
3786
                        if err != nil {
×
3787
                                return fmt.Errorf("unable to insert "+
×
3788
                                        "node(%d) address(%v): %w", nodeID,
×
3789
                                        addr, err)
×
3790
                        }
×
3791
                }
3792
        }
3793

3794
        return nil
×
3795
}
3796

3797
// getNodeAddresses fetches the addresses for a node with the given DB ID.
3798
func getNodeAddresses(ctx context.Context, db SQLQueries, id int64) ([]net.Addr,
3799
        error) {
×
3800

×
3801
        // GetNodeAddresses ensures that the addresses for a given type are
×
3802
        // returned in the same order as they were inserted.
×
3803
        rows, err := db.GetNodeAddresses(ctx, id)
×
3804
        if err != nil {
×
3805
                return nil, err
×
3806
        }
×
3807

3808
        addresses := make([]net.Addr, 0, len(rows))
×
3809
        for _, row := range rows {
×
3810
                address := row.Address
×
3811

×
3812
                addr, err := parseAddress(dbAddressType(row.Type), address)
×
3813
                if err != nil {
×
3814
                        return nil, fmt.Errorf("unable to parse address "+
×
3815
                                "for node(%d): %v: %w", id, address, err)
×
3816
                }
×
3817

3818
                addresses = append(addresses, addr)
×
3819
        }
3820

3821
        // If we have no addresses, then we'll return nil instead of an
3822
        // empty slice.
3823
        if len(addresses) == 0 {
×
3824
                addresses = nil
×
3825
        }
×
3826

3827
        return addresses, nil
×
3828
}
3829

3830
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
3831
// database. This includes updating any existing types, inserting any new types,
3832
// and deleting any types that are no longer present.
3833
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3834
        nodeID int64, extraFields map[uint64][]byte) error {
×
3835

×
3836
        // Get any existing extra signed fields for the node.
×
3837
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3838
        if err != nil {
×
3839
                return err
×
3840
        }
×
3841

3842
        // Make a lookup map of the existing field types so that we can use it
3843
        // to keep track of any fields we should delete.
3844
        m := make(map[uint64]bool)
×
3845
        for _, field := range existingFields {
×
3846
                m[uint64(field.Type)] = true
×
3847
        }
×
3848

3849
        // For all the new fields, we'll upsert them and remove them from the
3850
        // map of existing fields.
3851
        for tlvType, value := range extraFields {
×
3852
                err = db.UpsertNodeExtraType(
×
3853
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
3854
                                NodeID: nodeID,
×
3855
                                Type:   int64(tlvType),
×
3856
                                Value:  value,
×
3857
                        },
×
3858
                )
×
3859
                if err != nil {
×
3860
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
3861
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3862
                }
×
3863

3864
                // Remove the field from the map of existing fields if it was
3865
                // present.
3866
                delete(m, tlvType)
×
3867
        }
3868

3869
        // For all the fields that are left in the map of existing fields, we'll
3870
        // delete them as they are no longer present in the new set of fields.
3871
        for tlvType := range m {
×
3872
                err = db.DeleteExtraNodeType(
×
3873
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
3874
                                NodeID: nodeID,
×
3875
                                Type:   int64(tlvType),
×
3876
                        },
×
3877
                )
×
3878
                if err != nil {
×
3879
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
3880
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3881
                }
×
3882
        }
3883

3884
        return nil
×
3885
}
3886

3887
// srcNodeInfo holds the information about the source node of the graph.
3888
type srcNodeInfo struct {
3889
        // id is the DB level ID of the source node entry in the "nodes" table.
3890
        id int64
3891

3892
        // pub is the public key of the source node.
3893
        pub route.Vertex
3894
}
3895

3896
// sourceNode returns the DB node ID and pub key of the source node for the
3897
// specified protocol version.
3898
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
3899
        version ProtocolVersion) (int64, route.Vertex, error) {
×
3900

×
3901
        s.srcNodeMu.Lock()
×
3902
        defer s.srcNodeMu.Unlock()
×
3903

×
3904
        // If we already have the source node ID and pub key cached, then
×
3905
        // return them.
×
3906
        if info, ok := s.srcNodes[version]; ok {
×
3907
                return info.id, info.pub, nil
×
3908
        }
×
3909

3910
        var pubKey route.Vertex
×
3911

×
3912
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
3913
        if err != nil {
×
3914
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
3915
                        err)
×
3916
        }
×
3917

3918
        if len(nodes) == 0 {
×
3919
                return 0, pubKey, ErrSourceNodeNotSet
×
3920
        } else if len(nodes) > 1 {
×
3921
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
3922
                        "protocol %s found", version)
×
3923
        }
×
3924

3925
        copy(pubKey[:], nodes[0].PubKey)
×
3926

×
3927
        s.srcNodes[version] = &srcNodeInfo{
×
3928
                id:  nodes[0].NodeID,
×
3929
                pub: pubKey,
×
3930
        }
×
3931

×
3932
        return nodes[0].NodeID, pubKey, nil
×
3933
}
3934

3935
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
3936
// This then produces a map from TLV type to value. If the input is not a
3937
// valid TLV stream, then an error is returned.
3938
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
3939
        r := bytes.NewReader(data)
×
3940

×
3941
        tlvStream, err := tlv.NewStream()
×
3942
        if err != nil {
×
3943
                return nil, err
×
3944
        }
×
3945

3946
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
3947
        // pass it into the P2P decoding variant.
3948
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
3949
        if err != nil {
×
3950
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
3951
        }
×
3952
        if len(parsedTypes) == 0 {
×
3953
                return nil, nil
×
3954
        }
×
3955

3956
        records := make(map[uint64][]byte)
×
3957
        for k, v := range parsedTypes {
×
3958
                records[uint64(k)] = v
×
3959
        }
×
3960

3961
        return records, nil
×
3962
}
3963

3964
// insertChannel inserts a new channel record into the database.
3965
func insertChannel(ctx context.Context, db SQLQueries,
3966
        edge *models.ChannelEdgeInfo) error {
×
3967

×
3968
        // Make sure that at least a "shell" entry for each node is present in
×
3969
        // the nodes table.
×
3970
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
3971
        if err != nil {
×
3972
                return fmt.Errorf("unable to create shell node: %w", err)
×
3973
        }
×
3974

3975
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
3976
        if err != nil {
×
3977
                return fmt.Errorf("unable to create shell node: %w", err)
×
3978
        }
×
3979

3980
        var capacity sql.NullInt64
×
3981
        if edge.Capacity != 0 {
×
3982
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
3983
        }
×
3984

3985
        createParams := sqlc.CreateChannelParams{
×
3986
                Version:     int16(ProtocolV1),
×
3987
                Scid:        channelIDToBytes(edge.ChannelID),
×
3988
                NodeID1:     node1DBID,
×
3989
                NodeID2:     node2DBID,
×
3990
                Outpoint:    edge.ChannelPoint.String(),
×
3991
                Capacity:    capacity,
×
3992
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
3993
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
3994
        }
×
3995

×
3996
        if edge.AuthProof != nil {
×
3997
                proof := edge.AuthProof
×
3998

×
3999
                createParams.Node1Signature = proof.NodeSig1Bytes
×
4000
                createParams.Node2Signature = proof.NodeSig2Bytes
×
4001
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
4002
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
4003
        }
×
4004

4005
        // Insert the new channel record.
4006
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
4007
        if err != nil {
×
4008
                return err
×
4009
        }
×
4010

4011
        // Insert any channel features.
4012
        for feature := range edge.Features.Features() {
×
4013
                err = db.InsertChannelFeature(
×
4014
                        ctx, sqlc.InsertChannelFeatureParams{
×
4015
                                ChannelID:  dbChanID,
×
4016
                                FeatureBit: int32(feature),
×
4017
                        },
×
4018
                )
×
4019
                if err != nil {
×
4020
                        return fmt.Errorf("unable to insert channel(%d) "+
×
4021
                                "feature(%v): %w", dbChanID, feature, err)
×
4022
                }
×
4023
        }
4024

4025
        // Finally, insert any extra TLV fields in the channel announcement.
4026
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
4027
        if err != nil {
×
4028
                return fmt.Errorf("unable to marshal extra opaque data: %w",
×
4029
                        err)
×
4030
        }
×
4031

4032
        for tlvType, value := range extra {
×
4033
                err := db.UpsertChannelExtraType(
×
4034
                        ctx, sqlc.UpsertChannelExtraTypeParams{
×
4035
                                ChannelID: dbChanID,
×
4036
                                Type:      int64(tlvType),
×
4037
                                Value:     value,
×
4038
                        },
×
4039
                )
×
4040
                if err != nil {
×
4041
                        return fmt.Errorf("unable to upsert channel(%d) "+
×
4042
                                "extra signed field(%v): %w", edge.ChannelID,
×
4043
                                tlvType, err)
×
4044
                }
×
4045
        }
4046

4047
        return nil
×
4048
}
4049

4050
// maybeCreateShellNode checks if a shell node entry exists for the
4051
// given public key. If it does not exist, then a new shell node entry is
4052
// created. The ID of the node is returned. A shell node only has a protocol
4053
// version and public key persisted.
4054
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
4055
        pubKey route.Vertex) (int64, error) {
×
4056

×
4057
        dbNode, err := db.GetNodeByPubKey(
×
4058
                ctx, sqlc.GetNodeByPubKeyParams{
×
4059
                        PubKey:  pubKey[:],
×
4060
                        Version: int16(ProtocolV1),
×
4061
                },
×
4062
        )
×
4063
        // The node exists. Return the ID.
×
4064
        if err == nil {
×
4065
                return dbNode.ID, nil
×
4066
        } else if !errors.Is(err, sql.ErrNoRows) {
×
4067
                return 0, err
×
4068
        }
×
4069

4070
        // Otherwise, the node does not exist, so we create a shell entry for
4071
        // it.
4072
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
4073
                Version: int16(ProtocolV1),
×
4074
                PubKey:  pubKey[:],
×
4075
        })
×
4076
        if err != nil {
×
4077
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
4078
        }
×
4079

4080
        return id, nil
×
4081
}
4082

4083
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
4084
// the database. This includes deleting any existing types and then inserting
4085
// the new types.
4086
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
4087
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
4088

×
4089
        // Delete all existing extra signed fields for the channel policy.
×
4090
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
4091
        if err != nil {
×
4092
                return fmt.Errorf("unable to delete "+
×
4093
                        "existing policy extra signed fields for policy %d: %w",
×
4094
                        chanPolicyID, err)
×
4095
        }
×
4096

4097
        // Insert all new extra signed fields for the channel policy.
4098
        for tlvType, value := range extraFields {
×
4099
                err = db.UpsertChanPolicyExtraType(
×
4100
                        ctx, sqlc.UpsertChanPolicyExtraTypeParams{
×
4101
                                ChannelPolicyID: chanPolicyID,
×
4102
                                Type:            int64(tlvType),
×
4103
                                Value:           value,
×
4104
                        },
×
4105
                )
×
4106
                if err != nil {
×
4107
                        return fmt.Errorf("unable to insert "+
×
4108
                                "channel_policy(%d) extra signed field(%v): %w",
×
4109
                                chanPolicyID, tlvType, err)
×
4110
                }
×
4111
        }
4112

4113
        return nil
×
4114
}
4115

4116
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
4117
// provided dbChanRow and also fetches any other required information
4118
// to construct the edge info.
4119
func getAndBuildEdgeInfo(ctx context.Context, cfg *SQLStoreConfig,
4120
        db SQLQueries, dbChan sqlc.GraphChannel, node1,
4121
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
4122

×
4123
        data, err := batchLoadChannelData(
×
4124
                ctx, cfg.QueryCfg, db, []int64{dbChan.ID}, nil,
×
4125
        )
×
4126
        if err != nil {
×
4127
                return nil, fmt.Errorf("unable to batch load channel data: %w",
×
4128
                        err)
×
4129
        }
×
4130

4131
        return buildEdgeInfoWithBatchData(
×
4132
                cfg.ChainHash, dbChan, node1, node2, data,
×
4133
        )
×
4134
}
4135

4136
// buildEdgeInfoWithBatchData builds edge info using pre-loaded batch data.
4137
func buildEdgeInfoWithBatchData(chain chainhash.Hash,
4138
        dbChan sqlc.GraphChannel, node1, node2 route.Vertex,
4139
        batchData *batchChannelData) (*models.ChannelEdgeInfo, error) {
×
4140

×
4141
        if dbChan.Version != int16(ProtocolV1) {
×
4142
                return nil, fmt.Errorf("unsupported channel version: %d",
×
4143
                        dbChan.Version)
×
4144
        }
×
4145

4146
        // Use pre-loaded features and extras types.
4147
        fv := lnwire.EmptyFeatureVector()
×
4148
        if features, exists := batchData.chanfeatures[dbChan.ID]; exists {
×
4149
                for _, bit := range features {
×
4150
                        fv.Set(lnwire.FeatureBit(bit))
×
4151
                }
×
4152
        }
4153

4154
        var extras map[uint64][]byte
×
4155
        channelExtras, exists := batchData.chanExtraTypes[dbChan.ID]
×
4156
        if exists {
×
4157
                extras = channelExtras
×
4158
        } else {
×
4159
                extras = make(map[uint64][]byte)
×
4160
        }
×
4161

4162
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
4163
        if err != nil {
×
4164
                return nil, err
×
4165
        }
×
4166

4167
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4168
        if err != nil {
×
4169
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4170
                        "fields: %w", err)
×
4171
        }
×
4172
        if recs == nil {
×
4173
                recs = make([]byte, 0)
×
4174
        }
×
4175

4176
        var btcKey1, btcKey2 route.Vertex
×
4177
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
4178
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
4179

×
4180
        channel := &models.ChannelEdgeInfo{
×
4181
                ChainHash:        chain,
×
4182
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
4183
                NodeKey1Bytes:    node1,
×
4184
                NodeKey2Bytes:    node2,
×
4185
                BitcoinKey1Bytes: btcKey1,
×
4186
                BitcoinKey2Bytes: btcKey2,
×
4187
                ChannelPoint:     *op,
×
4188
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
4189
                Features:         fv,
×
4190
                ExtraOpaqueData:  recs,
×
4191
        }
×
4192

×
4193
        // We always set all the signatures at the same time, so we can
×
4194
        // safely check if one signature is present to determine if we have the
×
4195
        // rest of the signatures for the auth proof.
×
4196
        if len(dbChan.Bitcoin1Signature) > 0 {
×
4197
                channel.AuthProof = &models.ChannelAuthProof{
×
4198
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
4199
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
4200
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
4201
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
4202
                }
×
4203
        }
×
4204

4205
        return channel, nil
×
4206
}
4207

4208
// buildNodeVertices is a helper that converts raw node public keys
4209
// into route.Vertex instances.
4210
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4211
        route.Vertex, error) {
×
4212

×
4213
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4214
        if err != nil {
×
4215
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4216
                        "create vertex from node1 pubkey: %w", err)
×
4217
        }
×
4218

4219
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
4220
        if err != nil {
×
4221
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4222
                        "create vertex from node2 pubkey: %w", err)
×
4223
        }
×
4224

4225
        return node1Vertex, node2Vertex, nil
×
4226
}
4227

4228
// getAndBuildChanPolicies uses the given sqlc.GraphChannelPolicy and also
4229
// retrieves all the extra info required to build the complete
4230
// models.ChannelEdgePolicy types. It returns two policies, which may be nil if
4231
// the provided sqlc.GraphChannelPolicy records are nil.
4232
func getAndBuildChanPolicies(ctx context.Context, cfg *sqldb.QueryConfig,
4233
        db SQLQueries, dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4234
        channelID uint64, node1, node2 route.Vertex) (*models.ChannelEdgePolicy,
4235
        *models.ChannelEdgePolicy, error) {
×
4236

×
4237
        if dbPol1 == nil && dbPol2 == nil {
×
4238
                return nil, nil, nil
×
4239
        }
×
4240

4241
        var policyIDs = make([]int64, 0, 2)
×
4242
        if dbPol1 != nil {
×
4243
                policyIDs = append(policyIDs, dbPol1.ID)
×
4244
        }
×
4245
        if dbPol2 != nil {
×
4246
                policyIDs = append(policyIDs, dbPol2.ID)
×
4247
        }
×
4248

4249
        batchData, err := batchLoadChannelData(ctx, cfg, db, nil, policyIDs)
×
4250
        if err != nil {
×
4251
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
4252
                        "data: %w", err)
×
4253
        }
×
4254

4255
        pol1, err := buildChanPolicyWithBatchData(
×
4256
                dbPol1, channelID, node2, batchData,
×
4257
        )
×
4258
        if err != nil {
×
4259
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
4260
        }
×
4261

4262
        pol2, err := buildChanPolicyWithBatchData(
×
4263
                dbPol2, channelID, node1, batchData,
×
4264
        )
×
4265
        if err != nil {
×
4266
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
4267
        }
×
4268

4269
        return pol1, pol2, nil
×
4270
}
4271

4272
// buildCachedChanPolicies builds models.CachedEdgePolicy instances from the
4273
// provided sqlc.GraphChannelPolicy objects. If a provided policy is nil,
4274
// then nil is returned for it.
4275
func buildCachedChanPolicies(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4276
        channelID uint64, node1, node2 route.Vertex) (*models.CachedEdgePolicy,
4277
        *models.CachedEdgePolicy, error) {
×
4278

×
4279
        var p1, p2 *models.CachedEdgePolicy
×
4280
        if dbPol1 != nil {
×
4281
                policy1, err := buildChanPolicy(*dbPol1, channelID, nil, node2)
×
4282
                if err != nil {
×
4283
                        return nil, nil, err
×
4284
                }
×
4285

4286
                p1 = models.NewCachedPolicy(policy1)
×
4287
        }
4288
        if dbPol2 != nil {
×
4289
                policy2, err := buildChanPolicy(*dbPol2, channelID, nil, node1)
×
4290
                if err != nil {
×
4291
                        return nil, nil, err
×
4292
                }
×
4293

4294
                p2 = models.NewCachedPolicy(policy2)
×
4295
        }
4296

4297
        return p1, p2, nil
×
4298
}
4299

4300
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4301
// provided sqlc.GraphChannelPolicy and other required information.
4302
func buildChanPolicy(dbPolicy sqlc.GraphChannelPolicy, channelID uint64,
4303
        extras map[uint64][]byte,
4304
        toNode route.Vertex) (*models.ChannelEdgePolicy, error) {
×
4305

×
4306
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4307
        if err != nil {
×
4308
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4309
                        "fields: %w", err)
×
4310
        }
×
4311

4312
        var inboundFee fn.Option[lnwire.Fee]
×
4313
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
4314
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4315

×
4316
                inboundFee = fn.Some(lnwire.Fee{
×
4317
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4318
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4319
                })
×
4320
        }
×
4321

4322
        return &models.ChannelEdgePolicy{
×
4323
                SigBytes:  dbPolicy.Signature,
×
4324
                ChannelID: channelID,
×
4325
                LastUpdate: time.Unix(
×
4326
                        dbPolicy.LastUpdate.Int64, 0,
×
4327
                ),
×
4328
                MessageFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags](
×
4329
                        dbPolicy.MessageFlags,
×
4330
                ),
×
4331
                ChannelFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags](
×
4332
                        dbPolicy.ChannelFlags,
×
4333
                ),
×
4334
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
4335
                MinHTLC: lnwire.MilliSatoshi(
×
4336
                        dbPolicy.MinHtlcMsat,
×
4337
                ),
×
4338
                MaxHTLC: lnwire.MilliSatoshi(
×
4339
                        dbPolicy.MaxHtlcMsat.Int64,
×
4340
                ),
×
4341
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4342
                        dbPolicy.BaseFeeMsat,
×
4343
                ),
×
4344
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4345
                ToNode:                    toNode,
×
4346
                InboundFee:                inboundFee,
×
4347
                ExtraOpaqueData:           recs,
×
4348
        }, nil
×
4349
}
4350

4351
// extractChannelPolicies extracts the sqlc.GraphChannelPolicy records from the give
4352
// row which is expected to be a sqlc type that contains channel policy
4353
// information. It returns two policies, which may be nil if the policy
4354
// information is not present in the row.
4355
//
4356
//nolint:ll,dupl,funlen
4357
func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy,
4358
        *sqlc.GraphChannelPolicy, error) {
×
4359

×
4360
        var policy1, policy2 *sqlc.GraphChannelPolicy
×
4361
        switch r := row.(type) {
×
4362
        case sqlc.ListChannelsWithPoliciesForCachePaginatedRow:
×
4363
                if r.Policy1Timelock.Valid {
×
4364
                        policy1 = &sqlc.GraphChannelPolicy{
×
4365
                                Timelock:                r.Policy1Timelock.Int32,
×
4366
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4367
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4368
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4369
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4370
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4371
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4372
                                Disabled:                r.Policy1Disabled,
×
4373
                                MessageFlags:            r.Policy1MessageFlags,
×
4374
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4375
                        }
×
4376
                }
×
4377
                if r.Policy2Timelock.Valid {
×
4378
                        policy2 = &sqlc.GraphChannelPolicy{
×
4379
                                Timelock:                r.Policy2Timelock.Int32,
×
4380
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4381
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4382
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4383
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4384
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4385
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4386
                                Disabled:                r.Policy2Disabled,
×
4387
                                MessageFlags:            r.Policy2MessageFlags,
×
4388
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4389
                        }
×
4390
                }
×
4391

4392
                return policy1, policy2, nil
×
4393

4394
        case sqlc.GetChannelsBySCIDWithPoliciesRow:
×
4395
                if r.Policy1ID.Valid {
×
4396
                        policy1 = &sqlc.GraphChannelPolicy{
×
4397
                                ID:                      r.Policy1ID.Int64,
×
4398
                                Version:                 r.Policy1Version.Int16,
×
4399
                                ChannelID:               r.GraphChannel.ID,
×
4400
                                NodeID:                  r.Policy1NodeID.Int64,
×
4401
                                Timelock:                r.Policy1Timelock.Int32,
×
4402
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4403
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4404
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4405
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4406
                                LastUpdate:              r.Policy1LastUpdate,
×
4407
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4408
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4409
                                Disabled:                r.Policy1Disabled,
×
4410
                                MessageFlags:            r.Policy1MessageFlags,
×
4411
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4412
                                Signature:               r.Policy1Signature,
×
4413
                        }
×
4414
                }
×
4415
                if r.Policy2ID.Valid {
×
4416
                        policy2 = &sqlc.GraphChannelPolicy{
×
4417
                                ID:                      r.Policy2ID.Int64,
×
4418
                                Version:                 r.Policy2Version.Int16,
×
4419
                                ChannelID:               r.GraphChannel.ID,
×
4420
                                NodeID:                  r.Policy2NodeID.Int64,
×
4421
                                Timelock:                r.Policy2Timelock.Int32,
×
4422
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4423
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4424
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4425
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4426
                                LastUpdate:              r.Policy2LastUpdate,
×
4427
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4428
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4429
                                Disabled:                r.Policy2Disabled,
×
4430
                                MessageFlags:            r.Policy2MessageFlags,
×
4431
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4432
                                Signature:               r.Policy2Signature,
×
4433
                        }
×
4434
                }
×
4435

4436
                return policy1, policy2, nil
×
4437

4438
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
4439
                if r.Policy1ID.Valid {
×
4440
                        policy1 = &sqlc.GraphChannelPolicy{
×
4441
                                ID:                      r.Policy1ID.Int64,
×
4442
                                Version:                 r.Policy1Version.Int16,
×
4443
                                ChannelID:               r.GraphChannel.ID,
×
4444
                                NodeID:                  r.Policy1NodeID.Int64,
×
4445
                                Timelock:                r.Policy1Timelock.Int32,
×
4446
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4447
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4448
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4449
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4450
                                LastUpdate:              r.Policy1LastUpdate,
×
4451
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4452
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4453
                                Disabled:                r.Policy1Disabled,
×
4454
                                MessageFlags:            r.Policy1MessageFlags,
×
4455
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4456
                                Signature:               r.Policy1Signature,
×
4457
                        }
×
4458
                }
×
4459
                if r.Policy2ID.Valid {
×
4460
                        policy2 = &sqlc.GraphChannelPolicy{
×
4461
                                ID:                      r.Policy2ID.Int64,
×
4462
                                Version:                 r.Policy2Version.Int16,
×
4463
                                ChannelID:               r.GraphChannel.ID,
×
4464
                                NodeID:                  r.Policy2NodeID.Int64,
×
4465
                                Timelock:                r.Policy2Timelock.Int32,
×
4466
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4467
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4468
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4469
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4470
                                LastUpdate:              r.Policy2LastUpdate,
×
4471
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4472
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4473
                                Disabled:                r.Policy2Disabled,
×
4474
                                MessageFlags:            r.Policy2MessageFlags,
×
4475
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4476
                                Signature:               r.Policy2Signature,
×
4477
                        }
×
4478
                }
×
4479

4480
                return policy1, policy2, nil
×
4481

4482
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4483
                if r.Policy1ID.Valid {
×
4484
                        policy1 = &sqlc.GraphChannelPolicy{
×
4485
                                ID:                      r.Policy1ID.Int64,
×
4486
                                Version:                 r.Policy1Version.Int16,
×
4487
                                ChannelID:               r.GraphChannel.ID,
×
4488
                                NodeID:                  r.Policy1NodeID.Int64,
×
4489
                                Timelock:                r.Policy1Timelock.Int32,
×
4490
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4491
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4492
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4493
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4494
                                LastUpdate:              r.Policy1LastUpdate,
×
4495
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4496
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4497
                                Disabled:                r.Policy1Disabled,
×
4498
                                MessageFlags:            r.Policy1MessageFlags,
×
4499
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4500
                                Signature:               r.Policy1Signature,
×
4501
                        }
×
4502
                }
×
4503
                if r.Policy2ID.Valid {
×
4504
                        policy2 = &sqlc.GraphChannelPolicy{
×
4505
                                ID:                      r.Policy2ID.Int64,
×
4506
                                Version:                 r.Policy2Version.Int16,
×
4507
                                ChannelID:               r.GraphChannel.ID,
×
4508
                                NodeID:                  r.Policy2NodeID.Int64,
×
4509
                                Timelock:                r.Policy2Timelock.Int32,
×
4510
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4511
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4512
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4513
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4514
                                LastUpdate:              r.Policy2LastUpdate,
×
4515
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4516
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4517
                                Disabled:                r.Policy2Disabled,
×
4518
                                MessageFlags:            r.Policy2MessageFlags,
×
4519
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4520
                                Signature:               r.Policy2Signature,
×
4521
                        }
×
4522
                }
×
4523

4524
                return policy1, policy2, nil
×
4525

4526
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4527
                if r.Policy1ID.Valid {
×
4528
                        policy1 = &sqlc.GraphChannelPolicy{
×
4529
                                ID:                      r.Policy1ID.Int64,
×
4530
                                Version:                 r.Policy1Version.Int16,
×
4531
                                ChannelID:               r.GraphChannel.ID,
×
4532
                                NodeID:                  r.Policy1NodeID.Int64,
×
4533
                                Timelock:                r.Policy1Timelock.Int32,
×
4534
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4535
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4536
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4537
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4538
                                LastUpdate:              r.Policy1LastUpdate,
×
4539
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4540
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4541
                                Disabled:                r.Policy1Disabled,
×
4542
                                MessageFlags:            r.Policy1MessageFlags,
×
4543
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4544
                                Signature:               r.Policy1Signature,
×
4545
                        }
×
4546
                }
×
4547
                if r.Policy2ID.Valid {
×
4548
                        policy2 = &sqlc.GraphChannelPolicy{
×
4549
                                ID:                      r.Policy2ID.Int64,
×
4550
                                Version:                 r.Policy2Version.Int16,
×
4551
                                ChannelID:               r.GraphChannel.ID,
×
4552
                                NodeID:                  r.Policy2NodeID.Int64,
×
4553
                                Timelock:                r.Policy2Timelock.Int32,
×
4554
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4555
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4556
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4557
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4558
                                LastUpdate:              r.Policy2LastUpdate,
×
4559
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4560
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4561
                                Disabled:                r.Policy2Disabled,
×
4562
                                MessageFlags:            r.Policy2MessageFlags,
×
4563
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4564
                                Signature:               r.Policy2Signature,
×
4565
                        }
×
4566
                }
×
4567

4568
                return policy1, policy2, nil
×
4569

4570
        case sqlc.ListChannelsForNodeIDsRow:
×
4571
                if r.Policy1ID.Valid {
×
4572
                        policy1 = &sqlc.GraphChannelPolicy{
×
4573
                                ID:                      r.Policy1ID.Int64,
×
4574
                                Version:                 r.Policy1Version.Int16,
×
4575
                                ChannelID:               r.GraphChannel.ID,
×
4576
                                NodeID:                  r.Policy1NodeID.Int64,
×
4577
                                Timelock:                r.Policy1Timelock.Int32,
×
4578
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4579
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4580
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4581
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4582
                                LastUpdate:              r.Policy1LastUpdate,
×
4583
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4584
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4585
                                Disabled:                r.Policy1Disabled,
×
4586
                                MessageFlags:            r.Policy1MessageFlags,
×
4587
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4588
                                Signature:               r.Policy1Signature,
×
4589
                        }
×
4590
                }
×
4591
                if r.Policy2ID.Valid {
×
4592
                        policy2 = &sqlc.GraphChannelPolicy{
×
4593
                                ID:                      r.Policy2ID.Int64,
×
4594
                                Version:                 r.Policy2Version.Int16,
×
4595
                                ChannelID:               r.GraphChannel.ID,
×
4596
                                NodeID:                  r.Policy2NodeID.Int64,
×
4597
                                Timelock:                r.Policy2Timelock.Int32,
×
4598
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4599
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4600
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4601
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4602
                                LastUpdate:              r.Policy2LastUpdate,
×
4603
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4604
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4605
                                Disabled:                r.Policy2Disabled,
×
4606
                                MessageFlags:            r.Policy2MessageFlags,
×
4607
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4608
                                Signature:               r.Policy2Signature,
×
4609
                        }
×
4610
                }
×
4611

4612
                return policy1, policy2, nil
×
4613

4614
        case sqlc.ListChannelsByNodeIDRow:
×
4615
                if r.Policy1ID.Valid {
×
4616
                        policy1 = &sqlc.GraphChannelPolicy{
×
4617
                                ID:                      r.Policy1ID.Int64,
×
4618
                                Version:                 r.Policy1Version.Int16,
×
4619
                                ChannelID:               r.GraphChannel.ID,
×
4620
                                NodeID:                  r.Policy1NodeID.Int64,
×
4621
                                Timelock:                r.Policy1Timelock.Int32,
×
4622
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4623
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4624
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4625
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4626
                                LastUpdate:              r.Policy1LastUpdate,
×
4627
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4628
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4629
                                Disabled:                r.Policy1Disabled,
×
4630
                                MessageFlags:            r.Policy1MessageFlags,
×
4631
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4632
                                Signature:               r.Policy1Signature,
×
4633
                        }
×
4634
                }
×
4635
                if r.Policy2ID.Valid {
×
4636
                        policy2 = &sqlc.GraphChannelPolicy{
×
4637
                                ID:                      r.Policy2ID.Int64,
×
4638
                                Version:                 r.Policy2Version.Int16,
×
4639
                                ChannelID:               r.GraphChannel.ID,
×
4640
                                NodeID:                  r.Policy2NodeID.Int64,
×
4641
                                Timelock:                r.Policy2Timelock.Int32,
×
4642
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4643
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4644
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4645
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4646
                                LastUpdate:              r.Policy2LastUpdate,
×
4647
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4648
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4649
                                Disabled:                r.Policy2Disabled,
×
4650
                                MessageFlags:            r.Policy2MessageFlags,
×
4651
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4652
                                Signature:               r.Policy2Signature,
×
4653
                        }
×
4654
                }
×
4655

4656
                return policy1, policy2, nil
×
4657

4658
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4659
                if r.Policy1ID.Valid {
×
4660
                        policy1 = &sqlc.GraphChannelPolicy{
×
4661
                                ID:                      r.Policy1ID.Int64,
×
4662
                                Version:                 r.Policy1Version.Int16,
×
4663
                                ChannelID:               r.GraphChannel.ID,
×
4664
                                NodeID:                  r.Policy1NodeID.Int64,
×
4665
                                Timelock:                r.Policy1Timelock.Int32,
×
4666
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4667
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4668
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4669
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4670
                                LastUpdate:              r.Policy1LastUpdate,
×
4671
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4672
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4673
                                Disabled:                r.Policy1Disabled,
×
4674
                                MessageFlags:            r.Policy1MessageFlags,
×
4675
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4676
                                Signature:               r.Policy1Signature,
×
4677
                        }
×
4678
                }
×
4679
                if r.Policy2ID.Valid {
×
4680
                        policy2 = &sqlc.GraphChannelPolicy{
×
4681
                                ID:                      r.Policy2ID.Int64,
×
4682
                                Version:                 r.Policy2Version.Int16,
×
4683
                                ChannelID:               r.GraphChannel.ID,
×
4684
                                NodeID:                  r.Policy2NodeID.Int64,
×
4685
                                Timelock:                r.Policy2Timelock.Int32,
×
4686
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4687
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4688
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4689
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4690
                                LastUpdate:              r.Policy2LastUpdate,
×
4691
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4692
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4693
                                Disabled:                r.Policy2Disabled,
×
4694
                                MessageFlags:            r.Policy2MessageFlags,
×
4695
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4696
                                Signature:               r.Policy2Signature,
×
4697
                        }
×
4698
                }
×
4699

4700
                return policy1, policy2, nil
×
4701

4702
        case sqlc.GetChannelsByIDsRow:
×
4703
                if r.Policy1ID.Valid {
×
4704
                        policy1 = &sqlc.GraphChannelPolicy{
×
4705
                                ID:                      r.Policy1ID.Int64,
×
4706
                                Version:                 r.Policy1Version.Int16,
×
4707
                                ChannelID:               r.GraphChannel.ID,
×
4708
                                NodeID:                  r.Policy1NodeID.Int64,
×
4709
                                Timelock:                r.Policy1Timelock.Int32,
×
4710
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4711
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4712
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4713
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4714
                                LastUpdate:              r.Policy1LastUpdate,
×
4715
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4716
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4717
                                Disabled:                r.Policy1Disabled,
×
4718
                                MessageFlags:            r.Policy1MessageFlags,
×
4719
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4720
                                Signature:               r.Policy1Signature,
×
4721
                        }
×
4722
                }
×
4723
                if r.Policy2ID.Valid {
×
4724
                        policy2 = &sqlc.GraphChannelPolicy{
×
4725
                                ID:                      r.Policy2ID.Int64,
×
4726
                                Version:                 r.Policy2Version.Int16,
×
4727
                                ChannelID:               r.GraphChannel.ID,
×
4728
                                NodeID:                  r.Policy2NodeID.Int64,
×
4729
                                Timelock:                r.Policy2Timelock.Int32,
×
4730
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4731
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4732
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4733
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4734
                                LastUpdate:              r.Policy2LastUpdate,
×
4735
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4736
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4737
                                Disabled:                r.Policy2Disabled,
×
4738
                                MessageFlags:            r.Policy2MessageFlags,
×
4739
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4740
                                Signature:               r.Policy2Signature,
×
4741
                        }
×
4742
                }
×
4743

4744
                return policy1, policy2, nil
×
4745

4746
        default:
×
4747
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
4748
                        "extractChannelPolicies: %T", r)
×
4749
        }
4750
}
4751

4752
// channelIDToBytes converts a channel ID (SCID) to a byte array
4753
// representation.
4754
func channelIDToBytes(channelID uint64) []byte {
×
4755
        var chanIDB [8]byte
×
4756
        byteOrder.PutUint64(chanIDB[:], channelID)
×
4757

×
4758
        return chanIDB[:]
×
4759
}
×
4760

4761
// buildNodeAddresses converts a slice of nodeAddress into a slice of net.Addr.
4762
func buildNodeAddresses(addresses []nodeAddress) ([]net.Addr, error) {
×
4763
        if len(addresses) == 0 {
×
4764
                return nil, nil
×
4765
        }
×
4766

4767
        result := make([]net.Addr, 0, len(addresses))
×
4768
        for _, addr := range addresses {
×
4769
                netAddr, err := parseAddress(addr.addrType, addr.address)
×
4770
                if err != nil {
×
4771
                        return nil, fmt.Errorf("unable to parse address %s "+
×
4772
                                "of type %d: %w", addr.address, addr.addrType,
×
4773
                                err)
×
4774
                }
×
4775
                if netAddr != nil {
×
4776
                        result = append(result, netAddr)
×
4777
                }
×
4778
        }
4779

4780
        // If we have no valid addresses, return nil instead of empty slice.
4781
        if len(result) == 0 {
×
4782
                return nil, nil
×
4783
        }
×
4784

4785
        return result, nil
×
4786
}
4787

4788
// parseAddress parses the given address string based on the address type
4789
// and returns a net.Addr instance. It supports IPv4, IPv6, Tor v2, Tor v3,
4790
// and opaque addresses.
4791
func parseAddress(addrType dbAddressType, address string) (net.Addr, error) {
×
4792
        switch addrType {
×
4793
        case addressTypeIPv4:
×
4794
                tcp, err := net.ResolveTCPAddr("tcp4", address)
×
4795
                if err != nil {
×
4796
                        return nil, err
×
4797
                }
×
4798

4799
                tcp.IP = tcp.IP.To4()
×
4800

×
4801
                return tcp, nil
×
4802

4803
        case addressTypeIPv6:
×
4804
                tcp, err := net.ResolveTCPAddr("tcp6", address)
×
4805
                if err != nil {
×
4806
                        return nil, err
×
4807
                }
×
4808

4809
                return tcp, nil
×
4810

4811
        case addressTypeTorV3, addressTypeTorV2:
×
4812
                service, portStr, err := net.SplitHostPort(address)
×
4813
                if err != nil {
×
4814
                        return nil, fmt.Errorf("unable to split tor "+
×
4815
                                "address: %v", address)
×
4816
                }
×
4817

4818
                port, err := strconv.Atoi(portStr)
×
4819
                if err != nil {
×
4820
                        return nil, err
×
4821
                }
×
4822

4823
                return &tor.OnionAddr{
×
4824
                        OnionService: service,
×
4825
                        Port:         port,
×
4826
                }, nil
×
4827

4828
        case addressTypeDNS:
×
4829
                hostname, portStr, err := net.SplitHostPort(address)
×
4830
                if err != nil {
×
4831
                        return nil, fmt.Errorf("unable to split DNS "+
×
4832
                                "address: %v", address)
×
4833
                }
×
4834

4835
                port, err := strconv.Atoi(portStr)
×
4836
                if err != nil {
×
4837
                        return nil, err
×
4838
                }
×
4839

4840
                return &lnwire.DNSAddress{
×
4841
                        Hostname: hostname,
×
4842
                        Port:     uint16(port),
×
4843
                }, nil
×
4844

4845
        case addressTypeOpaque:
×
4846
                opaque, err := hex.DecodeString(address)
×
4847
                if err != nil {
×
4848
                        return nil, fmt.Errorf("unable to decode opaque "+
×
4849
                                "address: %v", address)
×
4850
                }
×
4851

4852
                return &lnwire.OpaqueAddrs{
×
4853
                        Payload: opaque,
×
4854
                }, nil
×
4855

4856
        default:
×
4857
                return nil, fmt.Errorf("unknown address type: %v", addrType)
×
4858
        }
4859
}
4860

4861
// batchNodeData holds all the related data for a batch of nodes.
4862
type batchNodeData struct {
4863
        // features is a map from a DB node ID to the feature bits for that
4864
        // node.
4865
        features map[int64][]int
4866

4867
        // addresses is a map from a DB node ID to the node's addresses.
4868
        addresses map[int64][]nodeAddress
4869

4870
        // extraFields is a map from a DB node ID to the extra signed fields
4871
        // for that node.
4872
        extraFields map[int64]map[uint64][]byte
4873
}
4874

4875
// nodeAddress holds the address type, position and address string for a
4876
// node. This is used to batch the fetching of node addresses.
4877
type nodeAddress struct {
4878
        addrType dbAddressType
4879
        position int32
4880
        address  string
4881
}
4882

4883
// batchLoadNodeData loads all related data for a batch of node IDs using the
4884
// provided SQLQueries interface. It returns a batchNodeData instance containing
4885
// the node features, addresses and extra signed fields.
4886
func batchLoadNodeData(ctx context.Context, cfg *sqldb.QueryConfig,
4887
        db SQLQueries, nodeIDs []int64) (*batchNodeData, error) {
×
4888

×
4889
        // Batch load the node features.
×
4890
        features, err := batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
4891
        if err != nil {
×
4892
                return nil, fmt.Errorf("unable to batch load node "+
×
4893
                        "features: %w", err)
×
4894
        }
×
4895

4896
        // Batch load the node addresses.
4897
        addrs, err := batchLoadNodeAddressesHelper(ctx, cfg, db, nodeIDs)
×
4898
        if err != nil {
×
4899
                return nil, fmt.Errorf("unable to batch load node "+
×
4900
                        "addresses: %w", err)
×
4901
        }
×
4902

4903
        // Batch load the node extra signed fields.
4904
        extraTypes, err := batchLoadNodeExtraTypesHelper(ctx, cfg, db, nodeIDs)
×
4905
        if err != nil {
×
4906
                return nil, fmt.Errorf("unable to batch load node extra "+
×
4907
                        "signed fields: %w", err)
×
4908
        }
×
4909

4910
        return &batchNodeData{
×
4911
                features:    features,
×
4912
                addresses:   addrs,
×
4913
                extraFields: extraTypes,
×
4914
        }, nil
×
4915
}
4916

4917
// batchLoadNodeFeaturesHelper loads node features for a batch of node IDs
4918
// using ExecuteBatchQuery wrapper around the GetNodeFeaturesBatch query.
4919
func batchLoadNodeFeaturesHelper(ctx context.Context,
4920
        cfg *sqldb.QueryConfig, db SQLQueries,
4921
        nodeIDs []int64) (map[int64][]int, error) {
×
4922

×
4923
        features := make(map[int64][]int)
×
4924

×
4925
        return features, sqldb.ExecuteBatchQuery(
×
4926
                ctx, cfg, nodeIDs,
×
4927
                func(id int64) int64 {
×
4928
                        return id
×
4929
                },
×
4930
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature,
4931
                        error) {
×
4932

×
4933
                        return db.GetNodeFeaturesBatch(ctx, ids)
×
4934
                },
×
4935
                func(ctx context.Context, feature sqlc.GraphNodeFeature) error {
×
4936
                        features[feature.NodeID] = append(
×
4937
                                features[feature.NodeID],
×
4938
                                int(feature.FeatureBit),
×
4939
                        )
×
4940

×
4941
                        return nil
×
4942
                },
×
4943
        )
4944
}
4945

4946
// batchLoadNodeAddressesHelper loads node addresses using ExecuteBatchQuery
4947
// wrapper around the GetNodeAddressesBatch query. It returns a map from
4948
// node ID to a slice of nodeAddress structs.
4949
func batchLoadNodeAddressesHelper(ctx context.Context,
4950
        cfg *sqldb.QueryConfig, db SQLQueries,
4951
        nodeIDs []int64) (map[int64][]nodeAddress, error) {
×
4952

×
4953
        addrs := make(map[int64][]nodeAddress)
×
4954

×
4955
        return addrs, sqldb.ExecuteBatchQuery(
×
4956
                ctx, cfg, nodeIDs,
×
4957
                func(id int64) int64 {
×
4958
                        return id
×
4959
                },
×
4960
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress,
4961
                        error) {
×
4962

×
4963
                        return db.GetNodeAddressesBatch(ctx, ids)
×
4964
                },
×
4965
                func(ctx context.Context, addr sqlc.GraphNodeAddress) error {
×
4966
                        addrs[addr.NodeID] = append(
×
4967
                                addrs[addr.NodeID], nodeAddress{
×
4968
                                        addrType: dbAddressType(addr.Type),
×
4969
                                        position: addr.Position,
×
4970
                                        address:  addr.Address,
×
4971
                                },
×
4972
                        )
×
4973

×
4974
                        return nil
×
4975
                },
×
4976
        )
4977
}
4978

4979
// batchLoadNodeExtraTypesHelper loads node extra type bytes for a batch of
4980
// node IDs using ExecuteBatchQuery wrapper around the GetNodeExtraTypesBatch
4981
// query.
4982
func batchLoadNodeExtraTypesHelper(ctx context.Context,
4983
        cfg *sqldb.QueryConfig, db SQLQueries,
4984
        nodeIDs []int64) (map[int64]map[uint64][]byte, error) {
×
4985

×
4986
        extraFields := make(map[int64]map[uint64][]byte)
×
4987

×
4988
        callback := func(ctx context.Context,
×
4989
                field sqlc.GraphNodeExtraType) error {
×
4990

×
4991
                if extraFields[field.NodeID] == nil {
×
4992
                        extraFields[field.NodeID] = make(map[uint64][]byte)
×
4993
                }
×
4994
                extraFields[field.NodeID][uint64(field.Type)] = field.Value
×
4995

×
4996
                return nil
×
4997
        }
4998

4999
        return extraFields, sqldb.ExecuteBatchQuery(
×
5000
                ctx, cfg, nodeIDs,
×
5001
                func(id int64) int64 {
×
5002
                        return id
×
5003
                },
×
5004
                func(ctx context.Context, ids []int64) (
5005
                        []sqlc.GraphNodeExtraType, error) {
×
5006

×
5007
                        return db.GetNodeExtraTypesBatch(ctx, ids)
×
5008
                },
×
5009
                callback,
5010
        )
5011
}
5012

5013
// buildChanPoliciesWithBatchData builds two models.ChannelEdgePolicy instances
5014
// from the provided sqlc.GraphChannelPolicy records and the
5015
// provided batchChannelData.
5016
func buildChanPoliciesWithBatchData(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
5017
        channelID uint64, node1, node2 route.Vertex,
5018
        batchData *batchChannelData) (*models.ChannelEdgePolicy,
5019
        *models.ChannelEdgePolicy, error) {
×
5020

×
5021
        pol1, err := buildChanPolicyWithBatchData(
×
5022
                dbPol1, channelID, node2, batchData,
×
5023
        )
×
5024
        if err != nil {
×
5025
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
5026
        }
×
5027

5028
        pol2, err := buildChanPolicyWithBatchData(
×
5029
                dbPol2, channelID, node1, batchData,
×
5030
        )
×
5031
        if err != nil {
×
5032
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
5033
        }
×
5034

5035
        return pol1, pol2, nil
×
5036
}
5037

5038
// buildChanPolicyWithBatchData builds a models.ChannelEdgePolicy instance from
5039
// the provided sqlc.GraphChannelPolicy and the provided batchChannelData.
5040
func buildChanPolicyWithBatchData(dbPol *sqlc.GraphChannelPolicy,
5041
        channelID uint64, toNode route.Vertex,
5042
        batchData *batchChannelData) (*models.ChannelEdgePolicy, error) {
×
5043

×
5044
        if dbPol == nil {
×
5045
                return nil, nil
×
5046
        }
×
5047

5048
        var dbPol1Extras map[uint64][]byte
×
5049
        if extras, exists := batchData.policyExtras[dbPol.ID]; exists {
×
5050
                dbPol1Extras = extras
×
5051
        } else {
×
5052
                dbPol1Extras = make(map[uint64][]byte)
×
5053
        }
×
5054

5055
        return buildChanPolicy(*dbPol, channelID, dbPol1Extras, toNode)
×
5056
}
5057

5058
// batchChannelData holds all the related data for a batch of channels.
5059
type batchChannelData struct {
5060
        // chanFeatures is a map from DB channel ID to a slice of feature bits.
5061
        chanfeatures map[int64][]int
5062

5063
        // chanExtras is a map from DB channel ID to a map of TLV type to
5064
        // extra signed field bytes.
5065
        chanExtraTypes map[int64]map[uint64][]byte
5066

5067
        // policyExtras is a map from DB channel policy ID to a map of TLV type
5068
        // to extra signed field bytes.
5069
        policyExtras map[int64]map[uint64][]byte
5070
}
5071

5072
// batchLoadChannelData loads all related data for batches of channels and
5073
// policies.
5074
func batchLoadChannelData(ctx context.Context, cfg *sqldb.QueryConfig,
5075
        db SQLQueries, channelIDs []int64,
5076
        policyIDs []int64) (*batchChannelData, error) {
×
5077

×
5078
        batchData := &batchChannelData{
×
5079
                chanfeatures:   make(map[int64][]int),
×
5080
                chanExtraTypes: make(map[int64]map[uint64][]byte),
×
5081
                policyExtras:   make(map[int64]map[uint64][]byte),
×
5082
        }
×
5083

×
5084
        // Batch load channel features and extras
×
5085
        var err error
×
5086
        if len(channelIDs) > 0 {
×
5087
                batchData.chanfeatures, err = batchLoadChannelFeaturesHelper(
×
5088
                        ctx, cfg, db, channelIDs,
×
5089
                )
×
5090
                if err != nil {
×
5091
                        return nil, fmt.Errorf("unable to batch load "+
×
5092
                                "channel features: %w", err)
×
5093
                }
×
5094

5095
                batchData.chanExtraTypes, err = batchLoadChannelExtrasHelper(
×
5096
                        ctx, cfg, db, channelIDs,
×
5097
                )
×
5098
                if err != nil {
×
5099
                        return nil, fmt.Errorf("unable to batch load "+
×
5100
                                "channel extras: %w", err)
×
5101
                }
×
5102
        }
5103

5104
        if len(policyIDs) > 0 {
×
5105
                policyExtras, err := batchLoadChannelPolicyExtrasHelper(
×
5106
                        ctx, cfg, db, policyIDs,
×
5107
                )
×
5108
                if err != nil {
×
5109
                        return nil, fmt.Errorf("unable to batch load "+
×
5110
                                "policy extras: %w", err)
×
5111
                }
×
5112
                batchData.policyExtras = policyExtras
×
5113
        }
5114

5115
        return batchData, nil
×
5116
}
5117

5118
// batchLoadChannelFeaturesHelper loads channel features for a batch of
5119
// channel IDs using ExecuteBatchQuery wrapper around the
5120
// GetChannelFeaturesBatch query. It returns a map from DB channel ID to a
5121
// slice of feature bits.
5122
func batchLoadChannelFeaturesHelper(ctx context.Context,
5123
        cfg *sqldb.QueryConfig, db SQLQueries,
5124
        channelIDs []int64) (map[int64][]int, error) {
×
5125

×
5126
        features := make(map[int64][]int)
×
5127

×
5128
        return features, sqldb.ExecuteBatchQuery(
×
5129
                ctx, cfg, channelIDs,
×
5130
                func(id int64) int64 {
×
5131
                        return id
×
5132
                },
×
5133
                func(ctx context.Context,
5134
                        ids []int64) ([]sqlc.GraphChannelFeature, error) {
×
5135

×
5136
                        return db.GetChannelFeaturesBatch(ctx, ids)
×
5137
                },
×
5138
                func(ctx context.Context,
5139
                        feature sqlc.GraphChannelFeature) error {
×
5140

×
5141
                        features[feature.ChannelID] = append(
×
5142
                                features[feature.ChannelID],
×
5143
                                int(feature.FeatureBit),
×
5144
                        )
×
5145

×
5146
                        return nil
×
5147
                },
×
5148
        )
5149
}
5150

5151
// batchLoadChannelExtrasHelper loads channel extra types for a batch of
5152
// channel IDs using ExecuteBatchQuery wrapper around the GetChannelExtrasBatch
5153
// query. It returns a map from DB channel ID to a map of TLV type to extra
5154
// signed field bytes.
5155
func batchLoadChannelExtrasHelper(ctx context.Context,
5156
        cfg *sqldb.QueryConfig, db SQLQueries,
5157
        channelIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5158

×
5159
        extras := make(map[int64]map[uint64][]byte)
×
5160

×
5161
        cb := func(ctx context.Context,
×
5162
                extra sqlc.GraphChannelExtraType) error {
×
5163

×
5164
                if extras[extra.ChannelID] == nil {
×
5165
                        extras[extra.ChannelID] = make(map[uint64][]byte)
×
5166
                }
×
5167
                extras[extra.ChannelID][uint64(extra.Type)] = extra.Value
×
5168

×
5169
                return nil
×
5170
        }
5171

5172
        return extras, sqldb.ExecuteBatchQuery(
×
5173
                ctx, cfg, channelIDs,
×
5174
                func(id int64) int64 {
×
5175
                        return id
×
5176
                },
×
5177
                func(ctx context.Context,
5178
                        ids []int64) ([]sqlc.GraphChannelExtraType, error) {
×
5179

×
5180
                        return db.GetChannelExtrasBatch(ctx, ids)
×
5181
                }, cb,
×
5182
        )
5183
}
5184

5185
// batchLoadChannelPolicyExtrasHelper loads channel policy extra types for a
5186
// batch of policy IDs using ExecuteBatchQuery wrapper around the
5187
// GetChannelPolicyExtraTypesBatch query. It returns a map from DB policy ID to
5188
// a map of TLV type to extra signed field bytes.
5189
func batchLoadChannelPolicyExtrasHelper(ctx context.Context,
5190
        cfg *sqldb.QueryConfig, db SQLQueries,
5191
        policyIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5192

×
5193
        extras := make(map[int64]map[uint64][]byte)
×
5194

×
5195
        return extras, sqldb.ExecuteBatchQuery(
×
5196
                ctx, cfg, policyIDs,
×
5197
                func(id int64) int64 {
×
5198
                        return id
×
5199
                },
×
5200
                func(ctx context.Context, ids []int64) (
5201
                        []sqlc.GetChannelPolicyExtraTypesBatchRow, error) {
×
5202

×
5203
                        return db.GetChannelPolicyExtraTypesBatch(ctx, ids)
×
5204
                },
×
5205
                func(ctx context.Context,
5206
                        row sqlc.GetChannelPolicyExtraTypesBatchRow) error {
×
5207

×
5208
                        if extras[row.PolicyID] == nil {
×
5209
                                extras[row.PolicyID] = make(map[uint64][]byte)
×
5210
                        }
×
5211
                        extras[row.PolicyID][uint64(row.Type)] = row.Value
×
5212

×
5213
                        return nil
×
5214
                },
5215
        )
5216
}
5217

5218
// forEachNodePaginated executes a paginated query to process each node in the
5219
// graph. It uses the provided SQLQueries interface to fetch nodes in batches
5220
// and applies the provided processNode function to each node.
5221
func forEachNodePaginated(ctx context.Context, cfg *sqldb.QueryConfig,
5222
        db SQLQueries, protocol ProtocolVersion,
5223
        processNode func(context.Context, int64,
5224
                *models.Node) error) error {
×
5225

×
5226
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
5227
                limit int32) ([]sqlc.GraphNode, error) {
×
5228

×
5229
                return db.ListNodesPaginated(
×
5230
                        ctx, sqlc.ListNodesPaginatedParams{
×
5231
                                Version: int16(protocol),
×
5232
                                ID:      lastID,
×
5233
                                Limit:   limit,
×
5234
                        },
×
5235
                )
×
5236
        }
×
5237

5238
        extractPageCursor := func(node sqlc.GraphNode) int64 {
×
5239
                return node.ID
×
5240
        }
×
5241

5242
        collectFunc := func(node sqlc.GraphNode) (int64, error) {
×
5243
                return node.ID, nil
×
5244
        }
×
5245

5246
        batchQueryFunc := func(ctx context.Context,
×
5247
                nodeIDs []int64) (*batchNodeData, error) {
×
5248

×
5249
                return batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
5250
        }
×
5251

5252
        processItem := func(ctx context.Context, dbNode sqlc.GraphNode,
×
5253
                batchData *batchNodeData) error {
×
5254

×
5255
                node, err := buildNodeWithBatchData(dbNode, batchData)
×
5256
                if err != nil {
×
5257
                        return fmt.Errorf("unable to build "+
×
5258
                                "node(id=%d): %w", dbNode.ID, err)
×
5259
                }
×
5260

5261
                return processNode(ctx, dbNode.ID, node)
×
5262
        }
5263

5264
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
5265
                ctx, cfg, int64(-1), pageQueryFunc, extractPageCursor,
×
5266
                collectFunc, batchQueryFunc, processItem,
×
5267
        )
×
5268
}
5269

5270
// forEachChannelWithPolicies executes a paginated query to process each channel
5271
// with policies in the graph.
5272
func forEachChannelWithPolicies(ctx context.Context, db SQLQueries,
5273
        cfg *SQLStoreConfig, processChannel func(*models.ChannelEdgeInfo,
5274
                *models.ChannelEdgePolicy,
5275
                *models.ChannelEdgePolicy) error) error {
×
5276

×
5277
        type channelBatchIDs struct {
×
5278
                channelID int64
×
5279
                policyIDs []int64
×
5280
        }
×
5281

×
5282
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
5283
                limit int32) ([]sqlc.ListChannelsWithPoliciesPaginatedRow,
×
5284
                error) {
×
5285

×
5286
                return db.ListChannelsWithPoliciesPaginated(
×
5287
                        ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
5288
                                Version: int16(ProtocolV1),
×
5289
                                ID:      lastID,
×
5290
                                Limit:   limit,
×
5291
                        },
×
5292
                )
×
5293
        }
×
5294

5295
        extractPageCursor := func(
×
5296
                row sqlc.ListChannelsWithPoliciesPaginatedRow) int64 {
×
5297

×
5298
                return row.GraphChannel.ID
×
5299
        }
×
5300

5301
        collectFunc := func(row sqlc.ListChannelsWithPoliciesPaginatedRow) (
×
5302
                channelBatchIDs, error) {
×
5303

×
5304
                ids := channelBatchIDs{
×
5305
                        channelID: row.GraphChannel.ID,
×
5306
                }
×
5307

×
5308
                // Extract policy IDs from the row.
×
5309
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5310
                if err != nil {
×
5311
                        return ids, err
×
5312
                }
×
5313

5314
                if dbPol1 != nil {
×
5315
                        ids.policyIDs = append(ids.policyIDs, dbPol1.ID)
×
5316
                }
×
5317
                if dbPol2 != nil {
×
5318
                        ids.policyIDs = append(ids.policyIDs, dbPol2.ID)
×
5319
                }
×
5320

5321
                return ids, nil
×
5322
        }
5323

5324
        batchDataFunc := func(ctx context.Context,
×
5325
                allIDs []channelBatchIDs) (*batchChannelData, error) {
×
5326

×
5327
                // Separate channel IDs from policy IDs.
×
5328
                var (
×
5329
                        channelIDs = make([]int64, len(allIDs))
×
5330
                        policyIDs  = make([]int64, 0, len(allIDs)*2)
×
5331
                )
×
5332

×
5333
                for i, ids := range allIDs {
×
5334
                        channelIDs[i] = ids.channelID
×
5335
                        policyIDs = append(policyIDs, ids.policyIDs...)
×
5336
                }
×
5337

5338
                return batchLoadChannelData(
×
5339
                        ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
5340
                )
×
5341
        }
5342

5343
        processItem := func(ctx context.Context,
×
5344
                row sqlc.ListChannelsWithPoliciesPaginatedRow,
×
5345
                batchData *batchChannelData) error {
×
5346

×
5347
                node1, node2, err := buildNodeVertices(
×
5348
                        row.Node1Pubkey, row.Node2Pubkey,
×
5349
                )
×
5350
                if err != nil {
×
5351
                        return err
×
5352
                }
×
5353

5354
                edge, err := buildEdgeInfoWithBatchData(
×
5355
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
5356
                        batchData,
×
5357
                )
×
5358
                if err != nil {
×
5359
                        return fmt.Errorf("unable to build channel info: %w",
×
5360
                                err)
×
5361
                }
×
5362

5363
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5364
                if err != nil {
×
5365
                        return err
×
5366
                }
×
5367

5368
                p1, p2, err := buildChanPoliciesWithBatchData(
×
5369
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
5370
                )
×
5371
                if err != nil {
×
5372
                        return err
×
5373
                }
×
5374

5375
                return processChannel(edge, p1, p2)
×
5376
        }
5377

5378
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
5379
                ctx, cfg.QueryCfg, int64(-1), pageQueryFunc, extractPageCursor,
×
5380
                collectFunc, batchDataFunc, processItem,
×
5381
        )
×
5382
}
5383

5384
// buildDirectedChannel builds a DirectedChannel instance from the provided
5385
// data.
5386
func buildDirectedChannel(chain chainhash.Hash, nodeID int64,
5387
        nodePub route.Vertex, channelRow sqlc.ListChannelsForNodeIDsRow,
5388
        channelBatchData *batchChannelData, features *lnwire.FeatureVector,
5389
        toNodeCallback func() route.Vertex) (*DirectedChannel, error) {
×
5390

×
5391
        node1, node2, err := buildNodeVertices(
×
5392
                channelRow.Node1Pubkey, channelRow.Node2Pubkey,
×
5393
        )
×
5394
        if err != nil {
×
5395
                return nil, fmt.Errorf("unable to build node vertices: %w", err)
×
5396
        }
×
5397

5398
        edge, err := buildEdgeInfoWithBatchData(
×
5399
                chain, channelRow.GraphChannel, node1, node2, channelBatchData,
×
5400
        )
×
5401
        if err != nil {
×
5402
                return nil, fmt.Errorf("unable to build channel info: %w", err)
×
5403
        }
×
5404

5405
        dbPol1, dbPol2, err := extractChannelPolicies(channelRow)
×
5406
        if err != nil {
×
5407
                return nil, fmt.Errorf("unable to extract channel policies: %w",
×
5408
                        err)
×
5409
        }
×
5410

5411
        p1, p2, err := buildChanPoliciesWithBatchData(
×
5412
                dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
5413
                channelBatchData,
×
5414
        )
×
5415
        if err != nil {
×
5416
                return nil, fmt.Errorf("unable to build channel policies: %w",
×
5417
                        err)
×
5418
        }
×
5419

5420
        // Determine outgoing and incoming policy for this specific node.
5421
        p1ToNode := channelRow.GraphChannel.NodeID2
×
5422
        p2ToNode := channelRow.GraphChannel.NodeID1
×
5423
        outPolicy, inPolicy := p1, p2
×
5424
        if (p1 != nil && p1ToNode == nodeID) ||
×
5425
                (p2 != nil && p2ToNode != nodeID) {
×
5426

×
5427
                outPolicy, inPolicy = p2, p1
×
5428
        }
×
5429

5430
        // Build cached policy.
5431
        var cachedInPolicy *models.CachedEdgePolicy
×
5432
        if inPolicy != nil {
×
5433
                cachedInPolicy = models.NewCachedPolicy(inPolicy)
×
5434
                cachedInPolicy.ToNodePubKey = toNodeCallback
×
5435
                cachedInPolicy.ToNodeFeatures = features
×
5436
        }
×
5437

5438
        // Extract inbound fee.
5439
        var inboundFee lnwire.Fee
×
5440
        if outPolicy != nil {
×
5441
                outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
5442
                        inboundFee = fee
×
5443
                })
×
5444
        }
5445

5446
        // Build directed channel.
5447
        directedChannel := &DirectedChannel{
×
5448
                ChannelID:    edge.ChannelID,
×
5449
                IsNode1:      nodePub == edge.NodeKey1Bytes,
×
5450
                OtherNode:    edge.NodeKey2Bytes,
×
5451
                Capacity:     edge.Capacity,
×
5452
                OutPolicySet: outPolicy != nil,
×
5453
                InPolicy:     cachedInPolicy,
×
5454
                InboundFee:   inboundFee,
×
5455
        }
×
5456

×
5457
        if nodePub == edge.NodeKey2Bytes {
×
5458
                directedChannel.OtherNode = edge.NodeKey1Bytes
×
5459
        }
×
5460

5461
        return directedChannel, nil
×
5462
}
5463

5464
// batchBuildChannelEdges builds a slice of ChannelEdge instances from the
5465
// provided rows. It uses batch loading for channels, policies, and nodes.
5466
func batchBuildChannelEdges[T sqlc.ChannelAndNodes](ctx context.Context,
5467
        cfg *SQLStoreConfig, db SQLQueries, rows []T) ([]ChannelEdge, error) {
×
5468

×
5469
        var (
×
5470
                channelIDs = make([]int64, len(rows))
×
5471
                policyIDs  = make([]int64, 0, len(rows)*2)
×
5472
                nodeIDs    = make([]int64, 0, len(rows)*2)
×
5473

×
5474
                // nodeIDSet is used to ensure we only collect unique node IDs.
×
5475
                nodeIDSet = make(map[int64]bool)
×
5476

×
5477
                // edges will hold the final channel edges built from the rows.
×
5478
                edges = make([]ChannelEdge, 0, len(rows))
×
5479
        )
×
5480

×
5481
        // Collect all IDs needed for batch loading.
×
5482
        for i, row := range rows {
×
5483
                channelIDs[i] = row.Channel().ID
×
5484

×
5485
                // Collect policy IDs
×
5486
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5487
                if err != nil {
×
5488
                        return nil, fmt.Errorf("unable to extract channel "+
×
5489
                                "policies: %w", err)
×
5490
                }
×
5491
                if dbPol1 != nil {
×
5492
                        policyIDs = append(policyIDs, dbPol1.ID)
×
5493
                }
×
5494
                if dbPol2 != nil {
×
5495
                        policyIDs = append(policyIDs, dbPol2.ID)
×
5496
                }
×
5497

5498
                var (
×
5499
                        node1ID = row.Node1().ID
×
5500
                        node2ID = row.Node2().ID
×
5501
                )
×
5502

×
5503
                // Collect unique node IDs.
×
5504
                if !nodeIDSet[node1ID] {
×
5505
                        nodeIDs = append(nodeIDs, node1ID)
×
5506
                        nodeIDSet[node1ID] = true
×
5507
                }
×
5508

5509
                if !nodeIDSet[node2ID] {
×
5510
                        nodeIDs = append(nodeIDs, node2ID)
×
5511
                        nodeIDSet[node2ID] = true
×
5512
                }
×
5513
        }
5514

5515
        // Batch the data for all the channels and policies.
5516
        channelBatchData, err := batchLoadChannelData(
×
5517
                ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
5518
        )
×
5519
        if err != nil {
×
5520
                return nil, fmt.Errorf("unable to batch load channel and "+
×
5521
                        "policy data: %w", err)
×
5522
        }
×
5523

5524
        // Batch the data for all the nodes.
5525
        nodeBatchData, err := batchLoadNodeData(ctx, cfg.QueryCfg, db, nodeIDs)
×
5526
        if err != nil {
×
5527
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
5528
                        err)
×
5529
        }
×
5530

5531
        // Build all channel edges using batch data.
5532
        for _, row := range rows {
×
5533
                // Build nodes using batch data.
×
5534
                node1, err := buildNodeWithBatchData(row.Node1(), nodeBatchData)
×
5535
                if err != nil {
×
5536
                        return nil, fmt.Errorf("unable to build node1: %w", err)
×
5537
                }
×
5538

5539
                node2, err := buildNodeWithBatchData(row.Node2(), nodeBatchData)
×
5540
                if err != nil {
×
5541
                        return nil, fmt.Errorf("unable to build node2: %w", err)
×
5542
                }
×
5543

5544
                // Build channel info using batch data.
5545
                channel, err := buildEdgeInfoWithBatchData(
×
5546
                        cfg.ChainHash, row.Channel(), node1.PubKeyBytes,
×
5547
                        node2.PubKeyBytes, channelBatchData,
×
5548
                )
×
5549
                if err != nil {
×
5550
                        return nil, fmt.Errorf("unable to build channel "+
×
5551
                                "info: %w", err)
×
5552
                }
×
5553

5554
                // Extract and build policies using batch data.
5555
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5556
                if err != nil {
×
5557
                        return nil, fmt.Errorf("unable to extract channel "+
×
5558
                                "policies: %w", err)
×
5559
                }
×
5560

5561
                p1, p2, err := buildChanPoliciesWithBatchData(
×
5562
                        dbPol1, dbPol2, channel.ChannelID,
×
5563
                        node1.PubKeyBytes, node2.PubKeyBytes, channelBatchData,
×
5564
                )
×
5565
                if err != nil {
×
5566
                        return nil, fmt.Errorf("unable to build channel "+
×
5567
                                "policies: %w", err)
×
5568
                }
×
5569

5570
                edges = append(edges, ChannelEdge{
×
5571
                        Info:    channel,
×
5572
                        Policy1: p1,
×
5573
                        Policy2: p2,
×
5574
                        Node1:   node1,
×
5575
                        Node2:   node2,
×
5576
                })
×
5577
        }
5578

5579
        return edges, nil
×
5580
}
5581

5582
// batchBuildChannelInfo builds a slice of models.ChannelEdgeInfo
5583
// instances from the provided rows using batch loading for channel data.
5584
func batchBuildChannelInfo[T sqlc.ChannelAndNodeIDs](ctx context.Context,
5585
        cfg *SQLStoreConfig, db SQLQueries, rows []T) (
5586
        []*models.ChannelEdgeInfo, []int64, error) {
×
5587

×
5588
        if len(rows) == 0 {
×
5589
                return nil, nil, nil
×
5590
        }
×
5591

5592
        // Collect all the channel IDs needed for batch loading.
5593
        channelIDs := make([]int64, len(rows))
×
5594
        for i, row := range rows {
×
5595
                channelIDs[i] = row.Channel().ID
×
5596
        }
×
5597

5598
        // Batch load the channel data.
5599
        channelBatchData, err := batchLoadChannelData(
×
5600
                ctx, cfg.QueryCfg, db, channelIDs, nil,
×
5601
        )
×
5602
        if err != nil {
×
5603
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
5604
                        "data: %w", err)
×
5605
        }
×
5606

5607
        // Build all channel edges using batch data.
5608
        edges := make([]*models.ChannelEdgeInfo, 0, len(rows))
×
5609
        for _, row := range rows {
×
5610
                node1, node2, err := buildNodeVertices(
×
5611
                        row.Node1Pub(), row.Node2Pub(),
×
5612
                )
×
5613
                if err != nil {
×
5614
                        return nil, nil, err
×
5615
                }
×
5616

5617
                // Build channel info using batch data
5618
                info, err := buildEdgeInfoWithBatchData(
×
5619
                        cfg.ChainHash, row.Channel(), node1, node2,
×
5620
                        channelBatchData,
×
5621
                )
×
5622
                if err != nil {
×
5623
                        return nil, nil, err
×
5624
                }
×
5625

5626
                edges = append(edges, info)
×
5627
        }
5628

5629
        return edges, channelIDs, nil
×
5630
}
5631

5632
// handleZombieMarking is a helper function that handles the logic of
5633
// marking a channel as a zombie in the database. It takes into account whether
5634
// we are in strict zombie pruning mode, and adjusts the node public keys
5635
// accordingly based on the last update timestamps of the channel policies.
5636
func handleZombieMarking(ctx context.Context, db SQLQueries,
5637
        row sqlc.GetChannelsBySCIDWithPoliciesRow, info *models.ChannelEdgeInfo,
5638
        strictZombiePruning bool, scid uint64) error {
×
5639

×
5640
        nodeKey1, nodeKey2 := info.NodeKey1Bytes, info.NodeKey2Bytes
×
5641

×
5642
        if strictZombiePruning {
×
5643
                var e1UpdateTime, e2UpdateTime *time.Time
×
5644
                if row.Policy1LastUpdate.Valid {
×
5645
                        e1Time := time.Unix(row.Policy1LastUpdate.Int64, 0)
×
5646
                        e1UpdateTime = &e1Time
×
5647
                }
×
5648
                if row.Policy2LastUpdate.Valid {
×
5649
                        e2Time := time.Unix(row.Policy2LastUpdate.Int64, 0)
×
5650
                        e2UpdateTime = &e2Time
×
5651
                }
×
5652

5653
                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
5654
                        info.NodeKey1Bytes, info.NodeKey2Bytes, e1UpdateTime,
×
5655
                        e2UpdateTime,
×
5656
                )
×
5657
        }
5658

5659
        return db.UpsertZombieChannel(
×
5660
                ctx, sqlc.UpsertZombieChannelParams{
×
5661
                        Version:  int16(ProtocolV1),
×
5662
                        Scid:     channelIDToBytes(scid),
×
5663
                        NodeKey1: nodeKey1[:],
×
5664
                        NodeKey2: nodeKey2[:],
×
5665
                },
×
5666
        )
×
5667
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc