• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

spesmilo / electrum / 5735552722403328

16 May 2025 10:28AM UTC coverage: 59.722% (+0.002%) from 59.72%
5735552722403328

Pull #9833

CirrusCI

f321x
make lightning dns seed fetching async
Pull Request #9833: dns: use async dnspython interface

22 of 50 new or added lines in 7 files covered. (44.0%)

1107 existing lines in 11 files now uncovered.

21549 of 36082 relevant lines covered (59.72%)

2.39 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

49.83
/electrum/lnworker.py
1
# Copyright (C) 2018 The Electrum developers
2
# Distributed under the MIT software license, see the accompanying
3
# file LICENCE or http://www.opensource.org/licenses/mit-license.php
4

5
import asyncio
4✔
6
import os
4✔
7
from decimal import Decimal
4✔
8
import random
4✔
9
import time
4✔
10
from enum import IntEnum
4✔
11
from typing import (
4✔
12
    Optional, Sequence, Tuple, List, Set, Dict, TYPE_CHECKING, NamedTuple, Mapping, Any, Iterable, AsyncGenerator,
13
    Callable, Awaitable
14
)
15
import threading
4✔
16
import socket
4✔
17
from functools import partial
4✔
18
from collections import defaultdict
4✔
19
import concurrent
4✔
20
from concurrent import futures
4✔
21
import urllib.parse
4✔
22
import itertools
4✔
23

24
import aiohttp
4✔
25
import dns.asyncresolver
4✔
26
import dns.exception
4✔
27
from aiorpcx import run_in_thread, NetAddress, ignore_after
4✔
28

29
from .logging import Logger
4✔
30
from .i18n import _
4✔
31
from .json_db import stored_in
4✔
32
from .channel_db import UpdateStatus, ChannelDBNotLoaded, get_mychannel_info, get_mychannel_policy
4✔
33

34
from . import constants, util
4✔
35
from .util import (
4✔
36
    profiler, OldTaskGroup, ESocksProxy, NetworkRetryManager, JsonRPCClient, NotEnoughFunds, EventListener,
37
    event_listener, bfh, InvoiceError, resolve_dns_srv, is_ip_address, log_exceptions, ignore_exceptions,
38
    make_aiohttp_session, timestamp_to_datetime, random_shuffled_copy, is_private_netaddress,
39
    UnrelatedTransactionException, LightningHistoryItem
40
)
41
from .fee_policy import FeePolicy, FixedFeePolicy
4✔
42
from .fee_policy import (FEERATE_FALLBACK_STATIC_FEE, FEE_LN_ETA_TARGET, FEE_LN_LOW_ETA_TARGET,
4✔
43
                         FEERATE_PER_KW_MIN_RELAY_LIGHTNING, FEE_LN_MINIMUM_ETA_TARGET)
44
from .invoices import Invoice, PR_UNPAID, PR_PAID, PR_INFLIGHT, PR_FAILED, LN_EXPIRY_NEVER, BaseInvoice
4✔
45
from .bitcoin import COIN, opcodes, make_op_return, address_to_scripthash, DummyAddress
4✔
46
from .bip32 import BIP32Node
4✔
47
from .address_synchronizer import TX_HEIGHT_LOCAL, TX_TIMESTAMP_INF
4✔
48
from .transaction import (
4✔
49
    Transaction, get_script_type_from_output_script, PartialTxOutput, PartialTransaction, PartialTxInput
50
)
51
from .crypto import (
4✔
52
    sha256, chacha20_encrypt, chacha20_decrypt, pw_encode_with_version_and_mac, pw_decode_with_version_and_mac
53
)
54

55
from .onion_message import OnionMessageManager
4✔
56
from .lntransport import LNTransport, LNResponderTransport, LNTransportBase, LNPeerAddr, split_host_port, extract_nodeid, ConnStringFormatError
4✔
57
from .lnpeer import Peer, LN_P2P_NETWORK_TIMEOUT
4✔
58
from .lnaddr import lnencode, LnAddr, lndecode
4✔
59
from .lnchannel import Channel, AbstractChannel, ChannelState, PeerState, HTLCWithStatus, ChannelBackup
4✔
60
from .lnrater import LNRater
4✔
61
from .lnutil import (
4✔
62
    get_compressed_pubkey_from_bech32, serialize_htlc_key, deserialize_htlc_key, PaymentFailure, generate_keypair,
63
    LnKeyFamily, LOCAL, REMOTE, MIN_FINAL_CLTV_DELTA_FOR_INVOICE, SENT, RECEIVED, HTLCOwner, UpdateAddHtlc, LnFeatures,
64
    ShortChannelID, HtlcLog, NoPathFound, InvalidGossipMsg, FeeBudgetExceeded, ImportedChannelBackupStorage,
65
    OnchainChannelBackupStorage, ln_compare_features, IncompatibleLightningFeatures, PaymentFeeBudget,
66
    NBLOCK_CLTV_DELTA_TOO_FAR_INTO_FUTURE, GossipForwardingMessage, MIN_FUNDING_SAT
67
)
68
from .lnonion import decode_onion_error, OnionFailureCode, OnionRoutingFailure, OnionPacket
4✔
69
from .lnmsg import decode_msg
4✔
70
from .lnrouter import (
4✔
71
    RouteEdge, LNPaymentRoute, LNPaymentPath, is_route_within_budget, NoChannelPolicy, LNPathInconsistent
72
)
73
from .lnwatcher import LNWatcher
4✔
74
from .submarine_swaps import SwapManager
4✔
75
from .mpp_split import suggest_splits, SplitConfigRating
4✔
76
from .trampoline import (
4✔
77
    create_trampoline_route_and_onion, is_legacy_relay, trampolines_by_id, hardcoded_trampoline_nodes,
78
    is_hardcoded_trampoline
79
)
80

81
if TYPE_CHECKING:
4✔
82
    from .network import Network
×
83
    from .wallet import Abstract_Wallet
×
84
    from .channel_db import ChannelDB
×
85
    from .simple_config import SimpleConfig
×
86

87

88
SAVED_PR_STATUS = [PR_PAID, PR_UNPAID]  # status that are persisted
4✔
89

90
NUM_PEERS_TARGET = 4
4✔
91

92
# onchain channel backup data
93
CB_VERSION = 0
4✔
94
CB_MAGIC_BYTES = bytes([0, 0, 0, CB_VERSION])
4✔
95
NODE_ID_PREFIX_LEN = 16
4✔
96

97

98
class PaymentDirection(IntEnum):
4✔
99
    SENT = 0
4✔
100
    RECEIVED = 1
4✔
101
    SELF_PAYMENT = 2
4✔
102
    FORWARDING = 3
4✔
103

104

105
class PaymentInfo(NamedTuple):
4✔
106
    payment_hash: bytes
4✔
107
    amount_msat: Optional[int]
4✔
108
    direction: int
4✔
109
    status: int
4✔
110

111

112
# Note: these states are persisted in the wallet file.
113
# Do not modify them without performing a wallet db upgrade
114
class RecvMPPResolution(IntEnum):
4✔
115
    WAITING = 0
4✔
116
    EXPIRED = 1
4✔
117
    ACCEPTED = 2
4✔
118
    FAILED = 3
4✔
119

120

121
class ReceivedMPPStatus(NamedTuple):
4✔
122
    resolution: RecvMPPResolution
4✔
123
    expected_msat: int
4✔
124
    htlc_set: Set[Tuple[ShortChannelID, UpdateAddHtlc]]
4✔
125

126
    @stored_in('received_mpp_htlcs', tuple)
4✔
127
    def from_tuple(resolution, expected_msat, htlc_list) -> 'ReceivedMPPStatus':
4✔
128
        htlc_set = set([(ShortChannelID(bytes.fromhex(scid)), UpdateAddHtlc.from_tuple(*x)) for (scid, x) in htlc_list])
×
129
        return ReceivedMPPStatus(
×
130
            resolution=RecvMPPResolution(resolution),
131
            expected_msat=expected_msat,
132
            htlc_set=htlc_set)
133

134

135
SentHtlcKey = Tuple[bytes, ShortChannelID, int]  # RHASH, scid, htlc_id
4✔
136

137

138
class SentHtlcInfo(NamedTuple):
4✔
139
    route: LNPaymentRoute
4✔
140
    payment_secret_orig: bytes
4✔
141
    payment_secret_bucket: bytes
4✔
142
    amount_msat: int
4✔
143
    bucket_msat: int
4✔
144
    amount_receiver_msat: int
4✔
145
    trampoline_fee_level: Optional[int]
4✔
146
    trampoline_route: Optional[LNPaymentRoute]
4✔
147

148

149
class ErrorAddingPeer(Exception): pass
4✔
150

151

152
# set some feature flags as baseline for both LNWallet and LNGossip
153
# note that e.g. DATA_LOSS_PROTECT is needed for LNGossip as many peers require it
154
BASE_FEATURES = (
4✔
155
    LnFeatures(0)
156
    | LnFeatures.OPTION_DATA_LOSS_PROTECT_OPT
157
    | LnFeatures.OPTION_STATIC_REMOTEKEY_OPT
158
    | LnFeatures.VAR_ONION_OPT
159
    | LnFeatures.PAYMENT_SECRET_OPT
160
    | LnFeatures.OPTION_UPFRONT_SHUTDOWN_SCRIPT_OPT
161
)
162

163
# we do not want to receive unrequested gossip (see lnpeer.maybe_save_remote_update)
164
LNWALLET_FEATURES = (
4✔
165
    BASE_FEATURES
166
    | LnFeatures.OPTION_DATA_LOSS_PROTECT_REQ
167
    | LnFeatures.OPTION_STATIC_REMOTEKEY_REQ
168
    | LnFeatures.VAR_ONION_REQ
169
    | LnFeatures.PAYMENT_SECRET_REQ
170
    | LnFeatures.BASIC_MPP_OPT
171
    | LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ELECTRUM
172
    | LnFeatures.OPTION_SHUTDOWN_ANYSEGWIT_OPT
173
    | LnFeatures.OPTION_CHANNEL_TYPE_OPT
174
    | LnFeatures.OPTION_SCID_ALIAS_OPT
175
    | LnFeatures.OPTION_SUPPORT_LARGE_CHANNEL_OPT
176
)
177

178
LNGOSSIP_FEATURES = (
4✔
179
    BASE_FEATURES
180
    # LNGossip doesn't serve gossip but weirdly have to signal so
181
    # that peers satisfy our queries
182
    | LnFeatures.GOSSIP_QUERIES_REQ
183
    | LnFeatures.GOSSIP_QUERIES_OPT
184
)
185

186

187
class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
4✔
188

189
    def __init__(self, node_keypair, features: LnFeatures, *, config: 'SimpleConfig'):
4✔
190
        Logger.__init__(self)
4✔
191
        NetworkRetryManager.__init__(
4✔
192
            self,
193
            max_retry_delay_normal=3600,
194
            init_retry_delay_normal=600,
195
            max_retry_delay_urgent=300,
196
            init_retry_delay_urgent=4,
197
        )
198
        self.lock = threading.RLock()
4✔
199
        self.node_keypair = node_keypair
4✔
200
        self._peers = {}  # type: Dict[bytes, Peer]  # pubkey -> Peer  # needs self.lock
4✔
201
        self.taskgroup = OldTaskGroup()
4✔
202
        self.listen_server = None  # type: Optional[asyncio.AbstractServer]
4✔
203
        self.features = features
4✔
204
        self.network = None  # type: Optional[Network]
4✔
205
        self.config = config
4✔
206
        self.stopping_soon = False  # whether we are being shut down
4✔
207
        self.register_callbacks()
4✔
208

209
    @property
4✔
210
    def channel_db(self) -> 'ChannelDB':
4✔
211
        return self.network.channel_db if self.network else None
×
212

213
    def uses_trampoline(self) -> bool:
4✔
214
        return not bool(self.channel_db)
×
215

216
    @property
4✔
217
    def peers(self) -> Mapping[bytes, Peer]:
4✔
218
        """Returns a read-only copy of peers."""
219
        with self.lock:
×
220
            return self._peers.copy()
×
221

222
    def channels_for_peer(self, node_id: bytes) -> Dict[bytes, Channel]:
4✔
223
        return {}
×
224

225
    def get_node_alias(self, node_id: bytes) -> Optional[str]:
4✔
226
        """Returns the alias of the node, or None if unknown."""
227
        node_alias = None
×
228
        if not self.uses_trampoline():
×
229
            node_info = self.channel_db.get_node_info_for_node_id(node_id)
×
230
            if node_info:
×
231
                node_alias = node_info.alias
×
232
        else:
233
            for k, v in hardcoded_trampoline_nodes().items():
×
234
                if v.pubkey.startswith(node_id):
×
235
                    node_alias = k
×
236
                    break
×
237
        return node_alias
×
238

239
    async def maybe_listen(self):
4✔
240
        # FIXME: only one LNWorker can listen at a time (single port)
241
        listen_addr = self.config.LIGHTNING_LISTEN
×
242
        if listen_addr:
×
243
            self.logger.info(f'lightning_listen enabled. will try to bind: {listen_addr!r}')
×
244
            try:
×
245
                netaddr = NetAddress.from_string(listen_addr)
×
246
            except Exception as e:
×
247
                self.logger.error(f"failed to parse config key '{self.config.cv.LIGHTNING_LISTEN.key()}'. got: {e!r}")
×
248
                return
×
249
            addr = str(netaddr.host)
×
250

251
            async def cb(reader, writer):
×
252
                transport = LNResponderTransport(self.node_keypair.privkey, reader, writer)
×
253
                try:
×
254
                    node_id = await transport.handshake()
×
255
                except Exception as e:
×
256
                    self.logger.info(f'handshake failure from incoming connection: {e!r}')
×
257
                    return
×
258
                await self._add_peer_from_transport(node_id=node_id, transport=transport)
×
259
            try:
×
260
                self.listen_server = await asyncio.start_server(cb, addr, netaddr.port)
×
261
            except OSError as e:
×
262
                self.logger.error(f"cannot listen for lightning p2p. error: {e!r}")
×
263

264
    async def main_loop(self):
4✔
265
        self.logger.info("starting taskgroup.")
×
266
        try:
×
267
            async with self.taskgroup as group:
×
268
                await group.spawn(asyncio.Event().wait)  # run forever (until cancel)
×
269
        except Exception as e:
×
270
            self.logger.exception("taskgroup died.")
×
271
        finally:
272
            self.logger.info("taskgroup stopped.")
×
273

274
    async def _maintain_connectivity(self):
4✔
275
        while True:
×
276
            await asyncio.sleep(1)
×
277
            if self.stopping_soon:
×
278
                return
×
279
            now = time.time()
×
280
            if len(self._peers) >= NUM_PEERS_TARGET:
×
281
                continue
×
282
            peers = await self._get_next_peers_to_try()
×
283
            for peer in peers:
×
284
                if self._can_retry_addr(peer, now=now):
×
285
                    try:
×
286
                        await self._add_peer(peer.host, peer.port, peer.pubkey)
×
287
                    except ErrorAddingPeer as e:
×
288
                        self.logger.info(f"failed to add peer: {peer}. exc: {e!r}")
×
289

290
    async def _add_peer(self, host: str, port: int, node_id: bytes) -> Peer:
4✔
291
        if node_id in self._peers:
×
292
            return self._peers[node_id]
×
293
        port = int(port)
×
294
        peer_addr = LNPeerAddr(host, port, node_id)
×
295
        self._trying_addr_now(peer_addr)
×
296
        self.logger.info(f"adding peer {peer_addr}")
×
297
        if node_id == self.node_keypair.pubkey or self.is_our_lnwallet(node_id):
×
298
            raise ErrorAddingPeer("cannot connect to self")
×
299
        transport = LNTransport(self.node_keypair.privkey, peer_addr,
×
300
                                e_proxy=ESocksProxy.from_network_settings(self.network))
301
        peer = await self._add_peer_from_transport(node_id=node_id, transport=transport)
×
302
        assert peer
×
303
        return peer
×
304

305
    async def _add_peer_from_transport(self, *, node_id: bytes, transport: LNTransportBase) -> Optional[Peer]:
4✔
306
        with self.lock:
×
307
            existing_peer = self._peers.get(node_id)
×
308
            if existing_peer:
×
309
                # Two instances of the same wallet are attempting to connect simultaneously.
310
                # If we let the new connection replace the existing one, the two instances might
311
                # both keep trying to reconnect, resulting in neither being usable.
312
                if existing_peer.is_initialized():
×
313
                    # give priority to the existing connection
314
                    return
×
315
                else:
316
                    # Use the new connection. (e.g. old peer might be an outgoing connection
317
                    # for an outdated host/port that will never connect)
318
                    existing_peer.close_and_cleanup()
×
319
            peer = Peer(self, node_id, transport)
×
320
            assert node_id not in self._peers
×
321
            self._peers[node_id] = peer
×
322
        await self.taskgroup.spawn(peer.main_loop())
×
323
        return peer
×
324

325
    def peer_closed(self, peer: Peer) -> None:
4✔
326
        with self.lock:
×
327
            peer2 = self._peers.get(peer.pubkey)
×
328
            if peer2 is peer:
×
329
                self._peers.pop(peer.pubkey)
×
330

331
    def num_peers(self) -> int:
4✔
332
        return sum([p.is_initialized() for p in self.peers.values()])
×
333

334
    def is_our_lnwallet(self, node_id: bytes) -> bool:
4✔
335
        """Check if node_id is one of our own wallets"""
336
        wallets = self.network.daemon.get_wallets()
×
337
        for wallet in wallets.values():
×
338
            if wallet.lnworker and wallet.lnworker.node_keypair.pubkey == node_id:
×
339
                return True
×
340
        return False
×
341

342
    def start_network(self, network: 'Network'):
4✔
343
        assert network
×
344
        assert self.network is None, "already started"
×
345
        self.network = network
×
346
        self._add_peers_from_config()
×
347
        asyncio.run_coroutine_threadsafe(self.main_loop(), self.network.asyncio_loop)
×
348

349
    async def stop(self):
4✔
350
        if self.listen_server:
4✔
351
            self.listen_server.close()
×
352
        self.unregister_callbacks()
4✔
353
        await self.taskgroup.cancel_remaining()
4✔
354

355
    def _add_peers_from_config(self):
4✔
356
        peer_list = self.config.LIGHTNING_PEERS or []
×
357
        for host, port, pubkey in peer_list:
×
358
            asyncio.run_coroutine_threadsafe(
×
359
                self._add_peer(host, int(port), bfh(pubkey)),
360
                self.network.asyncio_loop)
361

362
    def is_good_peer(self, peer: LNPeerAddr) -> bool:
4✔
363
        # the purpose of this method is to filter peers that advertise the desired feature bits
364
        # it is disabled for now, because feature bits published in node announcements seem to be unreliable
365
        return True
×
366
        node_id = peer.pubkey
367
        node = self.channel_db._nodes.get(node_id)
368
        if not node:
369
            return False
370
        try:
371
            ln_compare_features(self.features, node.features)
372
        except IncompatibleLightningFeatures:
373
            return False
374
        #self.logger.info(f'is_good {peer.host}')
375
        return True
376

377
    def on_peer_successfully_established(self, peer: Peer) -> None:
4✔
378
        if isinstance(peer.transport, LNTransport):
4✔
379
            peer_addr = peer.transport.peer_addr
×
380
            # reset connection attempt count
381
            self._on_connection_successfully_established(peer_addr)
×
382
            if not self.uses_trampoline():
×
383
                # add into channel db
384
                self.channel_db.add_recent_peer(peer_addr)
×
385
            # save network address into channels we might have with peer
386
            for chan in peer.channels.values():
×
387
                chan.add_or_update_peer_addr(peer_addr)
×
388

389
    async def _get_next_peers_to_try(self) -> Sequence[LNPeerAddr]:
4✔
390
        now = time.time()
×
391
        await self.channel_db.data_loaded.wait()
×
392
        # first try from recent peers
393
        recent_peers = self.channel_db.get_recent_peers()
×
394
        for peer in recent_peers:
×
395
            if not peer:
×
396
                continue
×
397
            if peer.pubkey in self._peers:
×
398
                continue
×
399
            if not self._can_retry_addr(peer, now=now):
×
400
                continue
×
401
            if not self.is_good_peer(peer):
×
402
                continue
×
403
            return [peer]
×
404
        # try random peer from graph
405
        unconnected_nodes = self.channel_db.get_200_randomly_sorted_nodes_not_in(self.peers.keys())
×
406
        if unconnected_nodes:
×
407
            for node_id in unconnected_nodes:
×
408
                addrs = self.channel_db.get_node_addresses(node_id)
×
409
                if not addrs:
×
410
                    continue
×
411
                host, port, timestamp = self.choose_preferred_address(list(addrs))
×
412
                try:
×
413
                    peer = LNPeerAddr(host, port, node_id)
×
414
                except ValueError:
×
415
                    continue
×
416
                if not self._can_retry_addr(peer, now=now):
×
417
                    continue
×
418
                if not self.is_good_peer(peer):
×
419
                    continue
×
420
                #self.logger.info('taking random ln peer from our channel db')
421
                return [peer]
×
422

423
        # getting desperate... let's try hardcoded fallback list of peers
424
        fallback_list = constants.net.FALLBACK_LN_NODES
×
425
        fallback_list = [peer for peer in fallback_list if self._can_retry_addr(peer, now=now)]
×
426
        if fallback_list:
×
427
            return [random.choice(fallback_list)]
×
428

429
        # last resort: try dns seeds (BOLT-10)
NEW
430
        return await self._get_peers_from_dns_seeds()
×
431

432
    async def _get_peers_from_dns_seeds(self) -> Sequence[LNPeerAddr]:
4✔
433
        # NOTE: potentially long blocking call, do not run directly on asyncio event loop.
434
        # Return several peers to reduce the number of dns queries.
435
        if not constants.net.LN_DNS_SEEDS:
×
436
            return []
×
437
        dns_seed = random.choice(constants.net.LN_DNS_SEEDS)
×
438
        self.logger.info('asking dns seed "{}" for ln peers'.format(dns_seed))
×
439
        try:
×
440
            # note: this might block for several seconds
441
            # this will include bech32-encoded-pubkeys and ports
NEW
442
            srv_answers = await resolve_dns_srv('r{}.{}'.format(
×
443
                constants.net.LN_REALM_BYTE, dns_seed))
444
        except dns.exception.DNSException as e:
×
445
            self.logger.info(f'failed querying (1) dns seed "{dns_seed}" for ln peers: {repr(e)}')
×
446
            return []
×
447
        random.shuffle(srv_answers)
×
448
        num_peers = 2 * NUM_PEERS_TARGET
×
449
        srv_answers = srv_answers[:num_peers]
×
450
        # we now have pubkeys and ports but host is still needed
451
        peers = []
×
452
        for srv_ans in srv_answers:
×
453
            try:
×
454
                # note: this might take several seconds
NEW
455
                answers = await dns.asyncresolver.resolve(srv_ans['host'])
×
456
            except dns.exception.DNSException as e:
×
457
                self.logger.info(f'failed querying (2) dns seed "{dns_seed}" for ln peers: {repr(e)}')
×
458
                continue
×
459
            try:
×
460
                ln_host = str(answers[0])
×
461
                port = int(srv_ans['port'])
×
462
                bech32_pubkey = srv_ans['host'].split('.')[0]
×
463
                pubkey = get_compressed_pubkey_from_bech32(bech32_pubkey)
×
464
                peers.append(LNPeerAddr(ln_host, port, pubkey))
×
465
            except Exception as e:
×
466
                self.logger.info(f'error with parsing peer from dns seed: {repr(e)}')
×
467
                continue
×
468
        self.logger.info(f'got {len(peers)} ln peers from dns seed')
×
469
        return peers
×
470

471
    @staticmethod
4✔
472
    def choose_preferred_address(addr_list: Sequence[Tuple[str, int, int]]) -> Tuple[str, int, int]:
4✔
473
        assert len(addr_list) >= 1
×
474
        # choose the most recent one that is an IP
475
        for host, port, timestamp in sorted(addr_list, key=lambda a: -a[2]):
×
476
            if is_ip_address(host):
×
477
                return host, port, timestamp
×
478
        # otherwise choose one at random
479
        # TODO maybe filter out onion if not on tor?
480
        choice = random.choice(addr_list)
×
481
        return choice
×
482

483
    @event_listener
4✔
484
    def on_event_proxy_set(self, *args):
4✔
485
        for peer in self.peers.values():
×
486
            peer.close_and_cleanup()
×
487
        self._clear_addr_retry_times()
×
488

489
    @log_exceptions
4✔
490
    async def add_peer(self, connect_str: str) -> Peer:
4✔
491
        node_id, rest = extract_nodeid(connect_str)
×
492
        peer = self._peers.get(node_id)
×
493
        if not peer:
×
494
            if rest is not None:
×
495
                host, port = split_host_port(rest)
×
496
            else:
497
                if self.uses_trampoline():
×
498
                    addr = trampolines_by_id().get(node_id)
×
499
                    if not addr:
×
500
                        raise ConnStringFormatError(_('Address unknown for node:') + ' ' + node_id.hex())
×
501
                    host, port = addr.host, addr.port
×
502
                else:
503
                    addrs = self.channel_db.get_node_addresses(node_id)
×
504
                    if not addrs:
×
505
                        raise ConnStringFormatError(_('Don\'t know any addresses for node:') + ' ' + node_id.hex())
×
506
                    host, port, timestamp = self.choose_preferred_address(list(addrs))
×
507
            port = int(port)
×
508

509
            if not self.network.proxy:
×
510
                # Try DNS-resolving the host (if needed). This is simply so that
511
                # the caller gets a nice exception if it cannot be resolved.
512
                # (we don't do the DNS lookup if a proxy is set, to avoid a DNS-leak)
513
                if host.endswith('.onion'):
×
514
                    raise ConnStringFormatError(_('.onion address, but no proxy configured'))
×
515
                try:
×
516
                    await asyncio.get_running_loop().getaddrinfo(host, port)
×
517
                except socket.gaierror:
×
518
                    raise ConnStringFormatError(_('Hostname does not resolve (getaddrinfo failed)'))
×
519

520
            # add peer
521
            peer = await self._add_peer(host, port, node_id)
×
522
        return peer
×
523

524

525
class LNGossip(LNWorker):
4✔
526
    """The LNGossip class is a separate, unannounced Lightning node with random id that is just querying
527
    gossip from other nodes. The LNGossip node does not satisfy gossip queries, this is done by the
528
    LNWallet class(es). LNWallets are the advertised nodes used for actual payments and only satisfy
529
    peer queries without fetching gossip themselves. This separation is done so that gossip can be queried
530
    independently of the active LNWallets. LNGossip keeps a curated batch of gossip in _forwarding_gossip
531
    that is fetched by the LNWallets for regular forwarding."""
532
    max_age = 14*24*3600
4✔
533
    LOGGING_SHORTCUT = 'g'
4✔
534

535
    def __init__(self, config: 'SimpleConfig'):
4✔
536
        seed = os.urandom(32)
×
537
        node = BIP32Node.from_rootseed(seed, xtype='standard')
×
538
        xprv = node.to_xprv()
×
539
        node_keypair = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.NODE_KEY)
×
540
        LNWorker.__init__(self, node_keypair, LNGOSSIP_FEATURES, config=config)
×
541
        self.unknown_ids = set()
×
542
        self._forwarding_gossip = []  # type: List[GossipForwardingMessage]
×
543
        self._last_gossip_batch_ts = 0  # type: int
×
544
        self._forwarding_gossip_lock = asyncio.Lock()
×
545
        self.gossip_request_semaphore = asyncio.Semaphore(5)
×
546
        # statistics
547
        self._num_chan_ann = 0
×
548
        self._num_node_ann = 0
×
549
        self._num_chan_upd = 0
×
550
        self._num_chan_upd_good = 0
×
551

552
    def start_network(self, network: 'Network'):
4✔
553
        super().start_network(network)
×
554
        for coro in [
×
555
                self._maintain_connectivity(),
556
                self.maintain_db(),
557
                self._maintain_forwarding_gossip()
558
        ]:
559
            tg_coro = self.taskgroup.spawn(coro)
×
560
            asyncio.run_coroutine_threadsafe(tg_coro, self.network.asyncio_loop)
×
561

562
    async def maintain_db(self):
4✔
563
        await self.channel_db.data_loaded.wait()
×
564
        while True:
×
565
            if len(self.unknown_ids) == 0:
×
566
                self.channel_db.prune_old_policies(self.max_age)
×
567
                self.channel_db.prune_orphaned_channels()
×
568
            await asyncio.sleep(120)
×
569

570
    async def _maintain_forwarding_gossip(self):
4✔
571
        await self.channel_db.data_loaded.wait()
×
572
        await self.wait_for_sync()
×
573
        while True:
×
574
            async with self._forwarding_gossip_lock:
×
575
                self._forwarding_gossip = self.channel_db.get_forwarding_gossip_batch()
×
576
                self._last_gossip_batch_ts = int(time.time())
×
577
            self.logger.debug(f"{len(self._forwarding_gossip)} gossip messages available to forward")
×
578
            await asyncio.sleep(60)
×
579

580
    async def get_forwarding_gossip(self) -> tuple[List[GossipForwardingMessage], int]:
4✔
581
        async with self._forwarding_gossip_lock:
×
582
            return self._forwarding_gossip, self._last_gossip_batch_ts
×
583

584
    async def add_new_ids(self, ids: Iterable[bytes]):
4✔
585
        known = self.channel_db.get_channel_ids()
×
586
        new = set(ids) - set(known)
×
587
        self.unknown_ids.update(new)
×
588
        util.trigger_callback('unknown_channels', len(self.unknown_ids))
×
589
        util.trigger_callback('gossip_peers', self.num_peers())
×
590
        util.trigger_callback('ln_gossip_sync_progress')
×
591

592
    def get_ids_to_query(self) -> Sequence[bytes]:
4✔
593
        N = 500
×
594
        l = list(self.unknown_ids)
×
595
        self.unknown_ids = set(l[N:])
×
596
        util.trigger_callback('unknown_channels', len(self.unknown_ids))
×
597
        util.trigger_callback('ln_gossip_sync_progress')
×
598
        return l[0:N]
×
599

600
    def get_sync_progress_estimate(self) -> Tuple[Optional[int], Optional[int], Optional[int]]:
4✔
601
        """Estimates the gossip synchronization process and returns the number
602
        of synchronized channels, the total channels in the network and a
603
        rescaled percentage of the synchronization process."""
604
        if self.num_peers() == 0:
×
605
            return None, None, None
×
606
        nchans_with_0p, nchans_with_1p, nchans_with_2p = self.channel_db.get_num_channels_partitioned_by_policy_count()
×
607
        num_db_channels = nchans_with_0p + nchans_with_1p + nchans_with_2p
×
608
        num_nodes = self.channel_db.num_nodes
×
609
        num_nodes_associated_to_chans = max(len(self.channel_db._channels_for_node.keys()), 1)
×
610
        # some channels will never have two policies (only one is in gossip?...)
611
        # so if we have at least 1 policy for a channel, we consider that channel "complete" here
612
        current_est = num_db_channels - nchans_with_0p
×
613
        total_est = len(self.unknown_ids) + num_db_channels
×
614

615
        progress_chans = current_est / total_est if total_est and current_est else 0
×
616
        # consider that we got at least 10% of the node anns of node ids we know about
617
        progress_nodes = min((num_nodes / num_nodes_associated_to_chans) * 10, 1)
×
618
        progress = (progress_chans * 3 + progress_nodes) / 4  # weigh the channel progress higher
×
619
        # self.logger.debug(f"Sync process chans: {progress_chans} | Progress nodes: {progress_nodes} | "
620
        #                   f"Total progress: {progress} | NUM_NODES: {num_nodes} / {num_nodes_associated_to_chans}")
621
        progress_percent = (1.0 / 0.95 * progress) * 100
×
622
        progress_percent = min(progress_percent, 100)
×
623
        progress_percent = round(progress_percent)
×
624
        # take a minimal number of synchronized channels to get a more accurate
625
        # percentage estimate
626
        if current_est < 200:
×
627
            progress_percent = 0
×
628
        return current_est, total_est, progress_percent
×
629

630
    async def process_gossip(self, chan_anns, node_anns, chan_upds):
4✔
631
        # note: we run in the originating peer's TaskGroup, so we can safely raise here
632
        #       and disconnect only from that peer
633
        await self.channel_db.data_loaded.wait()
×
634

635
        # channel announcements
636
        def process_chan_anns():
×
637
            for payload in chan_anns:
×
638
                self.channel_db.verify_channel_announcement(payload)
×
639
            self.channel_db.add_channel_announcements(chan_anns)
×
640
        await run_in_thread(process_chan_anns)
×
641

642
        # node announcements
643
        def process_node_anns():
×
644
            for payload in node_anns:
×
645
                self.channel_db.verify_node_announcement(payload)
×
646
            self.channel_db.add_node_announcements(node_anns)
×
647
        await run_in_thread(process_node_anns)
×
648
        # channel updates
649
        categorized_chan_upds = await run_in_thread(partial(
×
650
            self.channel_db.add_channel_updates,
651
            chan_upds,
652
            max_age=self.max_age))
653
        orphaned = categorized_chan_upds.orphaned
×
654
        if orphaned:
×
655
            self.logger.info(f'adding {len(orphaned)} unknown channel ids')
×
656
            orphaned_ids = [c['short_channel_id'] for c in orphaned]
×
657
            await self.add_new_ids(orphaned_ids)
×
658

659
        self._num_chan_ann += len(chan_anns)
×
660
        self._num_node_ann += len(node_anns)
×
661
        self._num_chan_upd += len(chan_upds)
×
662
        self._num_chan_upd_good += len(categorized_chan_upds.good)
×
663

664
    def is_synced(self) -> bool:
4✔
665
        _, _, percentage_synced = self.get_sync_progress_estimate()
×
666
        if percentage_synced is not None and percentage_synced >= 100:
×
667
            return True
×
668
        return False
×
669

670
    async def wait_for_sync(self, times_to_check: int = 3):
4✔
671
        """Check if we have 100% sync progress `times_to_check` times in a row (because the
672
        estimate often jumps back after some seconds when doing initial sync)."""
673
        while True:
×
674
            if self.is_synced():
×
675
                times_to_check -= 1
×
676
                if times_to_check <= 0:
×
677
                    return
×
678
            await asyncio.sleep(10)
×
679
            # flush the gossip queue so we don't forward old gossip after sync is complete
680
            self.channel_db.get_forwarding_gossip_batch()
×
681

682

683
class PaySession(Logger):
4✔
684
    def __init__(
4✔
685
            self,
686
            *,
687
            payment_hash: bytes,
688
            payment_secret: bytes,
689
            initial_trampoline_fee_level: int,
690
            invoice_features: int,
691
            r_tags,
692
            min_final_cltv_delta: int,  # delta for last node (typically from invoice)
693
            amount_to_pay: int,  # total payment amount final receiver will get
694
            invoice_pubkey: bytes,
695
            uses_trampoline: bool,  # whether sender uses trampoline or gossip
696
            use_two_trampolines: bool,  # whether legacy payments will try to use two trampolines
697
    ):
698
        assert payment_hash
4✔
699
        assert payment_secret
4✔
700
        self.payment_hash = payment_hash
4✔
701
        self.payment_secret = payment_secret
4✔
702
        self.payment_key = payment_hash + payment_secret
4✔
703
        Logger.__init__(self)
4✔
704

705
        self.invoice_features = LnFeatures(invoice_features)
4✔
706
        self.r_tags = r_tags
4✔
707
        self.min_final_cltv_delta = min_final_cltv_delta
4✔
708
        self.amount_to_pay = amount_to_pay
4✔
709
        self.invoice_pubkey = invoice_pubkey
4✔
710

711
        self.sent_htlcs_q = asyncio.Queue()  # type: asyncio.Queue[HtlcLog]
4✔
712
        self.start_time = time.time()
4✔
713

714
        self.uses_trampoline = uses_trampoline
4✔
715
        self.trampoline_fee_level = initial_trampoline_fee_level
4✔
716
        self.failed_trampoline_routes = []
4✔
717
        self.use_two_trampolines = use_two_trampolines
4✔
718
        self._sent_buckets = dict()  # psecret_bucket -> (amount_sent, amount_failed)
4✔
719

720
        self._amount_inflight = 0  # what we sent in htlcs (that receiver gets, without fees)
4✔
721
        self._nhtlcs_inflight = 0
4✔
722
        self.is_active = True  # is still trying to send new htlcs?
4✔
723

724
    def diagnostic_name(self):
4✔
725
        pkey = sha256(self.payment_key)
4✔
726
        return f"{self.payment_hash[:4].hex()}-{pkey[:2].hex()}"
4✔
727

728
    def maybe_raise_trampoline_fee(self, htlc_log: HtlcLog):
4✔
729
        if htlc_log.trampoline_fee_level == self.trampoline_fee_level:
4✔
730
            self.trampoline_fee_level += 1
4✔
731
            self.failed_trampoline_routes = []
4✔
732
            self.logger.info(f'raising trampoline fee level {self.trampoline_fee_level}')
4✔
733
        else:
734
            self.logger.info(f'NOT raising trampoline fee level, already at {self.trampoline_fee_level}')
4✔
735

736
    def handle_failed_trampoline_htlc(self, *, htlc_log: HtlcLog, failure_msg: OnionRoutingFailure):
4✔
737
        # FIXME The trampoline nodes in the path are chosen randomly.
738
        #       Some of the errors might depend on how we have chosen them.
739
        #       Having more attempts is currently useful in part because of the randomness,
740
        #       instead we should give feedback to create_routes_for_payment.
741
        # Sometimes the trampoline node fails to send a payment and returns
742
        # TEMPORARY_CHANNEL_FAILURE, while it succeeds with a higher trampoline fee.
743
        if failure_msg.code in (
4✔
744
                OnionFailureCode.TRAMPOLINE_FEE_INSUFFICIENT,
745
                OnionFailureCode.TRAMPOLINE_EXPIRY_TOO_SOON,
746
                OnionFailureCode.TEMPORARY_CHANNEL_FAILURE):
747
            # TODO: parse the node policy here (not returned by eclair yet)
748
            # TODO: erring node is always the first trampoline even if second
749
            #  trampoline demands more fees, we can't influence this
750
            self.maybe_raise_trampoline_fee(htlc_log)
4✔
751
        elif self.use_two_trampolines:
4✔
752
            self.use_two_trampolines = False
×
753
        elif failure_msg.code in (
4✔
754
                OnionFailureCode.UNKNOWN_NEXT_PEER,
755
                OnionFailureCode.TEMPORARY_NODE_FAILURE):
756
            trampoline_route = htlc_log.route
4✔
757
            r = [hop.end_node.hex() for hop in trampoline_route]
4✔
758
            self.logger.info(f'failed trampoline route: {r}')
4✔
759
            if r not in self.failed_trampoline_routes:
4✔
760
                self.failed_trampoline_routes.append(r)
4✔
761
            else:
762
                pass  # maybe the route was reused between different MPP parts
×
763
        else:
764
            raise PaymentFailure(failure_msg.code_name())
4✔
765

766
    async def wait_for_one_htlc_to_resolve(self) -> HtlcLog:
4✔
767
        self.logger.info(f"waiting... amount_inflight={self._amount_inflight}. nhtlcs_inflight={self._nhtlcs_inflight}")
4✔
768
        htlc_log = await self.sent_htlcs_q.get()
4✔
769
        self._amount_inflight -= htlc_log.amount_msat
4✔
770
        self._nhtlcs_inflight -= 1
4✔
771
        if self._amount_inflight < 0 or self._nhtlcs_inflight < 0:
4✔
772
            raise Exception(f"amount_inflight={self._amount_inflight}, nhtlcs_inflight={self._nhtlcs_inflight}. both should be >= 0 !")
×
773
        return htlc_log
4✔
774

775
    def add_new_htlc(self, sent_htlc_info: SentHtlcInfo):
4✔
776
        self._nhtlcs_inflight += 1
4✔
777
        self._amount_inflight += sent_htlc_info.amount_receiver_msat
4✔
778
        if self._amount_inflight > self.amount_to_pay:  # safety belts
4✔
779
            raise Exception(f"amount_inflight={self._amount_inflight} > amount_to_pay={self.amount_to_pay}")
×
780
        shi = sent_htlc_info
4✔
781
        bkey = shi.payment_secret_bucket
4✔
782
        # if we sent MPP to a trampoline, add item to sent_buckets
783
        if self.uses_trampoline and shi.amount_msat != shi.bucket_msat:
4✔
784
            if bkey not in self._sent_buckets:
4✔
785
                self._sent_buckets[bkey] = (0, 0)
4✔
786
            amount_sent, amount_failed = self._sent_buckets[bkey]
4✔
787
            amount_sent += shi.amount_receiver_msat
4✔
788
            self._sent_buckets[bkey] = amount_sent, amount_failed
4✔
789

790
    def on_htlc_fail_get_fail_amt_to_propagate(self, sent_htlc_info: SentHtlcInfo) -> Optional[int]:
4✔
791
        shi = sent_htlc_info
4✔
792
        # check sent_buckets if we use trampoline
793
        bkey = shi.payment_secret_bucket
4✔
794
        if self.uses_trampoline and bkey in self._sent_buckets:
4✔
795
            amount_sent, amount_failed = self._sent_buckets[bkey]
4✔
796
            amount_failed += shi.amount_receiver_msat
4✔
797
            self._sent_buckets[bkey] = amount_sent, amount_failed
4✔
798
            if amount_sent != amount_failed:
4✔
799
                self.logger.info('bucket still active...')
4✔
800
                return None
4✔
801
            self.logger.info('bucket failed')
4✔
802
            return amount_sent
4✔
803
        # not using trampoline buckets
804
        return shi.amount_receiver_msat
4✔
805

806
    def get_outstanding_amount_to_send(self) -> int:
4✔
807
        return self.amount_to_pay - self._amount_inflight
4✔
808

809
    def can_be_deleted(self) -> bool:
4✔
810
        """Returns True iff finished sending htlcs AND all pending htlcs have resolved."""
811
        if self.is_active:
4✔
812
            return False
4✔
813
        # note: no one is consuming from sent_htlcs_q anymore
814
        nhtlcs_resolved = self.sent_htlcs_q.qsize()
4✔
815
        assert nhtlcs_resolved <= self._nhtlcs_inflight
4✔
816
        return nhtlcs_resolved == self._nhtlcs_inflight
4✔
817

818

819
class LNWallet(LNWorker):
4✔
820

821
    lnwatcher: Optional['LNWatcher']
4✔
822
    MPP_EXPIRY = 120
4✔
823
    TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 3  # seconds
4✔
824
    PAYMENT_TIMEOUT = 120
4✔
825
    MPP_SPLIT_PART_FRACTION = 0.2
4✔
826
    MPP_SPLIT_PART_MINAMT_MSAT = 5_000_000
4✔
827

828
    def __init__(self, wallet: 'Abstract_Wallet', xprv):
4✔
829
        self.wallet = wallet
4✔
830
        self.config = wallet.config
4✔
831
        self.db = wallet.db
4✔
832
        self.node_keypair = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.NODE_KEY)
4✔
833
        self.backup_key = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.BACKUP_CIPHER).privkey
4✔
834
        self.static_payment_key = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.PAYMENT_BASE)
4✔
835
        self.payment_secret_key = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.PAYMENT_SECRET_KEY).privkey
4✔
836
        self.funding_root_keypair = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.FUNDING_ROOT_KEY)
4✔
837
        Logger.__init__(self)
4✔
838
        features = LNWALLET_FEATURES
4✔
839
        if self.config.ENABLE_ANCHOR_CHANNELS:
4✔
840
            features |= LnFeatures.OPTION_ANCHORS_ZERO_FEE_HTLC_OPT
4✔
841
        if self.config.ACCEPT_ZEROCONF_CHANNELS:
4✔
842
            features |= LnFeatures.OPTION_ZEROCONF_OPT
×
843
        if self.config.EXPERIMENTAL_LN_FORWARD_PAYMENTS and self.config.LIGHTNING_USE_GOSSIP:
4✔
844
            features |= LnFeatures.GOSSIP_QUERIES_OPT  # signal we have gossip to fetch
×
845
        LNWorker.__init__(self, self.node_keypair, features, config=self.config)
4✔
846
        self.lnwatcher = LNWatcher(self)
4✔
847
        self.lnrater: LNRater = None
4✔
848
        self.payment_info = self.db.get_dict('lightning_payments')     # RHASH -> amount, direction, is_paid
4✔
849
        self._preimages = self.db.get_dict('lightning_preimages')   # RHASH -> preimage
4✔
850
        self._bolt11_cache = {}
4✔
851
        # note: this sweep_address is only used as fallback; as it might result in address-reuse
852
        self.logs = defaultdict(list)  # type: Dict[str, List[HtlcLog]]  # key is RHASH  # (not persisted)
4✔
853
        # used in tests
854
        self.enable_htlc_settle = True
4✔
855
        self.enable_htlc_forwarding = True
4✔
856

857
        # note: accessing channels (besides simple lookup) needs self.lock!
858
        self._channels = {}  # type: Dict[bytes, Channel]
4✔
859
        channels = self.db.get_dict("channels")
4✔
860
        for channel_id, c in random_shuffled_copy(channels.items()):
4✔
861
            self._channels[bfh(channel_id)] = chan = Channel(c, lnworker=self)
4✔
862
            self.wallet.set_reserved_addresses_for_chan(chan, reserved=True)
4✔
863

864
        self._channel_backups = {}  # type: Dict[bytes, ChannelBackup]
4✔
865
        # order is important: imported should overwrite onchain
866
        for name in ["onchain_channel_backups", "imported_channel_backups"]:
4✔
867
            channel_backups = self.db.get_dict(name)
4✔
868
            for channel_id, storage in channel_backups.items():
4✔
869
                self._channel_backups[bfh(channel_id)] = cb = ChannelBackup(storage, lnworker=self)
×
870
                self.wallet.set_reserved_addresses_for_chan(cb, reserved=True)
×
871

872
        self._paysessions = dict()                      # type: Dict[bytes, PaySession]
4✔
873
        self.sent_htlcs_info = dict()                   # type: Dict[SentHtlcKey, SentHtlcInfo]
4✔
874
        self.received_mpp_htlcs = self.db.get_dict('received_mpp_htlcs')   # type: Dict[str, ReceivedMPPStatus]  # payment_key -> ReceivedMPPStatus
4✔
875

876
        # detect inflight payments
877
        self.inflight_payments = set()        # (not persisted) keys of invoices that are in PR_INFLIGHT state
4✔
878
        for payment_hash in self.get_payments(status='inflight').keys():
4✔
879
            self.set_invoice_status(payment_hash.hex(), PR_INFLIGHT)
×
880

881
        # payment forwarding
882
        self.active_forwardings = self.db.get_dict('active_forwardings')    # type: Dict[str, List[str]]        # Dict: payment_key -> list of htlc_keys
4✔
883
        self.forwarding_failures = self.db.get_dict('forwarding_failures')  # type: Dict[str, Tuple[str, str]]  # Dict: payment_key -> (error_bytes, error_message)
4✔
884
        self.downstream_to_upstream_htlc = {}                               # type: Dict[str, str]              # Dict: htlc_key -> htlc_key (not persisted)
4✔
885
        self.dont_settle_htlcs = self.db.get_dict('dont_settle_htlcs')      # type: Dict[str, None]             # payment_hashes of htlcs that we should not settle back yet even if we have the preimage
4✔
886

887
        # payment_hash -> callback:
888
        self.hold_invoice_callbacks = {}                # type: Dict[bytes, Callable[[bytes], Awaitable[None]]]
4✔
889
        self.payment_bundles = []                       # lists of hashes. todo:persist
4✔
890

891
        self.nostr_keypair = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.NOSTR_KEY)
4✔
892
        self.swap_manager = SwapManager(wallet=self.wallet, lnworker=self)
4✔
893
        self.onion_message_manager = OnionMessageManager(self)
4✔
894
        self.subscribe_to_channels()
4✔
895

896
    def subscribe_to_channels(self):
4✔
897
        for chan in self.channels.values():
4✔
898
            self.lnwatcher.add_channel(chan)
4✔
899
        for cb in self.channel_backups.values():
4✔
900
            self.lnwatcher.add_channel(cb)
×
901

902
    def has_deterministic_node_id(self) -> bool:
4✔
903
        return bool(self.db.get('lightning_xprv'))
×
904

905
    def can_have_recoverable_channels(self) -> bool:
4✔
906
        return (self.has_deterministic_node_id()
×
907
                and not self.config.LIGHTNING_LISTEN)
908

909
    def has_recoverable_channels(self) -> bool:
4✔
910
        """Whether *future* channels opened by this wallet would be recoverable
911
        from seed (via putting OP_RETURN outputs into funding txs).
912
        """
913
        return (self.can_have_recoverable_channels()
×
914
                and self.config.LIGHTNING_USE_RECOVERABLE_CHANNELS)
915

916
    def has_anchor_channels(self) -> bool:
4✔
917
        """Returns True if any active channel is an anchor channel."""
918
        return any(chan.has_anchors() and not chan.is_redeemed()
4✔
919
                    for chan in self.channels.values())
920

921
    @property
4✔
922
    def channels(self) -> Mapping[bytes, Channel]:
4✔
923
        """Returns a read-only copy of channels."""
924
        with self.lock:
4✔
925
            return self._channels.copy()
4✔
926

927
    @property
4✔
928
    def channel_backups(self) -> Mapping[bytes, ChannelBackup]:
4✔
929
        """Returns a read-only copy of channels."""
930
        with self.lock:
4✔
931
            return self._channel_backups.copy()
4✔
932

933
    def get_channel_objects(self) -> Mapping[bytes, AbstractChannel]:
4✔
934
        r = self.channel_backups
×
935
        r.update(self.channels)
×
936
        return r
×
937

938
    def get_channel_by_id(self, channel_id: bytes) -> Optional[Channel]:
4✔
939
        return self._channels.get(channel_id, None)
4✔
940

941
    def diagnostic_name(self):
4✔
942
        return self.wallet.diagnostic_name()
4✔
943

944
    @ignore_exceptions
4✔
945
    @log_exceptions
4✔
946
    async def sync_with_remote_watchtower(self):
4✔
947
        self.watchtower_ctns = {}
×
948
        while True:
×
949
            # periodically poll if the user updated 'watchtower_url'
950
            await asyncio.sleep(5)
×
951
            watchtower_url = self.config.WATCHTOWER_CLIENT_URL
×
952
            if not watchtower_url:
×
953
                continue
×
954
            parsed_url = urllib.parse.urlparse(watchtower_url)
×
955
            if not (parsed_url.scheme == 'https' or is_private_netaddress(parsed_url.hostname)):
×
956
                self.logger.warning(f"got watchtower URL for remote tower but we won't use it! "
×
957
                                    f"can only use HTTPS (except if private IP): not using {watchtower_url!r}")
958
                continue
×
959
            # try to sync with the remote watchtower
960
            try:
×
961
                async with make_aiohttp_session(proxy=self.network.proxy) as session:
×
962
                    watchtower = JsonRPCClient(session, watchtower_url)
×
963
                    watchtower.add_method('get_ctn')
×
964
                    watchtower.add_method('add_sweep_tx')
×
965
                    for chan in self.channels.values():
×
966
                        await self.sync_channel_with_watchtower(chan, watchtower)
×
967
            except aiohttp.client_exceptions.ClientConnectorError:
×
968
                self.logger.info(f'could not contact remote watchtower {watchtower_url}')
×
969

970
    def get_watchtower_ctn(self, channel_point):
4✔
971
        return self.watchtower_ctns.get(channel_point)
×
972

973
    async def sync_channel_with_watchtower(self, chan: Channel, watchtower):
4✔
974
        outpoint = chan.funding_outpoint.to_str()
×
975
        addr = chan.get_funding_address()
×
976
        current_ctn = chan.get_oldest_unrevoked_ctn(REMOTE)
×
977
        watchtower_ctn = await watchtower.get_ctn(outpoint, addr)
×
978
        for ctn in range(watchtower_ctn + 1, current_ctn):
×
979
            sweeptxs = chan.create_sweeptxs_for_watchtower(ctn)
×
980
            for tx in sweeptxs:
×
981
                await watchtower.add_sweep_tx(outpoint, ctn, tx.inputs()[0].prevout.to_str(), tx.serialize())
×
982
            self.watchtower_ctns[outpoint] = ctn
×
983

984
    def start_network(self, network: 'Network'):
4✔
985
        super().start_network(network)
×
986
        self.lnwatcher.start_network(network)
×
987
        self.swap_manager.start_network(network)
×
988
        self.lnrater = LNRater(self, network)
×
989
        self.onion_message_manager.start_network(network=network)
×
990

991
        for coro in [
×
992
                self.maybe_listen(),
993
                self.lnwatcher.trigger_callbacks(), # shortcut (don't block) if funding tx locked and verified
994
                self.reestablish_peers_and_channels(),
995
                self.sync_with_remote_watchtower(),
996
        ]:
997
            tg_coro = self.taskgroup.spawn(coro)
×
998
            asyncio.run_coroutine_threadsafe(tg_coro, self.network.asyncio_loop)
×
999

1000
    async def stop(self):
4✔
1001
        self.stopping_soon = True
4✔
1002
        if self.listen_server:  # stop accepting new peers
4✔
1003
            self.listen_server.close()
×
1004
        async with ignore_after(self.TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS):
4✔
1005
            await self.wait_for_received_pending_htlcs_to_get_removed()
4✔
1006
        await LNWorker.stop(self)
4✔
1007
        if self.lnwatcher:
4✔
1008
            self.lnwatcher.stop()
×
1009
            self.lnwatcher = None
×
1010
        if self.swap_manager and self.swap_manager.network:  # may not be present in tests
4✔
1011
            await self.swap_manager.stop()
×
1012
        if self.onion_message_manager:
4✔
1013
            await self.onion_message_manager.stop()
×
1014

1015
    async def wait_for_received_pending_htlcs_to_get_removed(self):
4✔
1016
        assert self.stopping_soon is True
4✔
1017
        # We try to fail pending MPP HTLCs, and wait a bit for them to get removed.
1018
        # Note: even without MPP, if we just failed/fulfilled an HTLC, it is good
1019
        #       to wait a bit for it to become irrevocably removed.
1020
        # Note: we don't wait for *all htlcs* to get removed, only for those
1021
        #       that we can already fail/fulfill. e.g. forwarded htlcs cannot be removed
1022
        async with OldTaskGroup() as group:
4✔
1023
            for peer in self.peers.values():
4✔
1024
                await group.spawn(peer.wait_one_htlc_switch_iteration())
4✔
1025
        while True:
4✔
1026
            if all(not peer.received_htlcs_pending_removal for peer in self.peers.values()):
4✔
1027
                break
4✔
1028
            async with OldTaskGroup(wait=any) as group:
4✔
1029
                for peer in self.peers.values():
4✔
1030
                    await group.spawn(peer.received_htlc_removed_event.wait())
4✔
1031

1032
    def peer_closed(self, peer):
4✔
1033
        for chan in self.channels_for_peer(peer.pubkey).values():
×
1034
            chan.peer_state = PeerState.DISCONNECTED
×
1035
            util.trigger_callback('channel', self.wallet, chan)
×
1036
        super().peer_closed(peer)
×
1037

1038
    def get_payments(self, *, status=None) -> Mapping[bytes, List[HTLCWithStatus]]:
4✔
1039
        out = defaultdict(list)
4✔
1040
        for chan in self.channels.values():
4✔
1041
            d = chan.get_payments(status=status)
4✔
1042
            for payment_hash, plist in d.items():
4✔
1043
                out[payment_hash] += plist
4✔
1044
        return out
4✔
1045

1046
    def get_payment_value(
4✔
1047
            self, info: Optional['PaymentInfo'],
1048
            plist: List[HTLCWithStatus]) -> Tuple[PaymentDirection, int, Optional[int], int]:
1049
        """ fee_msat is included in amount_msat"""
1050
        assert plist
×
1051
        amount_msat = sum(int(x.direction) * x.htlc.amount_msat for x in plist)
×
1052
        if all(x.direction == SENT for x in plist):
×
1053
            direction = PaymentDirection.SENT
×
1054
            fee_msat = (- info.amount_msat - amount_msat) if info else None
×
1055
        elif all(x.direction == RECEIVED for x in plist):
×
1056
            direction = PaymentDirection.RECEIVED
×
1057
            fee_msat = None
×
1058
        elif amount_msat < 0:
×
1059
            direction = PaymentDirection.SELF_PAYMENT
×
1060
            fee_msat = - amount_msat
×
1061
        else:
1062
            direction = PaymentDirection.FORWARDING
×
1063
            fee_msat = - amount_msat
×
1064
        timestamp = min([htlc_with_status.htlc.timestamp for htlc_with_status in plist])
×
1065
        return direction, amount_msat, fee_msat, timestamp
×
1066

1067
    def get_lightning_history(self) -> Dict[str, LightningHistoryItem]:
4✔
1068
        """
1069
        side effect: sets defaults labels
1070
        note that the result is not ordered
1071
        """
1072
        out = {}
×
1073
        for payment_hash, plist in self.get_payments(status='settled').items():
×
1074
            if len(plist) == 0:
×
1075
                continue
×
1076
            key = payment_hash.hex()
×
1077
            info = self.get_payment_info(payment_hash)
×
1078
            # note: just after successfully paying an invoice using MPP, amount and fee values might be shifted
1079
            #       temporarily: the amount only considers 'settled' htlcs (see plist above), but we might also
1080
            #       have some inflight htlcs still. Until all relevant htlcs settle, the amount will be lower than
1081
            #       expected and the fee higher (the inflight htlcs will be effectively counted as fees).
1082
            direction, amount_msat, fee_msat, timestamp = self.get_payment_value(info, plist)
×
1083
            label = self.wallet.get_label_for_rhash(key)
×
1084
            if not label and direction == PaymentDirection.FORWARDING:
×
1085
                label = _('Forwarding')
×
1086
            preimage = self.get_preimage(payment_hash).hex()
×
1087
            group_id = self.swap_manager.get_group_id_for_payment_hash(payment_hash)
×
1088
            item = LightningHistoryItem(
×
1089
                type = 'payment',
1090
                payment_hash = payment_hash.hex(),
1091
                preimage = preimage,
1092
                amount_msat = amount_msat,
1093
                fee_msat = fee_msat,
1094
                group_id = group_id,
1095
                timestamp = timestamp or 0,
1096
                label=label,
1097
                direction=direction,
1098
            )
1099
            out[payment_hash.hex()] = item
×
1100
        for chan in itertools.chain(self.channels.values(), self.channel_backups.values()):  # type: AbstractChannel
×
1101
            item = chan.get_funding_height()
×
1102
            if item is None:
×
1103
                continue
×
1104
            funding_txid, funding_height, funding_timestamp = item
×
1105
            label = _('Open channel') + ' ' + chan.get_id_for_log()
×
1106
            self.wallet.set_default_label(funding_txid, label)
×
1107
            self.wallet.set_group_label(funding_txid, label)
×
1108
            item = LightningHistoryItem(
×
1109
                type = 'channel_opening',
1110
                label = label,
1111
                group_id = funding_txid,
1112
                timestamp = funding_timestamp,
1113
                amount_msat = chan.balance(LOCAL, ctn=0),
1114
                fee_msat = None,
1115
                payment_hash = None,
1116
                preimage = None,
1117
                direction=None,
1118
            )
1119
            out[funding_txid] = item
×
1120
            item = chan.get_closing_height()
×
1121
            if item is None:
×
1122
                continue
×
1123
            closing_txid, closing_height, closing_timestamp = item
×
1124
            label = _('Close channel') + ' ' + chan.get_id_for_log()
×
1125
            self.wallet.set_default_label(closing_txid, label)
×
1126
            self.wallet.set_group_label(closing_txid, label)
×
1127
            item = LightningHistoryItem(
×
1128
                type = 'channel_closing',
1129
                label = label,
1130
                group_id = closing_txid,
1131
                timestamp = closing_timestamp,
1132
                amount_msat = -chan.balance(LOCAL),
1133
                fee_msat = None,
1134
                payment_hash = None,
1135
                preimage = None,
1136
                direction=None,
1137
            )
1138
            out[closing_txid] = item
×
1139

1140
        # sanity check
1141
        balance_msat = sum([x.amount_msat for x in out.values()])
×
1142
        lb = sum(chan.balance(LOCAL) if not chan.is_closed_or_closing() else 0
×
1143
                for chan in self.channels.values())
1144
        if balance_msat != lb:
×
1145
            # this typically happens when a channel is recently force closed
1146
            self.logger.info(f'get_lightning_history: balance mismatch {balance_msat - lb}')
×
1147
        return out
×
1148

1149
    def get_groups_for_onchain_history(self) -> Dict[str, str]:
4✔
1150
        """
1151
        returns dict: txid -> group_id
1152
        side effect: sets default labels
1153
        """
1154
        groups = {}
×
1155
        # add funding events
1156
        for chan in itertools.chain(self.channels.values(), self.channel_backups.values()):  # type: AbstractChannel
×
1157
            item = chan.get_funding_height()
×
1158
            if item is None:
×
1159
                continue
×
1160
            funding_txid, funding_height, funding_timestamp = item
×
1161
            groups[funding_txid] = funding_txid
×
1162
            item = chan.get_closing_height()
×
1163
            if item is None:
×
1164
                continue
×
1165
            closing_txid, closing_height, closing_timestamp = item
×
1166
            groups[closing_txid] = closing_txid
×
1167

1168
        d = self.swap_manager.get_groups_for_onchain_history()
×
1169
        for txid, v in d.items():
×
1170
            group_id = v['group_id']
×
1171
            label = v.get('label')
×
1172
            group_label = v.get('group_label') or label
×
1173
            groups[txid] = group_id
×
1174
            if label:
×
1175
                self.wallet.set_default_label(txid, label)
×
1176
            if group_label:
×
1177
                self.wallet.set_group_label(group_id, group_label)
×
1178

1179
        return groups
×
1180

1181
    def channel_peers(self) -> List[bytes]:
4✔
1182
        node_ids = [chan.node_id for chan in self.channels.values() if not chan.is_closed()]
×
1183
        return node_ids
×
1184

1185
    def channels_for_peer(self, node_id):
4✔
1186
        assert type(node_id) is bytes
4✔
1187
        return {chan_id: chan for (chan_id, chan) in self.channels.items()
4✔
1188
                if chan.node_id == node_id}
1189

1190
    def channel_state_changed(self, chan: Channel):
4✔
1191
        if type(chan) is Channel:
×
1192
            self.save_channel(chan)
×
1193
        self.clear_invoices_cache()
×
1194
        util.trigger_callback('channel', self.wallet, chan)
×
1195

1196
    def save_channel(self, chan: Channel):
4✔
1197
        assert type(chan) is Channel
×
1198
        if chan.config[REMOTE].next_per_commitment_point == chan.config[REMOTE].current_per_commitment_point:
×
1199
            raise Exception("Tried to save channel with next_point == current_point, this should not happen")
×
1200
        self.wallet.save_db()
×
1201
        util.trigger_callback('channel', self.wallet, chan)
×
1202

1203
    def channel_by_txo(self, txo: str) -> Optional[AbstractChannel]:
4✔
1204
        for chan in self.channels.values():
×
1205
            if chan.funding_outpoint.to_str() == txo:
×
1206
                return chan
×
1207
        for chan in self.channel_backups.values():
×
1208
            if chan.funding_outpoint.to_str() == txo:
×
1209
                return chan
×
1210

1211
    async def handle_onchain_state(self, chan: Channel):
4✔
1212
        if self.network is None:
×
1213
            # network not started yet
1214
            return
×
1215

1216
        if type(chan) is ChannelBackup:
×
1217
            util.trigger_callback('channel', self.wallet, chan)
×
1218
            return
×
1219

1220
        if (chan.get_state() in (ChannelState.OPEN, ChannelState.SHUTDOWN)
×
1221
                and chan.should_be_closed_due_to_expiring_htlcs(self.wallet.adb.get_local_height())):
1222
            self.logger.info(f"force-closing due to expiring htlcs")
×
1223
            await self.schedule_force_closing(chan.channel_id)
×
1224

1225
        elif chan.get_state() == ChannelState.FUNDED:
×
1226
            peer = self._peers.get(chan.node_id)
×
1227
            if peer and peer.is_initialized() and chan.peer_state == PeerState.GOOD:
×
1228
                peer.send_channel_ready(chan)
×
1229

1230
        elif chan.get_state() == ChannelState.OPEN:
×
1231
            peer = self._peers.get(chan.node_id)
×
1232
            if peer and peer.is_initialized() and chan.peer_state == PeerState.GOOD:
×
1233
                peer.maybe_update_fee(chan)
×
1234
                peer.maybe_send_announcement_signatures(chan)
×
1235

1236
        elif chan.get_state() == ChannelState.FORCE_CLOSING:
×
1237
            force_close_tx = chan.force_close_tx()
×
1238
            txid = force_close_tx.txid()
×
1239
            height = self.lnwatcher.adb.get_tx_height(txid).height
×
1240
            if height == TX_HEIGHT_LOCAL:
×
1241
                self.logger.info('REBROADCASTING CLOSING TX')
×
1242
                await self.network.try_broadcasting(force_close_tx, 'force-close')
×
1243

1244
    def get_peer_by_static_jit_scid_alias(self, scid_alias: bytes) -> Optional[Peer]:
4✔
1245
        for nodeid, peer in self.peers.items():
×
1246
            if scid_alias == self._scid_alias_of_node(nodeid):
×
1247
                return peer
×
1248

1249
    def _scid_alias_of_node(self, nodeid: bytes) -> bytes:
4✔
1250
        # scid alias for just-in-time channels
1251
        return sha256(b'Electrum' + nodeid)[0:8]
×
1252

1253
    def get_static_jit_scid_alias(self) -> bytes:
4✔
1254
        return self._scid_alias_of_node(self.node_keypair.pubkey)
×
1255

1256
    @log_exceptions
4✔
1257
    async def open_channel_just_in_time(
4✔
1258
        self,
1259
        *,
1260
        next_peer: Peer,
1261
        next_amount_msat_htlc: int,
1262
        next_cltv_abs: int,
1263
        payment_hash: bytes,
1264
        next_onion: OnionPacket,
1265
    ) -> str:
1266
        # if an exception is raised during negotiation, we raise an OnionRoutingFailure.
1267
        # this will cancel the incoming HTLC
1268

1269
        # prevent settling the htlc until the channel opening was successfull so we can fail it if needed
1270
        self.dont_settle_htlcs[payment_hash.hex()] = None
×
1271
        try:
×
1272
            funding_sat = 2 * (next_amount_msat_htlc // 1000) # try to fully spend htlcs
×
1273
            password = self.wallet.get_unlocked_password() if self.wallet.has_password() else None
×
1274
            channel_opening_fee = next_amount_msat_htlc // 100
×
1275
            if channel_opening_fee // 1000 < self.config.ZEROCONF_MIN_OPENING_FEE:
×
1276
                self.logger.info(f'rejecting JIT channel: payment too low')
×
1277
                raise OnionRoutingFailure(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'payment too low')
×
1278
            self.logger.info(f'channel opening fee (sats): {channel_opening_fee//1000}')
×
1279
            next_chan, funding_tx = await self.open_channel_with_peer(
×
1280
                next_peer, funding_sat,
1281
                push_sat=0,
1282
                zeroconf=True,
1283
                public=False,
1284
                opening_fee=channel_opening_fee,
1285
                password=password,
1286
            )
1287
            async def wait_for_channel():
×
1288
                while not next_chan.is_open():
×
1289
                    await asyncio.sleep(1)
×
1290
            await util.wait_for2(wait_for_channel(), LN_P2P_NETWORK_TIMEOUT)
×
1291
            next_chan.save_remote_scid_alias(self._scid_alias_of_node(next_peer.pubkey))
×
1292
            self.logger.info(f'JIT channel is open')
×
1293
            next_amount_msat_htlc -= channel_opening_fee
×
1294
            # fixme: some checks are missing
1295
            htlc = next_peer.send_htlc(
×
1296
                chan=next_chan,
1297
                payment_hash=payment_hash,
1298
                amount_msat=next_amount_msat_htlc,
1299
                cltv_abs=next_cltv_abs,
1300
                onion=next_onion)
1301
            async def wait_for_preimage():
×
1302
                while self.get_preimage(payment_hash) is None:
×
1303
                    await asyncio.sleep(1)
×
1304
            await util.wait_for2(wait_for_preimage(), LN_P2P_NETWORK_TIMEOUT)
×
1305

1306
            # We have been paid and can broadcast
1307
            # todo: if broadcasting raise an exception, we should try to rebroadcast
1308
            await self.network.broadcast_transaction(funding_tx)
×
1309
        except OnionRoutingFailure:
×
1310
            raise
×
1311
        except Exception:
×
1312
            raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_NODE_FAILURE, data=b'')
×
1313
        finally:
1314
            del self.dont_settle_htlcs[payment_hash.hex()]
×
1315

1316
        htlc_key = serialize_htlc_key(next_chan.get_scid_or_local_alias(), htlc.htlc_id)
×
1317
        return htlc_key
×
1318

1319
    @log_exceptions
4✔
1320
    async def open_channel_with_peer(
4✔
1321
            self, peer, funding_sat, *,
1322
            push_sat: int = 0,
1323
            public: bool = False,
1324
            zeroconf: bool = False,
1325
            opening_fee: int = None,
1326
            password=None):
1327
        if self.config.ENABLE_ANCHOR_CHANNELS:
×
1328
            self.wallet.unlock(password)
×
1329
        coins = self.wallet.get_spendable_coins(None)
×
1330
        node_id = peer.pubkey
×
1331
        fee_policy = FeePolicy(self.config.FEE_POLICY)
×
1332
        funding_tx = self.mktx_for_open_channel(
×
1333
            coins=coins,
1334
            funding_sat=funding_sat,
1335
            node_id=node_id,
1336
            fee_policy=fee_policy)
1337
        chan, funding_tx = await self._open_channel_coroutine(
×
1338
            peer=peer,
1339
            funding_tx=funding_tx,
1340
            funding_sat=funding_sat,
1341
            push_sat=push_sat,
1342
            public=public,
1343
            zeroconf=zeroconf,
1344
            opening_fee=opening_fee,
1345
            password=password)
1346
        return chan, funding_tx
×
1347

1348
    @log_exceptions
4✔
1349
    async def _open_channel_coroutine(
4✔
1350
            self, *,
1351
            peer: Peer,
1352
            funding_tx: PartialTransaction,
1353
            funding_sat: int,
1354
            push_sat: int,
1355
            public: bool,
1356
            zeroconf=False,
1357
            opening_fee=None,
1358
            password: Optional[str],
1359
    ) -> Tuple[Channel, PartialTransaction]:
1360

1361
        if funding_sat > self.config.LIGHTNING_MAX_FUNDING_SAT:
×
1362
            raise Exception(
×
1363
                _("Requested channel capacity is over maximum.")
1364
                + f"\n{funding_sat} sat > {self.config.LIGHTNING_MAX_FUNDING_SAT} sat"
1365
            )
1366
        coro = peer.channel_establishment_flow(
×
1367
            funding_tx=funding_tx,
1368
            funding_sat=funding_sat,
1369
            push_msat=push_sat * 1000,
1370
            public=public,
1371
            zeroconf=zeroconf,
1372
            opening_fee=opening_fee,
1373
            temp_channel_id=os.urandom(32))
1374
        chan, funding_tx = await util.wait_for2(coro, LN_P2P_NETWORK_TIMEOUT)
×
1375
        util.trigger_callback('channels_updated', self.wallet)
×
1376
        self.wallet.adb.add_transaction(funding_tx)  # save tx as local into the wallet
×
1377
        self.wallet.sign_transaction(funding_tx, password)
×
1378
        if funding_tx.is_complete() and not zeroconf:
×
1379
            await self.network.try_broadcasting(funding_tx, 'open_channel')
×
1380
        return chan, funding_tx
×
1381

1382
    def add_channel(self, chan: Channel):
4✔
1383
        with self.lock:
×
1384
            self._channels[chan.channel_id] = chan
×
1385
        self.lnwatcher.add_channel(chan)
×
1386

1387
    def add_new_channel(self, chan: Channel):
4✔
1388
        self.add_channel(chan)
×
1389
        channels_db = self.db.get_dict('channels')
×
1390
        channels_db[chan.channel_id.hex()] = chan.storage
×
1391
        self.wallet.set_reserved_addresses_for_chan(chan, reserved=True)
×
1392
        try:
×
1393
            self.save_channel(chan)
×
1394
        except Exception:
×
1395
            chan.set_state(ChannelState.REDEEMED)
×
1396
            self.remove_channel(chan.channel_id)
×
1397
            raise
×
1398

1399
    def cb_data(self, node_id: bytes) -> bytes:
4✔
1400
        return CB_MAGIC_BYTES + node_id[0:NODE_ID_PREFIX_LEN]
×
1401

1402
    def decrypt_cb_data(self, encrypted_data, funding_address):
4✔
1403
        funding_scripthash = bytes.fromhex(address_to_scripthash(funding_address))
×
1404
        nonce = funding_scripthash[0:12]
×
1405
        return chacha20_decrypt(key=self.backup_key, data=encrypted_data, nonce=nonce)
×
1406

1407
    def encrypt_cb_data(self, data, funding_address):
4✔
1408
        funding_scripthash = bytes.fromhex(address_to_scripthash(funding_address))
×
1409
        nonce = funding_scripthash[0:12]
×
1410
        # note: we are only using chacha20 instead of chacha20+poly1305 to save onchain space
1411
        #       (not have the 16 byte MAC). Otherwise, the latter would be preferable.
1412
        return chacha20_encrypt(key=self.backup_key, data=data, nonce=nonce)
×
1413

1414
    def mktx_for_open_channel(
4✔
1415
            self, *,
1416
            coins: Sequence[PartialTxInput],
1417
            funding_sat: int,
1418
            node_id: bytes,
1419
            fee_policy: FeePolicy,
1420
    ) -> PartialTransaction:
1421
        from .wallet import get_locktime_for_new_transaction
×
1422

1423
        outputs = [PartialTxOutput.from_address_and_value(DummyAddress.CHANNEL, funding_sat)]
×
1424
        if self.has_recoverable_channels():
×
1425
            dummy_scriptpubkey = make_op_return(self.cb_data(node_id))
×
1426
            outputs.append(PartialTxOutput(scriptpubkey=dummy_scriptpubkey, value=0))
×
1427
        tx = self.wallet.make_unsigned_transaction(
×
1428
            coins=coins,
1429
            outputs=outputs,
1430
            fee_policy=fee_policy,
1431
            # we do not know yet if peer accepts anchors, just assume they do
1432
            is_anchor_channel_opening=self.config.ENABLE_ANCHOR_CHANNELS,
1433
        )
1434
        tx.set_rbf(False)
×
1435
        # rm randomness from locktime, as we use the locktime as entropy for deriving the funding_privkey
1436
        # (and it would be confusing to get a collision as a consequence of the randomness)
1437
        tx.locktime = get_locktime_for_new_transaction(self.network, include_random_component=False)
×
1438
        return tx
×
1439

1440
    def suggest_funding_amount(self, amount_to_pay, coins) -> Tuple[int, int] | None:
4✔
1441
        """ whether we can pay amount_sat after opening a new channel"""
1442
        num_sats_can_send = int(self.num_sats_can_send())
×
1443
        lightning_needed = amount_to_pay - num_sats_can_send
×
1444
        assert lightning_needed > 0
×
1445
        min_funding_sat = lightning_needed + (lightning_needed // 20) + 1000  # safety margin
×
1446
        min_funding_sat = max(min_funding_sat, MIN_FUNDING_SAT)  # at least MIN_FUNDING_SAT
×
1447
        if min_funding_sat > self.config.LIGHTNING_MAX_FUNDING_SAT:
×
1448
            return
×
1449
        fee_policy = FeePolicy(f'feerate:{FEERATE_FALLBACK_STATIC_FEE}')
×
1450
        try:
×
1451
            self.mktx_for_open_channel(coins=coins, funding_sat=min_funding_sat, node_id=bytes(32), fee_policy=fee_policy)
×
1452
            funding_sat = min_funding_sat
×
1453
        except NotEnoughFunds:
×
1454
            return
×
1455
        # if available, suggest twice that amount:
1456
        if 2 * min_funding_sat <= self.config.LIGHTNING_MAX_FUNDING_SAT:
×
1457
            try:
×
1458
                self.mktx_for_open_channel(coins=coins, funding_sat=2*min_funding_sat, node_id=bytes(32), fee_policy=fee_policy)
×
1459
                funding_sat = 2 * min_funding_sat
×
1460
            except NotEnoughFunds:
×
1461
                pass
×
1462
        return funding_sat, min_funding_sat
×
1463

1464
    def open_channel(
4✔
1465
            self, *,
1466
            connect_str: str,
1467
            funding_tx: PartialTransaction,
1468
            funding_sat: int,
1469
            push_amt_sat: int,
1470
            public: bool = False,
1471
            password: str = None,
1472
    ) -> Tuple[Channel, PartialTransaction]:
1473

1474
        fut = asyncio.run_coroutine_threadsafe(self.add_peer(connect_str), self.network.asyncio_loop)
×
1475
        try:
×
1476
            peer = fut.result()
×
1477
        except concurrent.futures.TimeoutError:
×
1478
            raise Exception(_("add peer timed out"))
×
1479
        coro = self._open_channel_coroutine(
×
1480
            peer=peer,
1481
            funding_tx=funding_tx,
1482
            funding_sat=funding_sat,
1483
            push_sat=push_amt_sat,
1484
            public=public,
1485
            password=password)
1486
        fut = asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop)
×
1487
        try:
×
1488
            chan, funding_tx = fut.result()
×
1489
        except concurrent.futures.TimeoutError:
×
1490
            raise Exception(_("open_channel timed out"))
×
1491
        return chan, funding_tx
×
1492

1493
    def get_channel_by_short_id(self, short_channel_id: bytes) -> Optional[Channel]:
4✔
1494
        # First check against *real* SCIDs.
1495
        # This e.g. protects against maliciously chosen SCID aliases, and accidental collisions.
1496
        for chan in self.channels.values():
×
1497
            if chan.short_channel_id == short_channel_id:
×
1498
                return chan
×
1499
        # Now we also consider aliases.
1500
        # TODO we should split this as this search currently ignores the "direction"
1501
        #      of the aliases. We should only look at either the remote OR the local alias,
1502
        #      depending on context.
1503
        for chan in self.channels.values():
×
1504
            if chan.get_remote_scid_alias() == short_channel_id:
×
1505
                return chan
×
1506
            if chan.get_local_scid_alias() == short_channel_id:
×
1507
                return chan
×
1508

1509
    def can_pay_invoice(self, invoice: Invoice) -> bool:
4✔
1510
        assert invoice.is_lightning()
×
1511
        return (invoice.get_amount_sat() or 0) <= self.num_sats_can_send()
×
1512

1513
    @log_exceptions
4✔
1514
    async def pay_invoice(
4✔
1515
            self, invoice: Invoice, *,
1516
            amount_msat: int = None,
1517
            attempts: int = None,  # used only in unit tests
1518
            full_path: LNPaymentPath = None,
1519
            channels: Optional[Sequence[Channel]] = None,
1520
    ) -> Tuple[bool, List[HtlcLog]]:
1521
        bolt11 = invoice.lightning_invoice
4✔
1522
        lnaddr = self._check_bolt11_invoice(bolt11, amount_msat=amount_msat)
4✔
1523
        min_final_cltv_delta = lnaddr.get_min_final_cltv_delta()
4✔
1524
        payment_hash = lnaddr.paymenthash
4✔
1525
        key = payment_hash.hex()
4✔
1526
        payment_secret = lnaddr.payment_secret
4✔
1527
        invoice_pubkey = lnaddr.pubkey.serialize()
4✔
1528
        invoice_features = lnaddr.get_features()
4✔
1529
        r_tags = lnaddr.get_routing_info('r')
4✔
1530
        amount_to_pay = lnaddr.get_amount_msat()
4✔
1531
        status = self.get_payment_status(payment_hash)
4✔
1532
        if status == PR_PAID:
4✔
1533
            raise PaymentFailure(_("This invoice has been paid already"))
×
1534
        if status == PR_INFLIGHT:
4✔
1535
            raise PaymentFailure(_("A payment was already initiated for this invoice"))
×
1536
        if payment_hash in self.get_payments(status='inflight'):
4✔
1537
            raise PaymentFailure(_("A previous attempt to pay this invoice did not clear"))
×
1538
        info = PaymentInfo(payment_hash, amount_to_pay, SENT, PR_UNPAID)
4✔
1539
        self.save_payment_info(info)
4✔
1540
        self.wallet.set_label(key, lnaddr.get_description())
4✔
1541
        self.set_invoice_status(key, PR_INFLIGHT)
4✔
1542
        budget = PaymentFeeBudget.default(invoice_amount_msat=amount_to_pay, config=self.config)
4✔
1543
        if attempts is None and self.uses_trampoline():
4✔
1544
            # we don't expect lots of failed htlcs with trampoline, so we can fail sooner
1545
            attempts = 30
4✔
1546
        success = False
4✔
1547
        try:
4✔
1548
            await self.pay_to_node(
4✔
1549
                node_pubkey=invoice_pubkey,
1550
                payment_hash=payment_hash,
1551
                payment_secret=payment_secret,
1552
                amount_to_pay=amount_to_pay,
1553
                min_final_cltv_delta=min_final_cltv_delta,
1554
                r_tags=r_tags,
1555
                invoice_features=invoice_features,
1556
                attempts=attempts,
1557
                full_path=full_path,
1558
                channels=channels,
1559
                budget=budget,
1560
            )
1561
            success = True
4✔
1562
        except PaymentFailure as e:
4✔
1563
            self.logger.info(f'payment failure: {e!r}')
4✔
1564
            reason = str(e)
4✔
1565
        except ChannelDBNotLoaded as e:
4✔
1566
            self.logger.info(f'payment failure: {e!r}')
×
1567
            reason = str(e)
×
1568
        finally:
1569
            self.logger.info(f"pay_invoice ending session for RHASH={payment_hash.hex()}. {success=}")
4✔
1570
        if success:
4✔
1571
            self.set_invoice_status(key, PR_PAID)
4✔
1572
            util.trigger_callback('payment_succeeded', self.wallet, key)
4✔
1573
        else:
1574
            self.set_invoice_status(key, PR_UNPAID)
4✔
1575
            util.trigger_callback('payment_failed', self.wallet, key, reason)
4✔
1576
        log = self.logs[key]
4✔
1577
        return success, log
4✔
1578

1579
    async def pay_to_node(
4✔
1580
            self, *,
1581
            node_pubkey: bytes,
1582
            payment_hash: bytes,
1583
            payment_secret: bytes,
1584
            amount_to_pay: int,  # in msat
1585
            min_final_cltv_delta: int,
1586
            r_tags,
1587
            invoice_features: int,
1588
            attempts: int = None,
1589
            full_path: LNPaymentPath = None,
1590
            fwd_trampoline_onion: OnionPacket = None,
1591
            budget: PaymentFeeBudget,
1592
            channels: Optional[Sequence[Channel]] = None,
1593
            fw_payment_key: str = None,  # for forwarding
1594
    ) -> None:
1595

1596
        assert budget
4✔
1597
        assert budget.fee_msat >= 0, budget
4✔
1598
        assert budget.cltv >= 0, budget
4✔
1599

1600
        payment_key = payment_hash + payment_secret
4✔
1601
        assert payment_key not in self._paysessions
4✔
1602
        self._paysessions[payment_key] = paysession = PaySession(
4✔
1603
            payment_hash=payment_hash,
1604
            payment_secret=payment_secret,
1605
            initial_trampoline_fee_level=self.config.INITIAL_TRAMPOLINE_FEE_LEVEL,
1606
            invoice_features=invoice_features,
1607
            r_tags=r_tags,
1608
            min_final_cltv_delta=min_final_cltv_delta,
1609
            amount_to_pay=amount_to_pay,
1610
            invoice_pubkey=node_pubkey,
1611
            uses_trampoline=self.uses_trampoline(),
1612
            # the config option to use two trampoline hops for legacy payments has been removed as
1613
            # the trampoline onion is too small (400 bytes) to accommodate two trampoline hops and
1614
            # routing hints, making the functionality unusable for payments that require routing hints.
1615
            # TODO: if you read this, the year is 2027 and there is no use for the second trampoline
1616
            # hop code anymore remove the code completely.
1617
            use_two_trampolines=False,
1618
        )
1619
        self.logs[payment_hash.hex()] = log = []  # TODO incl payment_secret in key (re trampoline forwarding)
4✔
1620

1621
        paysession.logger.info(
4✔
1622
            f"pay_to_node starting session for RHASH={payment_hash.hex()}. "
1623
            f"using_trampoline={self.uses_trampoline()}. "
1624
            f"invoice_features={paysession.invoice_features.get_names()}. "
1625
            f"{amount_to_pay=} msat. {budget=}")
1626
        if not self.uses_trampoline():
4✔
1627
            self.logger.info(
4✔
1628
                f"gossip_db status. sync progress: {self.network.lngossip.get_sync_progress_estimate()}. "
1629
                f"num_nodes={self.channel_db.num_nodes}, "
1630
                f"num_channels={self.channel_db.num_channels}, "
1631
                f"num_policies={self.channel_db.num_policies}.")
1632

1633
        # when encountering trampoline forwarding difficulties in the legacy case, we
1634
        # sometimes need to fall back to a single trampoline forwarder, at the expense
1635
        # of privacy
1636
        try:
4✔
1637
            while True:
4✔
1638
                if (amount_to_send := paysession.get_outstanding_amount_to_send()) > 0:
4✔
1639
                    # 1. create a set of routes for remaining amount.
1640
                    # note: path-finding runs in a separate thread so that we don't block the asyncio loop
1641
                    # graph updates might occur during the computation
1642
                    remaining_fee_budget_msat = (budget.fee_msat * amount_to_send) // amount_to_pay
4✔
1643
                    routes = self.create_routes_for_payment(
4✔
1644
                        paysession=paysession,
1645
                        amount_msat=amount_to_send,
1646
                        full_path=full_path,
1647
                        fwd_trampoline_onion=fwd_trampoline_onion,
1648
                        channels=channels,
1649
                        budget=budget._replace(fee_msat=remaining_fee_budget_msat),
1650
                    )
1651
                    # 2. send htlcs
1652
                    async for sent_htlc_info, cltv_delta, trampoline_onion in routes:
4✔
1653
                        await self.pay_to_route(
4✔
1654
                            paysession=paysession,
1655
                            sent_htlc_info=sent_htlc_info,
1656
                            min_final_cltv_delta=cltv_delta,
1657
                            trampoline_onion=trampoline_onion,
1658
                            fw_payment_key=fw_payment_key,
1659
                        )
1660
                    # invoice_status is triggered in self.set_invoice_status when it actually changes.
1661
                    # It is also triggered here to update progress for a lightning payment in the GUI
1662
                    # (e.g. attempt counter)
1663
                    util.trigger_callback('invoice_status', self.wallet, payment_hash.hex(), PR_INFLIGHT)
4✔
1664
                # 3. await a queue
1665
                htlc_log = await paysession.wait_for_one_htlc_to_resolve()  # TODO maybe wait a bit, more failures might come
4✔
1666
                log.append(htlc_log)
4✔
1667
                if htlc_log.success:
4✔
1668
                    if self.network.path_finder:
4✔
1669
                        # TODO: report every route to liquidity hints for mpp
1670
                        # in the case of success, we report channels of the
1671
                        # route as being able to send the same amount in the future,
1672
                        # as we assume to not know the capacity
1673
                        self.network.path_finder.update_liquidity_hints(htlc_log.route, htlc_log.amount_msat)
4✔
1674
                        # remove inflight htlcs from liquidity hints
1675
                        self.network.path_finder.update_inflight_htlcs(htlc_log.route, add_htlcs=False)
4✔
1676
                    return
4✔
1677
                # htlc failed
1678
                # if we get a tmp channel failure, it might work to split the amount and try more routes
1679
                # if we get a channel update, we might retry the same route and amount
1680
                route = htlc_log.route
4✔
1681
                sender_idx = htlc_log.sender_idx
4✔
1682
                failure_msg = htlc_log.failure_msg
4✔
1683
                if sender_idx is None:
4✔
1684
                    raise PaymentFailure(failure_msg.code_name())
4✔
1685
                erring_node_id = route[sender_idx].node_id
4✔
1686
                code, data = failure_msg.code, failure_msg.data
4✔
1687
                self.logger.info(f"UPDATE_FAIL_HTLC. code={repr(code)}. "
4✔
1688
                                 f"decoded_data={failure_msg.decode_data()}. data={data.hex()!r}")
1689
                self.logger.info(f"error reported by {erring_node_id.hex()}")
4✔
1690
                if code == OnionFailureCode.MPP_TIMEOUT:
4✔
1691
                    raise PaymentFailure(failure_msg.code_name())
4✔
1692
                # errors returned by the next trampoline.
1693
                if fwd_trampoline_onion and code in [
4✔
1694
                        OnionFailureCode.TRAMPOLINE_FEE_INSUFFICIENT,
1695
                        OnionFailureCode.TRAMPOLINE_EXPIRY_TOO_SOON]:
1696
                    raise failure_msg
×
1697
                # trampoline
1698
                if self.uses_trampoline():
4✔
1699
                    paysession.handle_failed_trampoline_htlc(
4✔
1700
                        htlc_log=htlc_log, failure_msg=failure_msg)
1701
                else:
1702
                    self.handle_error_code_from_failed_htlc(
4✔
1703
                        route=route, sender_idx=sender_idx, failure_msg=failure_msg, amount=htlc_log.amount_msat)
1704
                # max attempts or timeout
1705
                if (attempts is not None and len(log) >= attempts) or (attempts is None and time.time() - paysession.start_time > self.PAYMENT_TIMEOUT):
4✔
1706
                    raise PaymentFailure('Giving up after %d attempts'%len(log))
4✔
1707
        finally:
1708
            paysession.is_active = False
4✔
1709
            if paysession.can_be_deleted():
4✔
1710
                self._paysessions.pop(payment_key)
4✔
1711
            paysession.logger.info(f"pay_to_node ending session for RHASH={payment_hash.hex()}")
4✔
1712

1713
    async def pay_to_route(
4✔
1714
            self, *,
1715
            paysession: PaySession,
1716
            sent_htlc_info: SentHtlcInfo,
1717
            min_final_cltv_delta: int,
1718
            trampoline_onion: Optional[OnionPacket] = None,
1719
            fw_payment_key: str = None,
1720
    ) -> None:
1721
        """Sends a single HTLC."""
1722
        shi = sent_htlc_info
4✔
1723
        del sent_htlc_info  # just renamed
4✔
1724
        short_channel_id = shi.route[0].short_channel_id
4✔
1725
        chan = self.get_channel_by_short_id(short_channel_id)
4✔
1726
        assert chan, ShortChannelID(short_channel_id)
4✔
1727
        peer = self._peers.get(shi.route[0].node_id)
4✔
1728
        if not peer:
4✔
1729
            raise PaymentFailure('Dropped peer')
×
1730
        await peer.initialized
4✔
1731
        htlc = peer.pay(
4✔
1732
            route=shi.route,
1733
            chan=chan,
1734
            amount_msat=shi.amount_msat,
1735
            total_msat=shi.bucket_msat,
1736
            payment_hash=paysession.payment_hash,
1737
            min_final_cltv_delta=min_final_cltv_delta,
1738
            payment_secret=shi.payment_secret_bucket,
1739
            trampoline_onion=trampoline_onion)
1740

1741
        key = (paysession.payment_hash, short_channel_id, htlc.htlc_id)
4✔
1742
        self.sent_htlcs_info[key] = shi
4✔
1743
        paysession.add_new_htlc(shi)
4✔
1744
        if fw_payment_key:
4✔
1745
            htlc_key = serialize_htlc_key(short_channel_id, htlc.htlc_id)
4✔
1746
            self.logger.info(f'adding active forwarding {fw_payment_key}')
4✔
1747
            self.active_forwardings[fw_payment_key].append(htlc_key)
4✔
1748
        if self.network.path_finder:
4✔
1749
            # add inflight htlcs to liquidity hints
1750
            self.network.path_finder.update_inflight_htlcs(shi.route, add_htlcs=True)
4✔
1751
        util.trigger_callback('htlc_added', chan, htlc, SENT)
4✔
1752

1753
    def handle_error_code_from_failed_htlc(
4✔
1754
            self,
1755
            *,
1756
            route: LNPaymentRoute,
1757
            sender_idx: int,
1758
            failure_msg: OnionRoutingFailure,
1759
            amount: int) -> None:
1760

1761
        assert self.channel_db  # cannot be in trampoline mode
4✔
1762
        assert self.network.path_finder
4✔
1763

1764
        # remove inflight htlcs from liquidity hints
1765
        self.network.path_finder.update_inflight_htlcs(route, add_htlcs=False)
4✔
1766

1767
        code, data = failure_msg.code, failure_msg.data
4✔
1768
        # TODO can we use lnmsg.OnionWireSerializer here?
1769
        # TODO update onion_wire.csv
1770
        # handle some specific error codes
1771
        failure_codes = {
4✔
1772
            OnionFailureCode.TEMPORARY_CHANNEL_FAILURE: 0,
1773
            OnionFailureCode.AMOUNT_BELOW_MINIMUM: 8,
1774
            OnionFailureCode.FEE_INSUFFICIENT: 8,
1775
            OnionFailureCode.INCORRECT_CLTV_EXPIRY: 4,
1776
            OnionFailureCode.EXPIRY_TOO_SOON: 0,
1777
            OnionFailureCode.CHANNEL_DISABLED: 2,
1778
        }
1779
        try:
4✔
1780
            failing_channel = route[sender_idx + 1].short_channel_id
4✔
1781
        except IndexError:
4✔
1782
            raise PaymentFailure(f'payment destination reported error: {failure_msg.code_name()}') from None
4✔
1783

1784
        # TODO: handle unknown next peer?
1785
        # handle failure codes that include a channel update
1786
        if code in failure_codes:
4✔
1787
            offset = failure_codes[code]
4✔
1788
            channel_update_len = int.from_bytes(data[offset:offset+2], byteorder="big")
4✔
1789
            channel_update_as_received = data[offset+2: offset+2+channel_update_len]
4✔
1790
            payload = self._decode_channel_update_msg(channel_update_as_received)
4✔
1791
            if payload is None:
4✔
1792
                self.logger.info(f'could not decode channel_update for failed htlc: '
×
1793
                                 f'{channel_update_as_received.hex()}')
1794
                blacklist = True
×
1795
            elif payload.get('short_channel_id') != failing_channel:
4✔
1796
                self.logger.info(f'short_channel_id in channel_update does not match our route')
×
1797
                blacklist = True
×
1798
            else:
1799
                # apply the channel update or get blacklisted
1800
                blacklist, update = self._handle_chanupd_from_failed_htlc(
4✔
1801
                    payload, route=route, sender_idx=sender_idx, failure_msg=failure_msg)
1802
                # we interpret a temporary channel failure as a liquidity issue
1803
                # in the channel and update our liquidity hints accordingly
1804
                if code == OnionFailureCode.TEMPORARY_CHANNEL_FAILURE:
4✔
1805
                    self.network.path_finder.update_liquidity_hints(
4✔
1806
                        route,
1807
                        amount,
1808
                        failing_channel=ShortChannelID(failing_channel))
1809
                # if we can't decide on some action, we are stuck
1810
                if not (blacklist or update):
4✔
1811
                    raise PaymentFailure(failure_msg.code_name())
×
1812
        # for errors that do not include a channel update
1813
        else:
1814
            blacklist = True
4✔
1815
        if blacklist:
4✔
1816
            self.network.path_finder.add_edge_to_blacklist(short_channel_id=failing_channel)
4✔
1817

1818
    def _handle_chanupd_from_failed_htlc(
4✔
1819
        self, payload, *,
1820
        route: LNPaymentRoute,
1821
        sender_idx: int,
1822
        failure_msg: OnionRoutingFailure,
1823
    ) -> Tuple[bool, bool]:
1824
        blacklist = False
4✔
1825
        update = False
4✔
1826
        try:
4✔
1827
            r = self.channel_db.add_channel_update(payload, verify=True)
4✔
1828
        except InvalidGossipMsg:
×
1829
            return True, False  # blacklist
×
1830
        short_channel_id = ShortChannelID(payload['short_channel_id'])
4✔
1831
        if r == UpdateStatus.GOOD:
4✔
1832
            self.logger.info(f"applied channel update to {short_channel_id}")
×
1833
            # TODO: add test for this
1834
            # FIXME: this does not work for our own unannounced channels.
1835
            for chan in self.channels.values():
×
1836
                if chan.short_channel_id == short_channel_id:
×
1837
                    chan.set_remote_update(payload)
×
1838
            update = True
×
1839
        elif r == UpdateStatus.ORPHANED:
4✔
1840
            # maybe it is a private channel (and data in invoice was outdated)
1841
            self.logger.info(f"Could not find {short_channel_id}. maybe update is for private channel?")
4✔
1842
            start_node_id = route[sender_idx].node_id
4✔
1843
            cache_ttl = None
4✔
1844
            if failure_msg.code == OnionFailureCode.CHANNEL_DISABLED:
4✔
1845
                # eclair sends CHANNEL_DISABLED if its peer is offline. E.g. we might be trying to pay
1846
                # a mobile phone with the app closed. So we cache this with a short TTL.
1847
                cache_ttl = self.channel_db.PRIVATE_CHAN_UPD_CACHE_TTL_SHORT
×
1848
            update = self.channel_db.add_channel_update_for_private_channel(payload, start_node_id, cache_ttl=cache_ttl)
4✔
1849
            blacklist = not update
4✔
1850
        elif r == UpdateStatus.EXPIRED:
×
1851
            blacklist = True
×
1852
        elif r == UpdateStatus.DEPRECATED:
×
1853
            self.logger.info(f'channel update is not more recent.')
×
1854
            blacklist = True
×
1855
        elif r == UpdateStatus.UNCHANGED:
×
1856
            blacklist = True
×
1857
        return blacklist, update
4✔
1858

1859
    @classmethod
4✔
1860
    def _decode_channel_update_msg(cls, chan_upd_msg: bytes) -> Optional[Dict[str, Any]]:
4✔
1861
        channel_update_as_received = chan_upd_msg
4✔
1862
        channel_update_typed = (258).to_bytes(length=2, byteorder="big") + channel_update_as_received
4✔
1863
        # note: some nodes put channel updates in error msgs with the leading msg_type already there.
1864
        #       we try decoding both ways here.
1865
        try:
4✔
1866
            message_type, payload = decode_msg(channel_update_typed)
4✔
1867
            if payload['chain_hash'] != constants.net.rev_genesis_bytes(): raise Exception()
4✔
1868
            payload['raw'] = channel_update_typed
4✔
1869
            return payload
4✔
1870
        except Exception:  # FIXME: too broad
4✔
1871
            try:
4✔
1872
                message_type, payload = decode_msg(channel_update_as_received)
4✔
1873
                if payload['chain_hash'] != constants.net.rev_genesis_bytes(): raise Exception()
4✔
1874
                payload['raw'] = channel_update_as_received
4✔
1875
                return payload
4✔
1876
            except Exception:
4✔
1877
                return None
4✔
1878

1879
    def _check_bolt11_invoice(self, bolt11_invoice: str, *, amount_msat: int = None) -> LnAddr:
4✔
1880
        """Parses and validates a bolt11 invoice str into a LnAddr.
1881
        Includes pre-payment checks external to the parser.
1882
        """
1883
        addr = lndecode(bolt11_invoice)
4✔
1884
        if addr.is_expired():
4✔
1885
            raise InvoiceError(_("This invoice has expired"))
×
1886
        # check amount
1887
        if amount_msat:  # replace amt in invoice. main usecase is paying zero amt invoices
4✔
1888
            existing_amt_msat = addr.get_amount_msat()
×
1889
            if existing_amt_msat and amount_msat < existing_amt_msat:
×
1890
                raise Exception("cannot pay lower amt than what is originally in LN invoice")
×
1891
            addr.amount = Decimal(amount_msat) / COIN / 1000
×
1892
        if addr.amount is None:
4✔
1893
            raise InvoiceError(_("Missing amount"))
×
1894
        # check cltv
1895
        if addr.get_min_final_cltv_delta() > NBLOCK_CLTV_DELTA_TOO_FAR_INTO_FUTURE:
4✔
1896
            raise InvoiceError("{}\n{}".format(
4✔
1897
                _("Invoice wants us to risk locking funds for unreasonably long."),
1898
                f"min_final_cltv_delta: {addr.get_min_final_cltv_delta()}"))
1899
        # check features
1900
        addr.validate_and_compare_features(self.features)
4✔
1901
        return addr
4✔
1902

1903
    def is_trampoline_peer(self, node_id: bytes) -> bool:
4✔
1904
        # until trampoline is advertised in lnfeatures, check against hardcoded list
1905
        if is_hardcoded_trampoline(node_id):
4✔
1906
            return True
4✔
1907
        peer = self._peers.get(node_id)
×
1908
        if not peer:
×
1909
            return False
×
1910
        return (peer.their_features.supports(LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ECLAIR)\
×
1911
                or peer.their_features.supports(LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ELECTRUM))
1912

1913
    def suggest_peer(self) -> Optional[bytes]:
4✔
1914
        if not self.uses_trampoline():
×
1915
            return self.lnrater.suggest_peer()
×
1916
        else:
1917
            return random.choice(list(hardcoded_trampoline_nodes().values())).pubkey
×
1918

1919
    def suggest_payment_splits(
4✔
1920
        self,
1921
        *,
1922
        amount_msat: int,
1923
        final_total_msat: int,
1924
        my_active_channels: Sequence[Channel],
1925
        invoice_features: LnFeatures,
1926
        r_tags: Sequence[Sequence[Sequence[Any]]],
1927
        receiver_pubkey: bytes,
1928
    ) -> List['SplitConfigRating']:
1929
        channels_with_funds = {
4✔
1930
            (chan.channel_id, chan.node_id): int(chan.available_to_spend(HTLCOwner.LOCAL))
1931
            for chan in my_active_channels
1932
        }
1933
        # if we have a direct channel it's preferrable to send a single part directly through this
1934
        # channel, so this bool will disable excluding single part payments
1935
        have_direct_channel = any(chan.node_id == receiver_pubkey for chan in my_active_channels)
4✔
1936
        self.logger.info(f"channels_with_funds: {channels_with_funds}, {have_direct_channel=}")
4✔
1937
        exclude_single_part_payments = False
4✔
1938
        if self.uses_trampoline():
4✔
1939
            # in the case of a legacy payment, we don't allow splitting via different
1940
            # trampoline nodes, because of https://github.com/ACINQ/eclair/issues/2127
1941
            is_legacy, _ = is_legacy_relay(invoice_features, r_tags)
4✔
1942
            exclude_multinode_payments = is_legacy
4✔
1943
            # we don't split within a channel when sending to a trampoline node,
1944
            # the trampoline node will split for us
1945
            exclude_single_channel_splits = True
4✔
1946
        else:
1947
            exclude_multinode_payments = False
4✔
1948
            exclude_single_channel_splits = False
4✔
1949
            if invoice_features.supports(LnFeatures.BASIC_MPP_OPT) and not self.config.TEST_FORCE_DISABLE_MPP:
4✔
1950
                # if amt is still large compared to total_msat, split it:
1951
                if (amount_msat / final_total_msat > self.MPP_SPLIT_PART_FRACTION
4✔
1952
                        and amount_msat > self.MPP_SPLIT_PART_MINAMT_MSAT
1953
                        and not have_direct_channel):
1954
                    exclude_single_part_payments = True
×
1955

1956
        split_configurations = suggest_splits(
4✔
1957
            amount_msat,
1958
            channels_with_funds,
1959
            exclude_single_part_payments=exclude_single_part_payments,
1960
            exclude_multinode_payments=exclude_multinode_payments,
1961
            exclude_single_channel_splits=exclude_single_channel_splits
1962
        )
1963

1964
        self.logger.info(f'suggest_split {amount_msat} returned {len(split_configurations)} configurations')
4✔
1965
        return split_configurations
4✔
1966

1967
    async def create_routes_for_payment(
4✔
1968
            self, *,
1969
            paysession: PaySession,
1970
            amount_msat: int,        # part of payment amount we want routes for now
1971
            fwd_trampoline_onion: OnionPacket = None,
1972
            full_path: LNPaymentPath = None,
1973
            channels: Optional[Sequence[Channel]] = None,
1974
            budget: PaymentFeeBudget,
1975
    ) -> AsyncGenerator[Tuple[SentHtlcInfo, int, Optional[OnionPacket]], None]:
1976

1977
        """Creates multiple routes for splitting a payment over the available
1978
        private channels.
1979

1980
        We first try to conduct the payment over a single channel. If that fails
1981
        and mpp is supported by the receiver, we will split the payment."""
1982
        trampoline_features = LnFeatures.VAR_ONION_OPT
4✔
1983
        local_height = self.wallet.adb.get_local_height()
4✔
1984
        fee_related_error = None  # type: Optional[FeeBudgetExceeded]
4✔
1985
        if channels:
4✔
1986
            my_active_channels = channels
×
1987
        else:
1988
            my_active_channels = [
4✔
1989
                chan for chan in self.channels.values() if
1990
                chan.is_active() and not chan.is_frozen_for_sending()]
1991
        # try random order
1992
        random.shuffle(my_active_channels)
4✔
1993
        split_configurations = self.suggest_payment_splits(
4✔
1994
            amount_msat=amount_msat,
1995
            final_total_msat=paysession.amount_to_pay,
1996
            my_active_channels=my_active_channels,
1997
            invoice_features=paysession.invoice_features,
1998
            r_tags=paysession.r_tags,
1999
            receiver_pubkey=paysession.invoice_pubkey,
2000
        )
2001
        for sc in split_configurations:
4✔
2002
            is_multichan_mpp = len(sc.config.items()) > 1
4✔
2003
            is_mpp = sc.config.number_parts() > 1
4✔
2004
            if is_mpp and not paysession.invoice_features.supports(LnFeatures.BASIC_MPP_OPT):
4✔
2005
                continue
4✔
2006
            if not is_mpp and self.config.TEST_FORCE_MPP:
4✔
2007
                continue
4✔
2008
            if is_mpp and self.config.TEST_FORCE_DISABLE_MPP:
4✔
2009
                continue
×
2010
            self.logger.info(f"trying split configuration: {sc.config.values()} rating: {sc.rating}")
4✔
2011
            routes = []
4✔
2012
            try:
4✔
2013
                if self.uses_trampoline():
4✔
2014
                    per_trampoline_channel_amounts = defaultdict(list)
4✔
2015
                    # categorize by trampoline nodes for trampoline mpp construction
2016
                    for (chan_id, _), part_amounts_msat in sc.config.items():
4✔
2017
                        chan = self.channels[chan_id]
4✔
2018
                        for part_amount_msat in part_amounts_msat:
4✔
2019
                            per_trampoline_channel_amounts[chan.node_id].append((chan_id, part_amount_msat))
4✔
2020
                    # for each trampoline forwarder, construct mpp trampoline
2021
                    for trampoline_node_id, trampoline_parts in per_trampoline_channel_amounts.items():
4✔
2022
                        per_trampoline_amount = sum([x[1] for x in trampoline_parts])
4✔
2023
                        trampoline_route, trampoline_onion, per_trampoline_amount_with_fees, per_trampoline_cltv_delta = create_trampoline_route_and_onion(
4✔
2024
                            amount_msat=per_trampoline_amount,
2025
                            total_msat=paysession.amount_to_pay,
2026
                            min_final_cltv_delta=paysession.min_final_cltv_delta,
2027
                            my_pubkey=self.node_keypair.pubkey,
2028
                            invoice_pubkey=paysession.invoice_pubkey,
2029
                            invoice_features=paysession.invoice_features,
2030
                            node_id=trampoline_node_id,
2031
                            r_tags=paysession.r_tags,
2032
                            payment_hash=paysession.payment_hash,
2033
                            payment_secret=paysession.payment_secret,
2034
                            local_height=local_height,
2035
                            trampoline_fee_level=paysession.trampoline_fee_level,
2036
                            use_two_trampolines=paysession.use_two_trampolines,
2037
                            failed_routes=paysession.failed_trampoline_routes,
2038
                            budget=budget._replace(fee_msat=budget.fee_msat // len(per_trampoline_channel_amounts)),
2039
                        )
2040
                        # node_features is only used to determine is_tlv
2041
                        per_trampoline_secret = os.urandom(32)
4✔
2042
                        per_trampoline_fees = per_trampoline_amount_with_fees - per_trampoline_amount
4✔
2043
                        self.logger.info(f'created route with trampoline fee level={paysession.trampoline_fee_level}')
4✔
2044
                        self.logger.info(f'trampoline hops: {[hop.end_node.hex() for hop in trampoline_route]}')
4✔
2045
                        self.logger.info(f'per trampoline fees: {per_trampoline_fees}')
4✔
2046
                        for chan_id, part_amount_msat in trampoline_parts:
4✔
2047
                            chan = self.channels[chan_id]
4✔
2048
                            margin = chan.available_to_spend(LOCAL, strict=True) - part_amount_msat
4✔
2049
                            delta_fee = min(per_trampoline_fees, margin)
4✔
2050
                            # TODO: distribute trampoline fee over several channels?
2051
                            part_amount_msat_with_fees = part_amount_msat + delta_fee
4✔
2052
                            per_trampoline_fees -= delta_fee
4✔
2053
                            route = [
4✔
2054
                                RouteEdge(
2055
                                    start_node=self.node_keypair.pubkey,
2056
                                    end_node=trampoline_node_id,
2057
                                    short_channel_id=chan.short_channel_id,
2058
                                    fee_base_msat=0,
2059
                                    fee_proportional_millionths=0,
2060
                                    cltv_delta=0,
2061
                                    node_features=trampoline_features)
2062
                            ]
2063
                            self.logger.info(f'adding route {part_amount_msat} {delta_fee} {margin}')
4✔
2064
                            shi = SentHtlcInfo(
4✔
2065
                                route=route,
2066
                                payment_secret_orig=paysession.payment_secret,
2067
                                payment_secret_bucket=per_trampoline_secret,
2068
                                amount_msat=part_amount_msat_with_fees,
2069
                                bucket_msat=per_trampoline_amount_with_fees,
2070
                                amount_receiver_msat=part_amount_msat,
2071
                                trampoline_fee_level=paysession.trampoline_fee_level,
2072
                                trampoline_route=trampoline_route,
2073
                            )
2074
                            routes.append((shi, per_trampoline_cltv_delta, trampoline_onion))
4✔
2075
                        if per_trampoline_fees != 0:
4✔
2076
                            e = 'not enough margin to pay trampoline fee'
×
2077
                            self.logger.info(e)
×
2078
                            raise FeeBudgetExceeded(e)
×
2079
                else:
2080
                    # We atomically loop through a split configuration. If there was
2081
                    # a failure to find a path for a single part, we try the next configuration
2082
                    for (chan_id, _), part_amounts_msat in sc.config.items():
4✔
2083
                        for part_amount_msat in part_amounts_msat:
4✔
2084
                            channel = self.channels[chan_id]
4✔
2085
                            route = await run_in_thread(
4✔
2086
                                partial(
2087
                                    self.create_route_for_single_htlc,
2088
                                    amount_msat=part_amount_msat,
2089
                                    invoice_pubkey=paysession.invoice_pubkey,
2090
                                    min_final_cltv_delta=paysession.min_final_cltv_delta,
2091
                                    r_tags=paysession.r_tags,
2092
                                    invoice_features=paysession.invoice_features,
2093
                                    my_sending_channels=[channel] if is_multichan_mpp else my_active_channels,
2094
                                    full_path=full_path,
2095
                                    budget=budget._replace(fee_msat=budget.fee_msat // sc.config.number_parts()),
2096
                                )
2097
                            )
2098
                            shi = SentHtlcInfo(
4✔
2099
                                route=route,
2100
                                payment_secret_orig=paysession.payment_secret,
2101
                                payment_secret_bucket=paysession.payment_secret,
2102
                                amount_msat=part_amount_msat,
2103
                                bucket_msat=paysession.amount_to_pay,
2104
                                amount_receiver_msat=part_amount_msat,
2105
                                trampoline_fee_level=None,
2106
                                trampoline_route=None,
2107
                            )
2108
                            routes.append((shi, paysession.min_final_cltv_delta, fwd_trampoline_onion))
4✔
2109
            except NoPathFound:
4✔
2110
                continue
4✔
2111
            except FeeBudgetExceeded as e:
4✔
2112
                fee_related_error = e
×
2113
                continue
×
2114
            for route in routes:
4✔
2115
                yield route
4✔
2116
            return
4✔
2117
        if fee_related_error is not None:
4✔
2118
            raise fee_related_error
×
2119
        raise NoPathFound()
4✔
2120

2121
    @profiler
4✔
2122
    def create_route_for_single_htlc(
4✔
2123
            self, *,
2124
            amount_msat: int,  # that final receiver gets
2125
            invoice_pubkey: bytes,
2126
            min_final_cltv_delta: int,
2127
            r_tags,
2128
            invoice_features: int,
2129
            my_sending_channels: List[Channel],
2130
            full_path: Optional[LNPaymentPath],
2131
            budget: PaymentFeeBudget,
2132
    ) -> LNPaymentRoute:
2133

2134
        my_sending_aliases = set(chan.get_local_scid_alias() for chan in my_sending_channels)
4✔
2135
        my_sending_channels = {chan.short_channel_id: chan for chan in my_sending_channels
4✔
2136
            if chan.short_channel_id is not None}
2137
        # Collect all private edges from route hints.
2138
        # Note: if some route hints are multiple edges long, and these paths cross each other,
2139
        #       we allow our path finding to cross the paths; i.e. the route hints are not isolated.
2140
        private_route_edges = {}  # type: Dict[ShortChannelID, RouteEdge]
4✔
2141
        for private_path in r_tags:
4✔
2142
            # we need to shift the node pubkey by one towards the destination:
2143
            private_path_nodes = [edge[0] for edge in private_path][1:] + [invoice_pubkey]
4✔
2144
            private_path_rest = [edge[1:] for edge in private_path]
4✔
2145
            start_node = private_path[0][0]
4✔
2146
            # remove aliases from direct routes
2147
            if len(private_path) == 1 and private_path[0][1] in my_sending_aliases:
4✔
2148
                self.logger.info(f'create_route: skipping alias {ShortChannelID(private_path[0][1])}')
×
2149
                continue
×
2150
            for end_node, edge_rest in zip(private_path_nodes, private_path_rest):
4✔
2151
                short_channel_id, fee_base_msat, fee_proportional_millionths, cltv_delta = edge_rest
4✔
2152
                short_channel_id = ShortChannelID(short_channel_id)
4✔
2153
                if (our_chan := self.get_channel_by_short_id(short_channel_id)) is not None:
4✔
2154
                    # check if the channel is one of our channels and frozen for sending
2155
                    if our_chan.is_frozen_for_sending():
4✔
2156
                        continue
×
2157
                # if we have a routing policy for this edge in the db, that takes precedence,
2158
                # as it is likely from a previous failure
2159
                channel_policy = self.channel_db.get_policy_for_node(
4✔
2160
                    short_channel_id=short_channel_id,
2161
                    node_id=start_node,
2162
                    my_channels=my_sending_channels)
2163
                if channel_policy:
4✔
2164
                    fee_base_msat = channel_policy.fee_base_msat
4✔
2165
                    fee_proportional_millionths = channel_policy.fee_proportional_millionths
4✔
2166
                    cltv_delta = channel_policy.cltv_delta
4✔
2167
                node_info = self.channel_db.get_node_info_for_node_id(node_id=end_node)
4✔
2168
                route_edge = RouteEdge(
4✔
2169
                        start_node=start_node,
2170
                        end_node=end_node,
2171
                        short_channel_id=short_channel_id,
2172
                        fee_base_msat=fee_base_msat,
2173
                        fee_proportional_millionths=fee_proportional_millionths,
2174
                        cltv_delta=cltv_delta,
2175
                        node_features=node_info.features if node_info else 0)
2176
                private_route_edges[route_edge.short_channel_id] = route_edge
4✔
2177
                start_node = end_node
4✔
2178
        # now find a route, end to end: between us and the recipient
2179
        try:
4✔
2180
            route = self.network.path_finder.find_route(
4✔
2181
                nodeA=self.node_keypair.pubkey,
2182
                nodeB=invoice_pubkey,
2183
                invoice_amount_msat=amount_msat,
2184
                path=full_path,
2185
                my_sending_channels=my_sending_channels,
2186
                private_route_edges=private_route_edges)
2187
        except NoChannelPolicy as e:
4✔
2188
            raise NoPathFound() from e
×
2189
        if not route:
4✔
2190
            raise NoPathFound()
4✔
2191
        if not is_route_within_budget(
4✔
2192
            route, budget=budget, amount_msat_for_dest=amount_msat, cltv_delta_for_dest=min_final_cltv_delta,
2193
        ):
2194
            self.logger.info(f"rejecting route (exceeds budget): {route=}. {budget=}")
×
2195
            raise FeeBudgetExceeded()
×
2196
        assert len(route) > 0
4✔
2197
        if route[-1].end_node != invoice_pubkey:
4✔
2198
            raise LNPathInconsistent("last node_id != invoice pubkey")
4✔
2199
        # add features from invoice
2200
        route[-1].node_features |= invoice_features
4✔
2201
        return route
4✔
2202

2203
    def clear_invoices_cache(self):
4✔
2204
        self._bolt11_cache.clear()
×
2205

2206
    def get_bolt11_invoice(
4✔
2207
            self, *,
2208
            payment_hash: bytes,
2209
            amount_msat: Optional[int],
2210
            message: str,
2211
            expiry: int,  # expiration of invoice (in seconds, relative)
2212
            fallback_address: Optional[str],
2213
            channels: Optional[Sequence[Channel]] = None,
2214
            min_final_cltv_expiry_delta: Optional[int] = None,
2215
    ) -> Tuple[LnAddr, str]:
2216
        assert isinstance(payment_hash, bytes), f"expected bytes, but got {type(payment_hash)}"
×
2217

2218
        pair = self._bolt11_cache.get(payment_hash)
×
2219
        if pair:
×
2220
            lnaddr, invoice = pair
×
2221
            assert lnaddr.get_amount_msat() == amount_msat
×
2222
            return pair
×
2223

2224
        assert amount_msat is None or amount_msat > 0
×
2225
        timestamp = int(time.time())
×
2226
        needs_jit: bool = self.receive_requires_jit_channel(amount_msat)
×
2227
        routing_hints = self.calc_routing_hints_for_invoice(amount_msat, channels=channels, needs_jit=needs_jit)
×
2228
        self.logger.info(f"creating bolt11 invoice with routing_hints: {routing_hints}, jit: {needs_jit}, sat: {amount_msat or 0 // 1000}")
×
2229
        invoice_features = self.features.for_invoice()
×
2230
        if not self.uses_trampoline():
×
2231
            invoice_features &= ~ LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ELECTRUM
×
2232
        if needs_jit:
×
2233
            # jit only works with single htlcs, mpp will cause LSP to open channels for each htlc
2234
            invoice_features &= ~ LnFeatures.BASIC_MPP_OPT & ~ LnFeatures.BASIC_MPP_REQ
×
2235
        payment_secret = self.get_payment_secret(payment_hash)
×
2236
        amount_btc = amount_msat/Decimal(COIN*1000) if amount_msat else None
×
2237
        if expiry == 0:
×
2238
            expiry = LN_EXPIRY_NEVER
×
2239
        if min_final_cltv_expiry_delta is None:
×
2240
            min_final_cltv_expiry_delta = MIN_FINAL_CLTV_DELTA_FOR_INVOICE
×
2241
        lnaddr = LnAddr(
×
2242
            paymenthash=payment_hash,
2243
            amount=amount_btc,
2244
            tags=[
2245
                ('d', message),
2246
                ('c', min_final_cltv_expiry_delta),
2247
                ('x', expiry),
2248
                ('9', invoice_features),
2249
                ('f', fallback_address),
2250
            ] + routing_hints,
2251
            date=timestamp,
2252
            payment_secret=payment_secret)
2253
        invoice = lnencode(lnaddr, self.node_keypair.privkey)
×
2254
        pair = lnaddr, invoice
×
2255
        self._bolt11_cache[payment_hash] = pair
×
2256
        return pair
×
2257

2258
    def get_payment_secret(self, payment_hash):
4✔
2259
        return sha256(sha256(self.payment_secret_key) + payment_hash)
4✔
2260

2261
    def _get_payment_key(self, payment_hash: bytes) -> bytes:
4✔
2262
        """Return payment bucket key.
2263
        We bucket htlcs based on payment_hash+payment_secret. payment_secret is included
2264
        as it changes over a trampoline path (in the outer onion), and these paths can overlap.
2265
        """
2266
        payment_secret = self.get_payment_secret(payment_hash)
4✔
2267
        return payment_hash + payment_secret
4✔
2268

2269
    def create_payment_info(self, *, amount_msat: Optional[int], write_to_disk=True) -> bytes:
4✔
2270
        payment_preimage = os.urandom(32)
4✔
2271
        payment_hash = sha256(payment_preimage)
4✔
2272
        info = PaymentInfo(payment_hash, amount_msat, RECEIVED, PR_UNPAID)
4✔
2273
        self.save_preimage(payment_hash, payment_preimage, write_to_disk=False)
4✔
2274
        self.save_payment_info(info, write_to_disk=False)
4✔
2275
        if write_to_disk:
4✔
2276
            self.wallet.save_db()
×
2277
        return payment_hash
4✔
2278

2279
    def bundle_payments(self, hash_list):
4✔
2280
        payment_keys = [self._get_payment_key(x) for x in hash_list]
4✔
2281
        self.payment_bundles.append(payment_keys)
4✔
2282

2283
    def get_payment_bundle(self, payment_key: bytes) -> Sequence[bytes]:
4✔
2284
        for key_list in self.payment_bundles:
4✔
2285
            if payment_key in key_list:
4✔
2286
                return key_list
4✔
2287

2288
    def save_preimage(self, payment_hash: bytes, preimage: bytes, *, write_to_disk: bool = True):
4✔
2289
        if sha256(preimage) != payment_hash:
4✔
2290
            raise Exception("tried to save incorrect preimage for payment_hash")
×
2291
        self._preimages[payment_hash.hex()] = preimage.hex()
4✔
2292
        if write_to_disk:
4✔
2293
            self.wallet.save_db()
4✔
2294

2295
    def get_preimage(self, payment_hash: bytes) -> Optional[bytes]:
4✔
2296
        assert isinstance(payment_hash, bytes), f"expected bytes, but got {type(payment_hash)}"
4✔
2297
        preimage_hex = self._preimages.get(payment_hash.hex())
4✔
2298
        if preimage_hex is None:
4✔
2299
            return None
4✔
2300
        preimage_bytes = bytes.fromhex(preimage_hex)
4✔
2301
        if sha256(preimage_bytes) != payment_hash:
4✔
UNCOV
2302
            raise Exception("found incorrect preimage for payment_hash")
×
2303
        return preimage_bytes
4✔
2304

2305
    def get_preimage_hex(self, payment_hash: str) -> Optional[str]:
4✔
UNCOV
2306
        preimage_bytes = self.get_preimage(bytes.fromhex(payment_hash)) or b""
×
UNCOV
2307
        return preimage_bytes.hex() or None
×
2308

2309
    def get_payment_info(self, payment_hash: bytes) -> Optional[PaymentInfo]:
4✔
2310
        """returns None if payment_hash is a payment we are forwarding"""
2311
        key = payment_hash.hex()
4✔
2312
        with self.lock:
4✔
2313
            if key in self.payment_info:
4✔
2314
                amount_msat, direction, status = self.payment_info[key]
4✔
2315
                return PaymentInfo(payment_hash, amount_msat, direction, status)
4✔
2316

2317
    def add_payment_info_for_hold_invoice(self, payment_hash: bytes, lightning_amount_sat: int):
4✔
UNCOV
2318
        info = PaymentInfo(payment_hash, lightning_amount_sat * 1000, RECEIVED, PR_UNPAID)
×
UNCOV
2319
        self.save_payment_info(info, write_to_disk=False)
×
2320

2321
    def register_hold_invoice(self, payment_hash: bytes, cb: Callable[[bytes], Awaitable[None]]):
4✔
2322
        self.hold_invoice_callbacks[payment_hash] = cb
4✔
2323

2324
    def unregister_hold_invoice(self, payment_hash: bytes):
4✔
UNCOV
2325
        self.hold_invoice_callbacks.pop(payment_hash)
×
2326

2327
    def save_payment_info(self, info: PaymentInfo, *, write_to_disk: bool = True) -> None:
4✔
2328
        key = info.payment_hash.hex()
4✔
2329
        assert info.status in SAVED_PR_STATUS
4✔
2330
        with self.lock:
4✔
2331
            self.payment_info[key] = info.amount_msat, info.direction, info.status
4✔
2332
        if write_to_disk:
4✔
2333
            self.wallet.save_db()
4✔
2334

2335
    def check_mpp_status(
4✔
2336
            self, *,
2337
            payment_secret: bytes,
2338
            short_channel_id: ShortChannelID,
2339
            htlc: UpdateAddHtlc,
2340
            expected_msat: int,
2341
    ) -> RecvMPPResolution:
2342
        """Returns the status of the incoming htlc set the given *htlc* belongs to.
2343

2344
        ACCEPTED simply means the mpp set is complete, and we can proceed with further
2345
        checks before fulfilling (or failing) the htlcs.
2346
        In particular, note that hold-invoice-htlcs typically remain in the ACCEPTED state
2347
        for quite some time -- not in the "WAITING" state (which would refer to the mpp set
2348
        not yet being complete!).
2349
        """
2350
        payment_hash = htlc.payment_hash
4✔
2351
        payment_key = payment_hash + payment_secret
4✔
2352
        self.update_mpp_with_received_htlc(
4✔
2353
            payment_key=payment_key, scid=short_channel_id, htlc=htlc, expected_msat=expected_msat)
2354
        mpp_resolution = self.received_mpp_htlcs[payment_key.hex()].resolution
4✔
2355
        # if still waiting, calc resolution now:
2356
        if mpp_resolution == RecvMPPResolution.WAITING:
4✔
2357
            bundle = self.get_payment_bundle(payment_key)
4✔
2358
            if bundle:
4✔
2359
                payment_keys = bundle
4✔
2360
            else:
2361
                payment_keys = [payment_key]
4✔
2362
            first_timestamp = min([self.get_first_timestamp_of_mpp(pkey) for pkey in payment_keys])
4✔
2363
            if self.get_payment_status(payment_hash) == PR_PAID:
4✔
UNCOV
2364
                mpp_resolution = RecvMPPResolution.ACCEPTED
×
2365
            elif self.stopping_soon:
4✔
2366
                # try to time out pending HTLCs before shutting down
2367
                mpp_resolution = RecvMPPResolution.EXPIRED
4✔
2368
            elif all([self.is_mpp_amount_reached(pkey) for pkey in payment_keys]):
4✔
2369
                mpp_resolution = RecvMPPResolution.ACCEPTED
4✔
2370
            elif time.time() - first_timestamp > self.MPP_EXPIRY:
4✔
2371
                mpp_resolution = RecvMPPResolution.EXPIRED
4✔
2372
            # save resolution, if any.
2373
            if mpp_resolution != RecvMPPResolution.WAITING:
4✔
2374
                for pkey in payment_keys:
4✔
2375
                    if pkey.hex() in self.received_mpp_htlcs:
4✔
2376
                        self.set_mpp_resolution(payment_key=pkey, resolution=mpp_resolution)
4✔
2377

2378
        return mpp_resolution
4✔
2379

2380
    def update_mpp_with_received_htlc(
4✔
2381
        self,
2382
        *,
2383
        payment_key: bytes,
2384
        scid: ShortChannelID,
2385
        htlc: UpdateAddHtlc,
2386
        expected_msat: int,
2387
    ):
2388
        # add new htlc to set
2389
        mpp_status = self.received_mpp_htlcs.get(payment_key.hex())
4✔
2390
        if mpp_status is None:
4✔
2391
            mpp_status = ReceivedMPPStatus(
4✔
2392
                resolution=RecvMPPResolution.WAITING,
2393
                expected_msat=expected_msat,
2394
                htlc_set=set(),
2395
            )
2396
        if expected_msat != mpp_status.expected_msat:
4✔
2397
            self.logger.info(
4✔
2398
                f"marking received mpp as failed. inconsistent total_msats in bucket. {payment_key.hex()=}")
2399
            mpp_status = mpp_status._replace(resolution=RecvMPPResolution.FAILED)
4✔
2400
        key = (scid, htlc)
4✔
2401
        if key not in mpp_status.htlc_set:
4✔
2402
            mpp_status.htlc_set.add(key)  # side-effecting htlc_set
4✔
2403
        self.received_mpp_htlcs[payment_key.hex()] = mpp_status
4✔
2404

2405
    def set_mpp_resolution(self, *, payment_key: bytes, resolution: RecvMPPResolution):
4✔
2406
        mpp_status = self.received_mpp_htlcs[payment_key.hex()]
4✔
2407
        self.logger.info(f'set_mpp_resolution {resolution.name} {len(mpp_status.htlc_set)} {payment_key.hex()}')
4✔
2408
        self.received_mpp_htlcs[payment_key.hex()] = mpp_status._replace(resolution=resolution)
4✔
2409

2410
    def is_mpp_amount_reached(self, payment_key: bytes) -> bool:
4✔
2411
        mpp_status = self.received_mpp_htlcs.get(payment_key.hex())
4✔
2412
        if not mpp_status:
4✔
2413
            return False
4✔
2414
        total = sum([_htlc.amount_msat for scid, _htlc in mpp_status.htlc_set])
4✔
2415
        return total >= mpp_status.expected_msat
4✔
2416

2417
    def is_accepted_mpp(self, payment_hash: bytes) -> bool:
4✔
UNCOV
2418
        payment_key = self._get_payment_key(payment_hash)
×
UNCOV
2419
        status = self.received_mpp_htlcs.get(payment_key.hex())
×
UNCOV
2420
        return status and status.resolution == RecvMPPResolution.ACCEPTED
×
2421

2422
    def get_first_timestamp_of_mpp(self, payment_key: bytes) -> int:
4✔
2423
        mpp_status = self.received_mpp_htlcs.get(payment_key.hex())
4✔
2424
        if not mpp_status:
4✔
2425
            return int(time.time())
4✔
2426
        return min([_htlc.timestamp for scid, _htlc in mpp_status.htlc_set])
4✔
2427

2428
    def maybe_cleanup_mpp(
4✔
2429
            self,
2430
            short_channel_id: ShortChannelID,
2431
            htlc: UpdateAddHtlc,
2432
    ) -> None:
2433

2434
        htlc_key = (short_channel_id, htlc)
4✔
2435
        for payment_key_hex, mpp_status in list(self.received_mpp_htlcs.items()):
4✔
2436
            if htlc_key not in mpp_status.htlc_set:
4✔
2437
                continue
4✔
2438
            assert mpp_status.resolution != RecvMPPResolution.WAITING
4✔
2439
            self.logger.info(f'maybe_cleanup_mpp: removing htlc of MPP {payment_key_hex}')
4✔
2440
            mpp_status.htlc_set.remove(htlc_key)  # side-effecting htlc_set
4✔
2441
            if len(mpp_status.htlc_set) == 0:
4✔
2442
                self.logger.info(f'maybe_cleanup_mpp: removing mpp {payment_key_hex}')
4✔
2443
                self.received_mpp_htlcs.pop(payment_key_hex)
4✔
2444
                self.maybe_cleanup_forwarding(payment_key_hex)
4✔
2445

2446
    def maybe_cleanup_forwarding(self, payment_key_hex: str) -> None:
4✔
2447
        self.active_forwardings.pop(payment_key_hex, None)
4✔
2448
        self.forwarding_failures.pop(payment_key_hex, None)
4✔
2449

2450
    def get_payment_status(self, payment_hash: bytes) -> int:
4✔
2451
        info = self.get_payment_info(payment_hash)
4✔
2452
        return info.status if info else PR_UNPAID
4✔
2453

2454
    def get_invoice_status(self, invoice: BaseInvoice) -> int:
4✔
2455
        invoice_id = invoice.rhash
4✔
2456
        status = self.get_payment_status(bfh(invoice_id))
4✔
2457
        if status == PR_UNPAID and invoice_id in self.inflight_payments:
4✔
UNCOV
2458
            return PR_INFLIGHT
×
2459
        # status may be PR_FAILED
2460
        if status == PR_UNPAID and invoice_id in self.logs:
4✔
2461
            status = PR_FAILED
×
2462
        return status
4✔
2463

2464
    def set_invoice_status(self, key: str, status: int) -> None:
4✔
2465
        if status == PR_INFLIGHT:
4✔
2466
            self.inflight_payments.add(key)
4✔
2467
        elif key in self.inflight_payments:
4✔
2468
            self.inflight_payments.remove(key)
4✔
2469
        if status in SAVED_PR_STATUS:
4✔
2470
            self.set_payment_status(bfh(key), status)
4✔
2471
        util.trigger_callback('invoice_status', self.wallet, key, status)
4✔
2472
        self.logger.info(f"set_invoice_status {key}: {status}")
4✔
2473
        # liquidity changed
2474
        self.clear_invoices_cache()
4✔
2475

2476
    def set_request_status(self, payment_hash: bytes, status: int) -> None:
4✔
2477
        if self.get_payment_status(payment_hash) == status:
4✔
2478
            return
4✔
2479
        self.set_payment_status(payment_hash, status)
4✔
2480
        request_id = payment_hash.hex()
4✔
2481
        req = self.wallet.get_request(request_id)
4✔
2482
        if req is None:
4✔
2483
            return
4✔
2484
        util.trigger_callback('request_status', self.wallet, request_id, status)
4✔
2485

2486
    def set_payment_status(self, payment_hash: bytes, status: int) -> None:
4✔
2487
        info = self.get_payment_info(payment_hash)
4✔
2488
        if info is None:
4✔
2489
            # if we are forwarding
2490
            return
4✔
2491
        info = info._replace(status=status)
4✔
2492
        self.save_payment_info(info)
4✔
2493

2494
    def is_forwarded_htlc(self, htlc_key) -> Optional[str]:
4✔
2495
        """Returns whether this was a forwarded HTLC."""
2496
        for payment_key, htlcs in self.active_forwardings.items():
4✔
2497
            if htlc_key in htlcs:
4✔
2498
                return payment_key
4✔
2499

2500
    def notify_upstream_peer(self, htlc_key: str) -> None:
4✔
2501
        """Called when an HTLC we offered on chan gets irrevocably fulfilled or failed.
2502
        If we find this was a forwarded HTLC, the upstream peer is notified.
2503
        """
2504
        upstream_key = self.downstream_to_upstream_htlc.pop(htlc_key, None)
4✔
2505
        if not upstream_key:
4✔
2506
            return
4✔
2507
        upstream_chan_scid, _ = deserialize_htlc_key(upstream_key)
4✔
2508
        upstream_chan = self.get_channel_by_short_id(upstream_chan_scid)
4✔
2509
        upstream_peer = self.peers.get(upstream_chan.node_id) if upstream_chan else None
4✔
2510
        if upstream_peer:
4✔
2511
            upstream_peer.downstream_htlc_resolved_event.set()
4✔
2512
            upstream_peer.downstream_htlc_resolved_event.clear()
4✔
2513

2514
    def htlc_fulfilled(self, chan: Channel, payment_hash: bytes, htlc_id: int):
4✔
2515

2516
        util.trigger_callback('htlc_fulfilled', payment_hash, chan, htlc_id)
4✔
2517
        htlc_key = serialize_htlc_key(chan.get_scid_or_local_alias(), htlc_id)
4✔
2518
        fw_key = self.is_forwarded_htlc(htlc_key)
4✔
2519
        if fw_key:
4✔
2520
            fw_htlcs = self.active_forwardings[fw_key]
4✔
2521
            fw_htlcs.remove(htlc_key)
4✔
2522

2523
        shi = self.sent_htlcs_info.get((payment_hash, chan.short_channel_id, htlc_id))
4✔
2524
        if shi and htlc_id in chan.onion_keys:
4✔
2525
            chan.pop_onion_key(htlc_id)
4✔
2526
            payment_key = payment_hash + shi.payment_secret_orig
4✔
2527
            paysession = self._paysessions[payment_key]
4✔
2528
            q = paysession.sent_htlcs_q
4✔
2529
            htlc_log = HtlcLog(
4✔
2530
                success=True,
2531
                route=shi.route,
2532
                amount_msat=shi.amount_receiver_msat,
2533
                trampoline_fee_level=shi.trampoline_fee_level)
2534
            q.put_nowait(htlc_log)
4✔
2535
            if paysession.can_be_deleted():
4✔
2536
                self._paysessions.pop(payment_key)
4✔
2537
                paysession_active = False
4✔
2538
            else:
2539
                paysession_active = True
4✔
2540
        else:
2541
            if fw_key:
4✔
2542
                paysession_active = False
4✔
2543
            else:
2544
                key = payment_hash.hex()
4✔
2545
                self.set_invoice_status(key, PR_PAID)
4✔
2546
                util.trigger_callback('payment_succeeded', self.wallet, key)
4✔
2547

2548
        if fw_key:
4✔
2549
            fw_htlcs = self.active_forwardings[fw_key]
4✔
2550
            if len(fw_htlcs) == 0 and not paysession_active:
4✔
2551
                self.notify_upstream_peer(htlc_key)
4✔
2552

2553
    def htlc_failed(
4✔
2554
            self,
2555
            chan: Channel,
2556
            payment_hash: bytes,
2557
            htlc_id: int,
2558
            error_bytes: Optional[bytes],
2559
            failure_message: Optional['OnionRoutingFailure']):
2560
        # note: this may be called several times for the same htlc
2561

2562
        util.trigger_callback('htlc_failed', payment_hash, chan, htlc_id)
4✔
2563
        htlc_key = serialize_htlc_key(chan.get_scid_or_local_alias(), htlc_id)
4✔
2564
        fw_key = self.is_forwarded_htlc(htlc_key)
4✔
2565
        if fw_key:
4✔
2566
            fw_htlcs = self.active_forwardings[fw_key]
4✔
2567
            fw_htlcs.remove(htlc_key)
4✔
2568

2569
        shi = self.sent_htlcs_info.get((payment_hash, chan.short_channel_id, htlc_id))
4✔
2570
        if shi and htlc_id in chan.onion_keys:
4✔
2571
            onion_key = chan.pop_onion_key(htlc_id)
4✔
2572
            payment_okey = payment_hash + shi.payment_secret_orig
4✔
2573
            paysession = self._paysessions[payment_okey]
4✔
2574
            q = paysession.sent_htlcs_q
4✔
2575
            # detect if it is part of a bucket
2576
            # if yes, wait until the bucket completely failed
2577
            route = shi.route
4✔
2578
            if error_bytes:
4✔
2579
                # TODO "decode_onion_error" might raise, catch and maybe blacklist/penalise someone?
2580
                try:
4✔
2581
                    failure_message, sender_idx = decode_onion_error(
4✔
2582
                        error_bytes,
2583
                        [x.node_id for x in route],
2584
                        onion_key)
2585
                except Exception as e:
4✔
2586
                    sender_idx = None
4✔
2587
                    failure_message = OnionRoutingFailure(OnionFailureCode.INVALID_ONION_PAYLOAD, str(e).encode())
4✔
2588
            else:
2589
                # probably got "update_fail_malformed_htlc". well... who to penalise now?
UNCOV
2590
                assert failure_message is not None
×
UNCOV
2591
                sender_idx = None
×
2592
            self.logger.info(f"htlc_failed {failure_message}")
4✔
2593
            amount_receiver_msat = paysession.on_htlc_fail_get_fail_amt_to_propagate(shi)
4✔
2594
            if amount_receiver_msat is None:
4✔
2595
                return
4✔
2596
            if shi.trampoline_route:
4✔
2597
                route = shi.trampoline_route
4✔
2598
            htlc_log = HtlcLog(
4✔
2599
                success=False,
2600
                route=route,
2601
                amount_msat=amount_receiver_msat,
2602
                error_bytes=error_bytes,
2603
                failure_msg=failure_message,
2604
                sender_idx=sender_idx,
2605
                trampoline_fee_level=shi.trampoline_fee_level)
2606
            q.put_nowait(htlc_log)
4✔
2607
            if paysession.can_be_deleted():
4✔
2608
                self._paysessions.pop(payment_okey)
4✔
2609
                paysession_active = False
4✔
2610
            else:
2611
                paysession_active = True
4✔
2612
        else:
2613
            if fw_key:
4✔
2614
                paysession_active = False
4✔
2615
            else:
2616
                self.logger.info(f"received unknown htlc_failed, probably from previous session (phash={payment_hash.hex()})")
4✔
2617
                key = payment_hash.hex()
4✔
2618
                if self.get_payment_status(payment_hash) != PR_UNPAID:
4✔
UNCOV
2619
                    self.set_invoice_status(key, PR_UNPAID)
×
UNCOV
2620
                    util.trigger_callback('payment_failed', self.wallet, key, '')
×
2621

2622
        if fw_key:
4✔
2623
            fw_htlcs = self.active_forwardings[fw_key]
4✔
2624
            can_forward_failure = (len(fw_htlcs) == 0) and not paysession_active
4✔
2625
            if can_forward_failure:
4✔
2626
                self.logger.info(f'htlc_failed: save_forwarding_failure (phash={payment_hash.hex()})')
4✔
2627
                self.save_forwarding_failure(fw_key, error_bytes=error_bytes, failure_message=failure_message)
4✔
2628
                self.notify_upstream_peer(htlc_key)
4✔
2629
            else:
2630
                self.logger.info(f'htlc_failed: waiting for other htlcs to fail (phash={payment_hash.hex()})')
4✔
2631

2632
    def calc_routing_hints_for_invoice(self, amount_msat: Optional[int], channels=None, needs_jit=False):
4✔
2633
        """calculate routing hints (BOLT-11 'r' field)"""
2634
        routing_hints = []
4✔
2635
        if needs_jit:
4✔
UNCOV
2636
            node_id, rest = extract_nodeid(self.config.ZEROCONF_TRUSTED_NODE)
×
UNCOV
2637
            alias_or_scid = self.get_static_jit_scid_alias()
×
UNCOV
2638
            routing_hints.append(('r', [(node_id, alias_or_scid, 0, 0, 144)]))
×
2639
            # no need for more because we cannot receive enough through the others and mpp is disabled for jit
2640
            channels = []
×
2641
        else:
2642
            if channels is None:
4✔
2643
                channels = list(self.get_channels_for_receiving(amount_msat=amount_msat, include_disconnected=True))
4✔
2644
                random.shuffle(channels)  # let's not leak channel order
4✔
2645
            scid_to_my_channels = {
4✔
2646
                chan.short_channel_id: chan for chan in channels
2647
                if chan.short_channel_id is not None
2648
            }
2649
        for chan in channels:
4✔
2650
            alias_or_scid = chan.get_remote_scid_alias() or chan.short_channel_id
4✔
2651
            assert isinstance(alias_or_scid, bytes), alias_or_scid
4✔
2652
            channel_info = get_mychannel_info(chan.short_channel_id, scid_to_my_channels)
4✔
2653
            # note: as a fallback, if we don't have a channel update for the
2654
            # incoming direction of our private channel, we fill the invoice with garbage.
2655
            # the sender should still be able to pay us, but will incur an extra round trip
2656
            # (they will get the channel update from the onion error)
2657
            # at least, that's the theory. https://github.com/lightningnetwork/lnd/issues/2066
2658
            fee_base_msat = fee_proportional_millionths = 0
4✔
2659
            cltv_delta = 1  # lnd won't even try with zero
4✔
2660
            missing_info = True
4✔
2661
            if channel_info:
4✔
2662
                policy = get_mychannel_policy(channel_info.short_channel_id, chan.node_id, scid_to_my_channels)
4✔
2663
                if policy:
4✔
2664
                    fee_base_msat = policy.fee_base_msat
4✔
2665
                    fee_proportional_millionths = policy.fee_proportional_millionths
4✔
2666
                    cltv_delta = policy.cltv_delta
4✔
2667
                    missing_info = False
4✔
2668
            if missing_info:
4✔
UNCOV
2669
                self.logger.info(
×
2670
                    f"Warning. Missing channel update for our channel {chan.short_channel_id}; "
2671
                    f"filling invoice with incorrect data.")
2672
            routing_hints.append(('r', [(
4✔
2673
                chan.node_id,
2674
                alias_or_scid,
2675
                fee_base_msat,
2676
                fee_proportional_millionths,
2677
                cltv_delta)]))
2678
        return routing_hints
4✔
2679

2680
    def delete_payment_info(self, payment_hash_hex: str):
4✔
2681
        # This method is called when an invoice or request is deleted by the user.
2682
        # The GUI only lets the user delete invoices or requests that have not been paid.
2683
        # Once an invoice/request has been paid, it is part of the history,
2684
        # and get_lightning_history assumes that payment_info is there.
UNCOV
2685
        assert self.get_payment_status(bytes.fromhex(payment_hash_hex)) != PR_PAID
×
UNCOV
2686
        with self.lock:
×
UNCOV
2687
            self.payment_info.pop(payment_hash_hex, None)
×
2688

2689
    def get_balance(self, frozen=False):
4✔
2690
        with self.lock:
×
UNCOV
2691
            return Decimal(sum(
×
2692
                chan.balance(LOCAL) if not chan.is_closed() and (chan.is_frozen_for_sending() if frozen else True) else 0
2693
                for chan in self.channels.values())) / 1000
2694

2695
    def get_channels_for_sending(self):
4✔
UNCOV
2696
        for c in self.channels.values():
×
UNCOV
2697
            if c.is_active() and not c.is_frozen_for_sending():
×
UNCOV
2698
                if self.channel_db or self.is_trampoline_peer(c.node_id):
×
2699
                    yield c
×
2700

2701
    def fee_estimate(self, amount_sat):
4✔
2702
        # Here we have to guess a fee, because some callers (submarine swaps)
2703
        # use this method to initiate a payment, which would otherwise fail.
UNCOV
2704
        fee_base_msat = 5000               # FIXME ehh.. there ought to be a better way...
×
UNCOV
2705
        fee_proportional_millionths = 500  # FIXME
×
2706
        # inverse of fee_for_edge_msat
2707
        amount_msat = amount_sat * 1000
×
2708
        amount_minus_fees = (amount_msat - fee_base_msat) * 1_000_000 // ( 1_000_000 + fee_proportional_millionths)
×
UNCOV
2709
        return Decimal(amount_msat - amount_minus_fees) / 1000
×
2710

2711
    def num_sats_can_send(self, deltas=None) -> Decimal:
4✔
2712
        """
2713
        without trampoline, sum of all channel capacity
2714
        with trampoline, MPP must use a single trampoline
2715
        """
UNCOV
2716
        if deltas is None:
×
UNCOV
2717
            deltas = {}
×
2718

2719
        def send_capacity(chan):
×
2720
            if chan in deltas:
×
UNCOV
2721
                delta_msat = deltas[chan] * 1000
×
2722
                if delta_msat > chan.available_to_spend(REMOTE):
×
2723
                    delta_msat = 0
×
2724
            else:
2725
                delta_msat = 0
×
2726
            return chan.available_to_spend(LOCAL) + delta_msat
×
UNCOV
2727
        can_send_dict = defaultdict(int)
×
2728
        with self.lock:
×
2729
            for c in self.get_channels_for_sending():
×
2730
                if not self.uses_trampoline():
×
2731
                    can_send_dict[0] += send_capacity(c)
×
2732
                else:
2733
                    can_send_dict[c.node_id] += send_capacity(c)
×
2734
        can_send = max(can_send_dict.values()) if can_send_dict else 0
×
UNCOV
2735
        can_send_sat = Decimal(can_send)/1000
×
2736
        can_send_sat -= self.fee_estimate(can_send_sat)
×
2737
        return max(can_send_sat, 0)
×
2738

2739
    def get_channels_for_receiving(
4✔
2740
        self, *, amount_msat: Optional[int] = None, include_disconnected: bool = False,
2741
    ) -> Sequence[Channel]:
2742
        if not amount_msat:  # assume we want to recv a large amt, e.g. finding max.
4✔
UNCOV
2743
            amount_msat = float('inf')
×
2744
        with self.lock:
4✔
2745
            channels = list(self.channels.values())
4✔
2746
            channels = [chan for chan in channels
4✔
2747
                        if chan.is_open() and not chan.is_frozen_for_receiving()]
2748

2749
            if not include_disconnected:
4✔
UNCOV
2750
                channels = [chan for chan in channels if chan.is_active()]
×
2751

2752
            # Filter out nodes that have low receive capacity compared to invoice amt.
2753
            # Even with MPP, below a certain threshold, including these channels probably
2754
            # hurts more than help, as they lead to many failed attempts for the sender.
2755
            channels = sorted(channels, key=lambda chan: -chan.available_to_spend(REMOTE))
4✔
2756
            selected_channels = []
4✔
2757
            running_sum = 0
4✔
2758
            cutoff_factor = 0.2  # heuristic
4✔
2759
            for chan in channels:
4✔
2760
                recv_capacity = chan.available_to_spend(REMOTE)
4✔
2761
                chan_can_handle_payment_as_single_part = recv_capacity >= amount_msat
4✔
2762
                chan_small_compared_to_running_sum = recv_capacity < cutoff_factor * running_sum
4✔
2763
                if not chan_can_handle_payment_as_single_part and chan_small_compared_to_running_sum:
4✔
2764
                    break
4✔
2765
                running_sum += recv_capacity
4✔
2766
                selected_channels.append(chan)
4✔
2767
            channels = selected_channels
4✔
2768
            del selected_channels
4✔
2769
            # cap max channels to include to keep QR code reasonably scannable
2770
            channels = channels[:10]
4✔
2771
            return channels
4✔
2772

2773
    def num_sats_can_receive(self, deltas=None) -> Decimal:
4✔
2774
        """
2775
        We no longer assume the sender to send MPP on different channels,
2776
        because channel liquidities are hard to guess
2777
        """
UNCOV
2778
        if deltas is None:
×
UNCOV
2779
            deltas = {}
×
2780

2781
        def recv_capacity(chan):
×
2782
            if chan in deltas:
×
UNCOV
2783
                delta_msat = deltas[chan] * 1000
×
2784
                if delta_msat > chan.available_to_spend(LOCAL):
×
2785
                    delta_msat = 0
×
2786
            else:
2787
                delta_msat = 0
×
2788
            return chan.available_to_spend(REMOTE) + delta_msat
×
UNCOV
2789
        with self.lock:
×
2790
            recv_channels = self.get_channels_for_receiving()
×
2791
            recv_chan_msats = [recv_capacity(chan) for chan in recv_channels]
×
2792
        if not recv_chan_msats:
×
2793
            return Decimal(0)
×
2794
        can_receive_msat = max(recv_chan_msats)
×
2795
        return Decimal(can_receive_msat) / 1000
×
2796

2797
    def receive_requires_jit_channel(self, amount_msat: Optional[int]) -> bool:
4✔
2798
        """Returns true if we cannot receive the amount and have set up a trusted LSP node.
2799
        Cannot work reliably with 0 amount invoices as we don't know if we are able to receive it.
2800
        """
2801
        # zeroconf provider is configured and connected
UNCOV
2802
        if (self.can_get_zeroconf_channel()
×
2803
                # we cannot receive the amount specified
2804
                and ((amount_msat and self.num_sats_can_receive() < (amount_msat // 1000))
2805
                    # or we cannot receive anything, and it's a 0 amount invoice
2806
                    or (not amount_msat and self.num_sats_can_receive() < 1))):
UNCOV
2807
            return True
×
UNCOV
2808
        return False
×
2809

2810
    def can_get_zeroconf_channel(self) -> bool:
4✔
2811
        if not self.config.ACCEPT_ZEROCONF_CHANNELS and self.config.ZEROCONF_TRUSTED_NODE:
×
2812
            # check if zeroconf is accepted and client has trusted zeroconf node configured
UNCOV
2813
            return False
×
2814
        try:
×
UNCOV
2815
            node_id = extract_nodeid(self.wallet.config.ZEROCONF_TRUSTED_NODE)[0]
×
2816
        except ConnStringFormatError:
×
2817
            # invalid connection string
2818
            return False
×
2819
        # only return True if we are connected to the zeroconf provider
UNCOV
2820
        return node_id in self.peers
×
2821

2822
    def _suggest_channels_for_rebalance(self, direction, amount_sat) -> Sequence[Tuple[Channel, int]]:
4✔
2823
        """
2824
        Suggest a channel and amount to send/receive with that channel, so that we will be able to receive/send amount_sat
2825
        This is used when suggesting a swap or rebalance in order to receive a payment
2826
        """
UNCOV
2827
        with self.lock:
×
UNCOV
2828
            func = self.num_sats_can_send if direction == SENT else self.num_sats_can_receive
×
UNCOV
2829
            suggestions = []
×
2830
            channels = self.get_channels_for_sending() if direction == SENT else self.get_channels_for_receiving()
×
2831
            for chan in channels:
×
2832
                available_sat = chan.available_to_spend(LOCAL if direction == SENT else REMOTE) // 1000
×
2833
                delta = amount_sat - available_sat
×
2834
                delta += self.fee_estimate(amount_sat)
×
2835
                # add safety margin
2836
                delta += delta // 100 + 1
×
2837
                if func(deltas={chan:delta}) >= amount_sat:
×
UNCOV
2838
                    suggestions.append((chan, delta))
×
2839
                elif direction==RECEIVED and func(deltas={chan:2*delta}) >= amount_sat:
×
2840
                    # MPP heuristics has a 0.5 slope
2841
                    suggestions.append((chan, 2*delta))
×
2842
        if not suggestions:
×
UNCOV
2843
            raise NotEnoughFunds
×
2844
        return suggestions
×
2845

2846
    def _suggest_rebalance(self, direction, amount_sat):
4✔
2847
        """
2848
        Suggest a rebalance in order to be able to send or receive amount_sat.
2849
        Returns (from_channel, to_channel, amount to shuffle)
2850
        """
UNCOV
2851
        try:
×
UNCOV
2852
            suggestions = self._suggest_channels_for_rebalance(direction, amount_sat)
×
UNCOV
2853
        except NotEnoughFunds:
×
2854
            return False
×
2855
        for chan2, delta in suggestions:
×
2856
            # margin for fee caused by rebalancing
2857
            delta += self.fee_estimate(amount_sat)
×
2858
            # find other channel or trampoline that can send delta
UNCOV
2859
            for chan1 in self.channels.values():
×
2860
                if chan1.is_frozen_for_sending() or not chan1.is_active():
×
UNCOV
2861
                    continue
×
2862
                if chan1 == chan2:
×
2863
                    continue
×
2864
                if self.uses_trampoline() and chan1.node_id == chan2.node_id:
×
2865
                    continue
×
2866
                if direction == SENT:
×
2867
                    if chan1.can_pay(delta*1000):
×
2868
                        return (chan1, chan2, delta)
×
2869
                else:
2870
                    if chan1.can_receive(delta*1000):
×
2871
                        return (chan2, chan1, delta)
×
2872
            else:
2873
                continue
×
2874
        else:
UNCOV
2875
            return False
×
2876

2877
    def num_sats_can_rebalance(self, chan1, chan2):
4✔
2878
        # TODO: we should be able to spend 'max', with variable fee
UNCOV
2879
        n1 = chan1.available_to_spend(LOCAL)
×
UNCOV
2880
        n1 -= self.fee_estimate(n1)
×
UNCOV
2881
        n2 = chan2.available_to_spend(REMOTE)
×
2882
        amount_sat = min(n1, n2) // 1000
×
2883
        return amount_sat
×
2884

2885
    def suggest_rebalance_to_send(self, amount_sat):
4✔
2886
        return self._suggest_rebalance(SENT, amount_sat)
×
2887

2888
    def suggest_rebalance_to_receive(self, amount_sat):
4✔
2889
        return self._suggest_rebalance(RECEIVED, amount_sat)
×
2890

2891
    def suggest_swap_to_send(self, amount_sat, coins):
4✔
2892
        # fixme: if swap_amount_sat is lower than the minimum swap amount, we need to propose a higher value
UNCOV
2893
        assert amount_sat > self.num_sats_can_send()
×
UNCOV
2894
        try:
×
UNCOV
2895
            suggestions = self._suggest_channels_for_rebalance(SENT, amount_sat)
×
2896
        except NotEnoughFunds:
×
2897
            return None
×
2898
        for chan, swap_recv_amount in suggestions:
×
2899
            # check that we can send onchain
2900
            swap_server_mining_fee = 10000 # guessing, because we have not called get_pairs yet
×
2901
            swap_funding_sat = swap_recv_amount + swap_server_mining_fee
×
UNCOV
2902
            swap_output = PartialTxOutput.from_address_and_value(DummyAddress.SWAP, int(swap_funding_sat))
×
2903
            try:
×
2904
                # check if we have enough onchain funds
2905
                self.wallet.make_unsigned_transaction(
×
2906
                    coins=coins,
2907
                    outputs=[swap_output],
2908
                    fee_policy=FeePolicy(self.config.FEE_POLICY_SWAPS),
2909
                )
UNCOV
2910
            except NotEnoughFunds:
×
UNCOV
2911
                continue
×
UNCOV
2912
            return chan, swap_recv_amount
×
2913
        return None
×
2914

2915
    def suggest_swap_to_receive(self, amount_sat):
4✔
2916
        assert amount_sat > self.num_sats_can_receive()
×
UNCOV
2917
        try:
×
UNCOV
2918
            suggestions = self._suggest_channels_for_rebalance(RECEIVED, amount_sat)
×
2919
        except NotEnoughFunds:
×
2920
            return
×
2921
        for chan, swap_recv_amount in suggestions:
×
2922
            return chan, swap_recv_amount
×
2923

2924
    async def rebalance_channels(self, chan1: Channel, chan2: Channel, *, amount_msat: int):
4✔
2925
        if chan1 == chan2:
×
UNCOV
2926
            raise Exception('Rebalance requires two different channels')
×
UNCOV
2927
        if self.uses_trampoline() and chan1.node_id == chan2.node_id:
×
2928
            raise Exception('Rebalance requires channels from different trampolines')
×
2929
        payment_hash = self.create_payment_info(amount_msat=amount_msat)
×
2930
        lnaddr, invoice = self.get_bolt11_invoice(
×
2931
            payment_hash=payment_hash,
2932
            amount_msat=amount_msat,
2933
            message='rebalance',
2934
            expiry=3600,
2935
            fallback_address=None,
2936
            channels=[chan2],
2937
        )
UNCOV
2938
        invoice_obj = Invoice.from_bech32(invoice)
×
UNCOV
2939
        return await self.pay_invoice(invoice_obj, channels=[chan1])
×
2940

2941
    def can_receive_invoice(self, invoice: BaseInvoice) -> bool:
4✔
2942
        assert invoice.is_lightning()
×
UNCOV
2943
        return (invoice.get_amount_sat() or 0) <= self.num_sats_can_receive()
×
2944

2945
    async def close_channel(self, chan_id):
4✔
2946
        chan = self._channels[chan_id]
×
UNCOV
2947
        peer = self._peers[chan.node_id]
×
UNCOV
2948
        return await peer.close_channel(chan_id)
×
2949

2950
    def _force_close_channel(self, chan_id: bytes) -> Transaction:
4✔
2951
        chan = self._channels[chan_id]
4✔
2952
        tx = chan.force_close_tx()
4✔
2953
        # We set the channel state to make sure we won't sign new commitment txs.
2954
        # We expect the caller to try to broadcast this tx, after which it is
2955
        # not safe to keep using the channel even if the broadcast errors (server could be lying).
2956
        # Until the tx is seen in the mempool, there will be automatic rebroadcasts.
2957
        chan.set_state(ChannelState.FORCE_CLOSING)
4✔
2958
        # Add local tx to wallet to also allow manual rebroadcasts.
2959
        try:
4✔
2960
            self.wallet.adb.add_transaction(tx)
4✔
UNCOV
2961
        except UnrelatedTransactionException:
×
UNCOV
2962
            pass  # this can happen if (~all the balance goes to REMOTE)
×
2963
        return tx
4✔
2964

2965
    async def force_close_channel(self, chan_id: bytes) -> str:
4✔
2966
        """Force-close the channel. Network-related exceptions are propagated to the caller.
2967
        (automatic rebroadcasts will be scheduled)
2968
        """
2969
        # note: as we are async, it can take a few event loop iterations between the caller
2970
        #       "calling us" and us getting to run, and we only set the channel state now:
2971
        tx = self._force_close_channel(chan_id)
4✔
2972
        await self.network.broadcast_transaction(tx)
4✔
2973
        return tx.txid()
4✔
2974

2975
    def schedule_force_closing(self, chan_id: bytes) -> 'asyncio.Task[bool]':
4✔
2976
        """Schedules a task to force-close the channel and returns it.
2977
        Network-related exceptions are suppressed.
2978
        (automatic rebroadcasts will be scheduled)
2979
        Note: this method is intentionally not async so that callers have a guarantee
2980
              that the channel state is set immediately.
2981
        """
2982
        tx = self._force_close_channel(chan_id)
4✔
2983
        return asyncio.create_task(self.network.try_broadcasting(tx, 'force-close'))
4✔
2984

2985
    def remove_channel(self, chan_id):
4✔
UNCOV
2986
        chan = self.channels[chan_id]
×
UNCOV
2987
        assert chan.can_be_deleted()
×
UNCOV
2988
        with self.lock:
×
2989
            self._channels.pop(chan_id)
×
2990
            self.db.get('channels').pop(chan_id.hex())
×
2991
        self.wallet.set_reserved_addresses_for_chan(chan, reserved=False)
×
2992

2993
        util.trigger_callback('channels_updated', self.wallet)
×
2994
        util.trigger_callback('wallet_updated', self.wallet)
×
2995

2996
    @ignore_exceptions
4✔
2997
    @log_exceptions
4✔
2998
    async def reestablish_peer_for_given_channel(self, chan: Channel) -> None:
4✔
UNCOV
2999
        now = time.time()
×
UNCOV
3000
        peer_addresses = []
×
UNCOV
3001
        if self.uses_trampoline():
×
3002
            addr = trampolines_by_id().get(chan.node_id)
×
3003
            if addr:
×
3004
                peer_addresses.append(addr)
×
3005
        else:
3006
            # will try last good address first, from gossip
3007
            last_good_addr = self.channel_db.get_last_good_address(chan.node_id)
×
UNCOV
3008
            if last_good_addr:
×
UNCOV
3009
                peer_addresses.append(last_good_addr)
×
3010
            # will try addresses for node_id from gossip
3011
            addrs_from_gossip = self.channel_db.get_node_addresses(chan.node_id) or []
×
3012
            for host, port, ts in addrs_from_gossip:
×
UNCOV
3013
                peer_addresses.append(LNPeerAddr(host, port, chan.node_id))
×
3014
        # will try addresses stored in channel storage
3015
        peer_addresses += list(chan.get_peer_addresses())
×
3016
        # Done gathering addresses.
3017
        # Now select first one that has not failed recently.
3018
        for peer in peer_addresses:
×
UNCOV
3019
            if self._can_retry_addr(peer, urgent=True, now=now):
×
UNCOV
3020
                await self._add_peer(peer.host, peer.port, peer.pubkey)
×
3021
                return
×
3022

3023
    async def reestablish_peers_and_channels(self):
4✔
3024
        while True:
×
UNCOV
3025
            await asyncio.sleep(1)
×
UNCOV
3026
            if self.stopping_soon:
×
3027
                return
×
3028
            if self.config.ZEROCONF_TRUSTED_NODE:
×
3029
                peer = LNPeerAddr.from_str(self.config.ZEROCONF_TRUSTED_NODE)
×
3030
                if self._can_retry_addr(peer, urgent=True):
×
3031
                    await self._add_peer(peer.host, peer.port, peer.pubkey)
×
3032
            for chan in self.channels.values():
×
3033
                # reestablish
3034
                # note: we delegate filtering out uninteresting chans to this:
3035
                if not chan.should_try_to_reestablish_peer():
×
UNCOV
3036
                    continue
×
UNCOV
3037
                peer = self._peers.get(chan.node_id, None)
×
3038
                if peer:
×
3039
                    await peer.taskgroup.spawn(peer.reestablish_channel(chan))
×
3040
                else:
3041
                    await self.taskgroup.spawn(self.reestablish_peer_for_given_channel(chan))
×
3042

3043
    def current_target_feerate_per_kw(self, *, has_anchors: bool) -> Optional[int]:
4✔
3044
        if self.network.fee_estimates.has_data():
4✔
3045
            target: int = FEE_LN_MINIMUM_ETA_TARGET if has_anchors else FEE_LN_ETA_TARGET
4✔
3046
            feerate_per_kvbyte = self.network.fee_estimates.eta_target_to_fee(target)
4✔
3047
            if has_anchors:
4✔
3048
                # set a floor of 5 sat/vb to have some safety margin in case the mempool
3049
                # grows quickly
UNCOV
3050
                feerate_per_kvbyte = max(feerate_per_kvbyte, 5000)
×
3051
        else:
UNCOV
3052
            if constants.net is not constants.BitcoinRegtest:
×
3053
                return None
×
UNCOV
3054
            feerate_per_kvbyte = FEERATE_FALLBACK_STATIC_FEE
×
3055
        return max(FEERATE_PER_KW_MIN_RELAY_LIGHTNING, feerate_per_kvbyte // 4)
4✔
3056

3057
    def current_low_feerate_per_kw_srk_channel(self) -> Optional[int]:
4✔
3058
        """Gets low feerate for static remote key channels."""
3059
        if constants.net is constants.BitcoinRegtest:
4✔
UNCOV
3060
            feerate_per_kvbyte = 0
×
3061
        else:
3062
            if not self.network.fee_estimates.has_data():
4✔
3063
                return None
×
3064
            feerate_per_kvbyte = self.network.fee_estimates.eta_target_to_fee(FEE_LN_LOW_ETA_TARGET) or 0
4✔
3065
        low_feerate_per_kw = max(FEERATE_PER_KW_MIN_RELAY_LIGHTNING, feerate_per_kvbyte // 4)
4✔
3066
        # make sure this is never higher than the target feerate:
3067
        current_target_feerate = self.current_target_feerate_per_kw(has_anchors=False)
4✔
3068
        if not current_target_feerate:
4✔
UNCOV
3069
            return None
×
3070
        low_feerate_per_kw = min(low_feerate_per_kw, current_target_feerate)
4✔
3071
        return low_feerate_per_kw
4✔
3072

3073
    def create_channel_backup(self, channel_id: bytes):
4✔
UNCOV
3074
        chan = self._channels[channel_id]
×
3075
        # do not backup old-style channels
UNCOV
3076
        assert chan.is_static_remotekey_enabled()
×
3077
        peer_addresses = list(chan.get_peer_addresses())
×
UNCOV
3078
        peer_addr = peer_addresses[0]
×
3079
        return ImportedChannelBackupStorage(
×
3080
            node_id = chan.node_id,
3081
            privkey = self.node_keypair.privkey,
3082
            funding_txid = chan.funding_outpoint.txid,
3083
            funding_index = chan.funding_outpoint.output_index,
3084
            funding_address = chan.get_funding_address(),
3085
            host = peer_addr.host,
3086
            port = peer_addr.port,
3087
            is_initiator = chan.constraints.is_initiator,
3088
            channel_seed = chan.config[LOCAL].channel_seed,
3089
            local_delay = chan.config[LOCAL].to_self_delay,
3090
            remote_delay = chan.config[REMOTE].to_self_delay,
3091
            remote_revocation_pubkey = chan.config[REMOTE].revocation_basepoint.pubkey,
3092
            remote_payment_pubkey = chan.config[REMOTE].payment_basepoint.pubkey,
3093
            local_payment_pubkey=chan.config[LOCAL].payment_basepoint.pubkey,
3094
            multisig_funding_privkey=chan.config[LOCAL].multisig_key.privkey,
3095
        )
3096

3097
    def export_channel_backup(self, channel_id):
4✔
UNCOV
3098
        xpub = self.wallet.get_fingerprint()
×
UNCOV
3099
        backup_bytes = self.create_channel_backup(channel_id).to_bytes()
×
UNCOV
3100
        assert backup_bytes == ImportedChannelBackupStorage.from_bytes(backup_bytes).to_bytes(), "roundtrip failed"
×
3101
        encrypted = pw_encode_with_version_and_mac(backup_bytes, xpub)
×
3102
        assert backup_bytes == pw_decode_with_version_and_mac(encrypted, xpub), "encrypt failed"
×
3103
        return 'channel_backup:' + encrypted
×
3104

3105
    async def request_force_close(self, channel_id: bytes, *, connect_str=None) -> None:
4✔
3106
        if channel_id in self.channels:
×
UNCOV
3107
            chan = self.channels[channel_id]
×
UNCOV
3108
            peer = self._peers.get(chan.node_id)
×
3109
            chan.should_request_force_close = True
×
3110
            if peer:
×
3111
                peer.close_and_cleanup()  # to force a reconnect
×
3112
        elif connect_str:
×
3113
            peer = await self.add_peer(connect_str)
×
3114
            await peer.request_force_close(channel_id)
×
3115
        elif channel_id in self.channel_backups:
×
3116
            await self._request_force_close_from_backup(channel_id)
×
3117
        else:
3118
            raise Exception(f'Unknown channel {channel_id.hex()}')
×
3119

3120
    def import_channel_backup(self, data):
4✔
3121
        xpub = self.wallet.get_fingerprint()
×
UNCOV
3122
        cb_storage = ImportedChannelBackupStorage.from_encrypted_str(data, password=xpub)
×
UNCOV
3123
        channel_id = cb_storage.channel_id()
×
3124
        if channel_id.hex() in self.db.get_dict("channels"):
×
3125
            raise Exception('Channel already in wallet')
×
3126
        self.logger.info(f'importing channel backup: {channel_id.hex()}')
×
3127
        d = self.db.get_dict("imported_channel_backups")
×
3128
        d[channel_id.hex()] = cb_storage
×
3129
        with self.lock:
×
3130
            cb = ChannelBackup(cb_storage, lnworker=self)
×
3131
            self._channel_backups[channel_id] = cb
×
3132
        self.wallet.set_reserved_addresses_for_chan(cb, reserved=True)
×
3133
        self.wallet.save_db()
×
3134
        util.trigger_callback('channels_updated', self.wallet)
×
3135
        self.lnwatcher.add_channel(cb)
×
3136

3137
    def has_conflicting_backup_with(self, remote_node_id: bytes):
4✔
3138
        """ Returns whether we have an active channel with this node on another device, using same local node id. """
UNCOV
3139
        channel_backup_peers = [
×
3140
            cb.node_id for cb in self.channel_backups.values()
3141
            if (not cb.is_closed() and cb.get_local_pubkey() == self.node_keypair.pubkey)]
3142
        return any(remote_node_id.startswith(cb_peer_nodeid) for cb_peer_nodeid in channel_backup_peers)
×
3143

3144
    def remove_channel_backup(self, channel_id):
4✔
3145
        chan = self.channel_backups[channel_id]
×
UNCOV
3146
        assert chan.can_be_deleted()
×
UNCOV
3147
        found = False
×
3148
        onchain_backups = self.db.get_dict("onchain_channel_backups")
×
3149
        imported_backups = self.db.get_dict("imported_channel_backups")
×
3150
        if channel_id.hex() in onchain_backups:
×
3151
            onchain_backups.pop(channel_id.hex())
×
3152
            found = True
×
3153
        if channel_id.hex() in imported_backups:
×
3154
            imported_backups.pop(channel_id.hex())
×
3155
            found = True
×
3156
        if not found:
×
3157
            raise Exception('Channel not found')
×
3158
        with self.lock:
×
3159
            self._channel_backups.pop(channel_id)
×
3160
        self.wallet.set_reserved_addresses_for_chan(chan, reserved=False)
×
3161
        self.wallet.save_db()
×
3162
        util.trigger_callback('channels_updated', self.wallet)
×
3163

3164
    @log_exceptions
4✔
3165
    async def _request_force_close_from_backup(self, channel_id: bytes):
4✔
UNCOV
3166
        cb = self.channel_backups.get(channel_id)
×
UNCOV
3167
        if not cb:
×
UNCOV
3168
            raise Exception(f'channel backup not found {self.channel_backups}')
×
3169
        cb = cb.cb # storage
×
3170
        self.logger.info(f'requesting channel force close: {channel_id.hex()}')
×
3171
        if isinstance(cb, ImportedChannelBackupStorage):
×
3172
            node_id = cb.node_id
×
3173
            privkey = cb.privkey
×
3174
            addresses = [(cb.host, cb.port, 0)]
×
3175
        else:
3176
            assert isinstance(cb, OnchainChannelBackupStorage)
×
3177
            privkey = self.node_keypair.privkey
×
UNCOV
3178
            for pubkey, peer_addr in trampolines_by_id().items():
×
3179
                if pubkey.startswith(cb.node_id_prefix):
×
3180
                    node_id = pubkey
×
3181
                    addresses = [(peer_addr.host, peer_addr.port, 0)]
×
3182
                    break
×
3183
            else:
3184
                # we will try with gossip (see below)
3185
                addresses = []
×
3186

UNCOV
3187
        async def _request_fclose(addresses):
×
3188
            for host, port, timestamp in addresses:
×
UNCOV
3189
                peer_addr = LNPeerAddr(host, port, node_id)
×
3190
                transport = LNTransport(privkey, peer_addr, e_proxy=ESocksProxy.from_network_settings(self.network))
×
3191
                peer = Peer(self, node_id, transport, is_channel_backup=True)
×
3192
                try:
×
3193
                    async with OldTaskGroup(wait=any) as group:
×
3194
                        await group.spawn(peer._message_loop())
×
3195
                        await group.spawn(peer.request_force_close(channel_id))
×
3196
                    return True
×
3197
                except Exception as e:
×
3198
                    self.logger.info(f'failed to connect {host} {e}')
×
3199
                    continue
×
3200
            else:
3201
                return False
×
3202
        # try first without gossip db
UNCOV
3203
        success = await _request_fclose(addresses)
×
3204
        if success:
×
UNCOV
3205
            return
×
3206
        # try with gossip db
3207
        if self.uses_trampoline():
×
3208
            raise Exception(_('Please enable gossip'))
×
UNCOV
3209
        node_id = self.network.channel_db.get_node_by_prefix(cb.node_id_prefix)
×
3210
        addresses_from_gossip = self.network.channel_db.get_node_addresses(node_id)
×
3211
        if not addresses_from_gossip:
×
3212
            raise Exception('Peer not found in gossip database')
×
3213
        success = await _request_fclose(addresses_from_gossip)
×
3214
        if not success:
×
3215
            raise Exception('failed to connect')
×
3216

3217
    def maybe_add_backup_from_tx(self, tx):
4✔
3218
        funding_address = None
4✔
3219
        node_id_prefix = None
4✔
3220
        for i, o in enumerate(tx.outputs()):
4✔
3221
            script_type = get_script_type_from_output_script(o.scriptpubkey)
4✔
3222
            if script_type == 'p2wsh':
4✔
UNCOV
3223
                funding_index = i
×
UNCOV
3224
                funding_address = o.address
×
UNCOV
3225
                for o2 in tx.outputs():
×
3226
                    if o2.scriptpubkey.startswith(bytes([opcodes.OP_RETURN])):
×
3227
                        encrypted_data = o2.scriptpubkey[2:]
×
3228
                        data = self.decrypt_cb_data(encrypted_data, funding_address)
×
3229
                        if data.startswith(CB_MAGIC_BYTES):
×
3230
                            node_id_prefix = data[len(CB_MAGIC_BYTES):]
×
3231
        if node_id_prefix is None:
4✔
3232
            return
4✔
3233
        funding_txid = tx.txid()
×
UNCOV
3234
        cb_storage = OnchainChannelBackupStorage(
×
3235
            node_id_prefix = node_id_prefix,
3236
            funding_txid = funding_txid,
3237
            funding_index = funding_index,
3238
            funding_address = funding_address,
3239
            is_initiator = True)
UNCOV
3240
        channel_id = cb_storage.channel_id().hex()
×
UNCOV
3241
        if channel_id in self.db.get_dict("channels"):
×
UNCOV
3242
            return
×
3243
        self.logger.info(f"adding backup from tx")
×
3244
        d = self.db.get_dict("onchain_channel_backups")
×
3245
        d[channel_id] = cb_storage
×
3246
        cb = ChannelBackup(cb_storage, lnworker=self)
×
3247
        self.wallet.set_reserved_addresses_for_chan(cb, reserved=True)
×
3248
        self.wallet.save_db()
×
3249
        with self.lock:
×
3250
            self._channel_backups[bfh(channel_id)] = cb
×
3251
        util.trigger_callback('channels_updated', self.wallet)
×
3252
        self.lnwatcher.add_channel(cb)
×
3253

3254
    def save_forwarding_failure(
4✔
3255
            self, payment_key:str, *,
3256
            error_bytes: Optional[bytes] = None,
3257
            failure_message: Optional['OnionRoutingFailure'] = None):
3258
        error_hex = error_bytes.hex() if error_bytes else None
4✔
3259
        failure_hex = failure_message.to_bytes().hex() if failure_message else None
4✔
3260
        self.forwarding_failures[payment_key] = (error_hex, failure_hex)
4✔
3261

3262
    def get_forwarding_failure(self, payment_key: str) -> Tuple[Optional[bytes], Optional['OnionRoutingFailure']]:
4✔
3263
        error_hex, failure_hex = self.forwarding_failures.get(payment_key, (None, None))
4✔
3264
        error_bytes = bytes.fromhex(error_hex) if error_hex else None
4✔
3265
        failure_message = OnionRoutingFailure.from_bytes(bytes.fromhex(failure_hex)) if failure_hex else None
4✔
3266
        return error_bytes, failure_message
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc