• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

spesmilo / electrum / 4878529344569344

04 Mar 2025 10:05AM UTC coverage: 60.716% (-0.02%) from 60.731%
4878529344569344

Pull #9587

CirrusCI

f321x
disable mpp flags in invoice creation if jit channel is required, check against available liquidity if we need a jit channel
Pull Request #9587: Disable mpp flags in invoice creation if jit channel is required and consider available liquidity

5 of 15 new or added lines in 2 files covered. (33.33%)

847 existing lines in 6 files now uncovered.

20678 of 34057 relevant lines covered (60.72%)

3.03 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

50.82
/electrum/lnworker.py
1
# Copyright (C) 2018 The Electrum developers
2
# Distributed under the MIT software license, see the accompanying
3
# file LICENCE or http://www.opensource.org/licenses/mit-license.php
4

5
import asyncio
5✔
6
import os
5✔
7
from decimal import Decimal
5✔
8
import random
5✔
9
import time
5✔
10
from enum import IntEnum
5✔
11
from typing import (
5✔
12
    Optional, Sequence, Tuple, List, Set, Dict, TYPE_CHECKING, NamedTuple, Mapping, Any, Iterable, AsyncGenerator,
13
    Callable, Awaitable
14
)
15
import threading
5✔
16
import socket
5✔
17
from functools import partial
5✔
18
from collections import defaultdict
5✔
19
import concurrent
5✔
20
from concurrent import futures
5✔
21
import urllib.parse
5✔
22
import itertools
5✔
23

24
import aiohttp
5✔
25
import dns.resolver
5✔
26
import dns.exception
5✔
27
from aiorpcx import run_in_thread, NetAddress, ignore_after
5✔
28

29
from .logging import Logger
5✔
30
from .i18n import _
5✔
31
from .json_db import stored_in
5✔
32
from .channel_db import UpdateStatus, ChannelDBNotLoaded, get_mychannel_info, get_mychannel_policy
5✔
33

34
from . import constants, util
5✔
35
from .util import (
5✔
36
    profiler, OldTaskGroup, ESocksProxy, NetworkRetryManager, JsonRPCClient, NotEnoughFunds, EventListener,
37
    event_listener, bfh, InvoiceError, resolve_dns_srv, is_ip_address, log_exceptions, ignore_exceptions,
38
    make_aiohttp_session, timestamp_to_datetime, random_shuffled_copy, is_private_netaddress,
39
    UnrelatedTransactionException, LightningHistoryItem
40
)
41
from .invoices import Invoice, PR_UNPAID, PR_PAID, PR_INFLIGHT, PR_FAILED, LN_EXPIRY_NEVER, BaseInvoice
5✔
42
from .bitcoin import COIN, opcodes, make_op_return, address_to_scripthash, DummyAddress
5✔
43
from .bip32 import BIP32Node
5✔
44
from .address_synchronizer import TX_HEIGHT_LOCAL, TX_TIMESTAMP_INF
5✔
45
from .transaction import (
5✔
46
    Transaction, get_script_type_from_output_script, PartialTxOutput, PartialTransaction, PartialTxInput
47
)
48
from .crypto import (
5✔
49
    sha256, chacha20_encrypt, chacha20_decrypt, pw_encode_with_version_and_mac, pw_decode_with_version_and_mac
50
)
51

52
from .onion_message import OnionMessageManager
5✔
53
from .lntransport import LNTransport, LNResponderTransport, LNTransportBase, LNPeerAddr, split_host_port, extract_nodeid, ConnStringFormatError
5✔
54
from .lnpeer import Peer, LN_P2P_NETWORK_TIMEOUT
5✔
55
from .lnaddr import lnencode, LnAddr, lndecode
5✔
56
from .lnchannel import Channel, AbstractChannel, ChannelState, PeerState, HTLCWithStatus, ChannelBackup
5✔
57
from .lnrater import LNRater
5✔
58
from .lnutil import (
5✔
59
    get_compressed_pubkey_from_bech32, serialize_htlc_key, deserialize_htlc_key, PaymentFailure, generate_keypair,
60
    LnKeyFamily, LOCAL, REMOTE, MIN_FINAL_CLTV_DELTA_FOR_INVOICE, SENT, RECEIVED, HTLCOwner, UpdateAddHtlc, LnFeatures,
61
    ShortChannelID, HtlcLog, NoPathFound, InvalidGossipMsg, FeeBudgetExceeded, ImportedChannelBackupStorage,
62
    OnchainChannelBackupStorage, ln_compare_features, IncompatibleLightningFeatures, PaymentFeeBudget,
63
    NBLOCK_CLTV_DELTA_TOO_FAR_INTO_FUTURE
64
)
65
from .lnonion import decode_onion_error, OnionFailureCode, OnionRoutingFailure, OnionPacket
5✔
66
from .lnmsg import decode_msg
5✔
67
from .lnrouter import (
5✔
68
    RouteEdge, LNPaymentRoute, LNPaymentPath, is_route_within_budget, NoChannelPolicy, LNPathInconsistent
69
)
70
from .lnwatcher import LNWalletWatcher
5✔
71
from .submarine_swaps import SwapManager
5✔
72
from .mpp_split import suggest_splits, SplitConfigRating
5✔
73
from .trampoline import (
5✔
74
    create_trampoline_route_and_onion, is_legacy_relay, trampolines_by_id, hardcoded_trampoline_nodes,
75
    is_hardcoded_trampoline
76
)
77

78
if TYPE_CHECKING:
5✔
79
    from .network import Network
×
80
    from .wallet import Abstract_Wallet
×
81
    from .channel_db import ChannelDB
×
82
    from .simple_config import SimpleConfig
×
83

84

85
SAVED_PR_STATUS = [PR_PAID, PR_UNPAID]  # status that are persisted
5✔
86

87
NUM_PEERS_TARGET = 4
5✔
88

89
# onchain channel backup data
90
CB_VERSION = 0
5✔
91
CB_MAGIC_BYTES = bytes([0, 0, 0, CB_VERSION])
5✔
92
NODE_ID_PREFIX_LEN = 16
5✔
93

94

95
class PaymentDirection(IntEnum):
5✔
96
    SENT = 0
5✔
97
    RECEIVED = 1
5✔
98
    SELF_PAYMENT = 2
5✔
99
    FORWARDING = 3
5✔
100

101

102
class PaymentInfo(NamedTuple):
5✔
103
    payment_hash: bytes
5✔
104
    amount_msat: Optional[int]
5✔
105
    direction: int
5✔
106
    status: int
5✔
107

108

109
# Note: these states are persisted in the wallet file.
110
# Do not modify them without performing a wallet db upgrade
111
class RecvMPPResolution(IntEnum):
5✔
112
    WAITING = 0
5✔
113
    EXPIRED = 1
5✔
114
    ACCEPTED = 2
5✔
115
    FAILED = 3
5✔
116

117

118
class ReceivedMPPStatus(NamedTuple):
5✔
119
    resolution: RecvMPPResolution
5✔
120
    expected_msat: int
5✔
121
    htlc_set: Set[Tuple[ShortChannelID, UpdateAddHtlc]]
5✔
122

123
    @stored_in('received_mpp_htlcs', tuple)
5✔
124
    def from_tuple(resolution, expected_msat, htlc_list) -> 'ReceivedMPPStatus':
5✔
125
        htlc_set = set([(ShortChannelID(bytes.fromhex(scid)), UpdateAddHtlc.from_tuple(*x)) for (scid, x) in htlc_list])
×
126
        return ReceivedMPPStatus(
×
127
            resolution=RecvMPPResolution(resolution),
128
            expected_msat=expected_msat,
129
            htlc_set=htlc_set)
130

131

132
SentHtlcKey = Tuple[bytes, ShortChannelID, int]  # RHASH, scid, htlc_id
5✔
133

134

135
class SentHtlcInfo(NamedTuple):
5✔
136
    route: LNPaymentRoute
5✔
137
    payment_secret_orig: bytes
5✔
138
    payment_secret_bucket: bytes
5✔
139
    amount_msat: int
5✔
140
    bucket_msat: int
5✔
141
    amount_receiver_msat: int
5✔
142
    trampoline_fee_level: Optional[int]
5✔
143
    trampoline_route: Optional[LNPaymentRoute]
5✔
144

145

146
class ErrorAddingPeer(Exception): pass
5✔
147

148

149
# set some feature flags as baseline for both LNWallet and LNGossip
150
# note that e.g. DATA_LOSS_PROTECT is needed for LNGossip as many peers require it
151
BASE_FEATURES = (
5✔
152
    LnFeatures(0)
153
    | LnFeatures.OPTION_DATA_LOSS_PROTECT_OPT
154
    | LnFeatures.OPTION_STATIC_REMOTEKEY_OPT
155
    | LnFeatures.VAR_ONION_OPT
156
    | LnFeatures.PAYMENT_SECRET_OPT
157
    | LnFeatures.OPTION_UPFRONT_SHUTDOWN_SCRIPT_OPT
158
)
159

160
# we do not want to receive unrequested gossip (see lnpeer.maybe_save_remote_update)
161
LNWALLET_FEATURES = (
5✔
162
    BASE_FEATURES
163
    | LnFeatures.OPTION_DATA_LOSS_PROTECT_REQ
164
    | LnFeatures.OPTION_STATIC_REMOTEKEY_REQ
165
    | LnFeatures.GOSSIP_QUERIES_REQ
166
    | LnFeatures.VAR_ONION_REQ
167
    | LnFeatures.PAYMENT_SECRET_REQ
168
    | LnFeatures.BASIC_MPP_OPT
169
    | LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ELECTRUM
170
    | LnFeatures.OPTION_SHUTDOWN_ANYSEGWIT_OPT
171
    | LnFeatures.OPTION_CHANNEL_TYPE_OPT
172
    | LnFeatures.OPTION_SCID_ALIAS_OPT
173
    | LnFeatures.OPTION_SUPPORT_LARGE_CHANNEL_OPT
174
)
175

176
LNGOSSIP_FEATURES = (
5✔
177
    BASE_FEATURES
178
    | LnFeatures.GOSSIP_QUERIES_OPT
179
    | LnFeatures.GOSSIP_QUERIES_REQ
180
)
181

182

183
class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
5✔
184

185
    def __init__(self, node_keypair, features: LnFeatures, *, config: 'SimpleConfig'):
5✔
186
        Logger.__init__(self)
5✔
187
        NetworkRetryManager.__init__(
5✔
188
            self,
189
            max_retry_delay_normal=3600,
190
            init_retry_delay_normal=600,
191
            max_retry_delay_urgent=300,
192
            init_retry_delay_urgent=4,
193
        )
194
        self.lock = threading.RLock()
5✔
195
        self.node_keypair = node_keypair
5✔
196
        self._peers = {}  # type: Dict[bytes, Peer]  # pubkey -> Peer  # needs self.lock
5✔
197
        self.taskgroup = OldTaskGroup()
5✔
198
        self.listen_server = None  # type: Optional[asyncio.AbstractServer]
5✔
199
        self.features = features
5✔
200
        self.network = None  # type: Optional[Network]
5✔
201
        self.config = config
5✔
202
        self.stopping_soon = False  # whether we are being shut down
5✔
203
        self.register_callbacks()
5✔
204

205
    @property
5✔
206
    def channel_db(self) -> 'ChannelDB':
5✔
207
        return self.network.channel_db if self.network else None
×
208

209
    def uses_trampoline(self) -> bool:
5✔
210
        return not bool(self.channel_db)
×
211

212
    @property
5✔
213
    def peers(self) -> Mapping[bytes, Peer]:
5✔
214
        """Returns a read-only copy of peers."""
215
        with self.lock:
×
216
            return self._peers.copy()
×
217

218
    def channels_for_peer(self, node_id: bytes) -> Dict[bytes, Channel]:
5✔
219
        return {}
×
220

221
    def get_node_alias(self, node_id: bytes) -> Optional[str]:
5✔
222
        """Returns the alias of the node, or None if unknown."""
223
        node_alias = None
×
224
        if not self.uses_trampoline():
×
225
            node_info = self.channel_db.get_node_info_for_node_id(node_id)
×
226
            if node_info:
×
227
                node_alias = node_info.alias
×
228
        else:
229
            for k, v in hardcoded_trampoline_nodes().items():
×
230
                if v.pubkey.startswith(node_id):
×
231
                    node_alias = k
×
232
                    break
×
233
        return node_alias
×
234

235
    async def maybe_listen(self):
5✔
236
        # FIXME: only one LNWorker can listen at a time (single port)
237
        listen_addr = self.config.LIGHTNING_LISTEN
×
238
        if listen_addr:
×
239
            self.logger.info(f'lightning_listen enabled. will try to bind: {listen_addr!r}')
×
240
            try:
×
241
                netaddr = NetAddress.from_string(listen_addr)
×
242
            except Exception as e:
×
243
                self.logger.error(f"failed to parse config key '{self.config.cv.LIGHTNING_LISTEN.key()}'. got: {e!r}")
×
244
                return
×
245
            addr = str(netaddr.host)
×
246

247
            async def cb(reader, writer):
×
248
                transport = LNResponderTransport(self.node_keypair.privkey, reader, writer)
×
249
                try:
×
250
                    node_id = await transport.handshake()
×
251
                except Exception as e:
×
252
                    self.logger.info(f'handshake failure from incoming connection: {e!r}')
×
253
                    return
×
254
                await self._add_peer_from_transport(node_id=node_id, transport=transport)
×
255
            try:
×
256
                self.listen_server = await asyncio.start_server(cb, addr, netaddr.port)
×
257
            except OSError as e:
×
258
                self.logger.error(f"cannot listen for lightning p2p. error: {e!r}")
×
259

260
    async def main_loop(self):
5✔
261
        self.logger.info("starting taskgroup.")
×
262
        try:
×
263
            async with self.taskgroup as group:
×
264
                await group.spawn(asyncio.Event().wait)  # run forever (until cancel)
×
265
        except Exception as e:
×
266
            self.logger.exception("taskgroup died.")
×
267
        finally:
268
            self.logger.info("taskgroup stopped.")
×
269

270
    async def _maintain_connectivity(self):
5✔
271
        while True:
×
272
            await asyncio.sleep(1)
×
273
            if self.stopping_soon:
×
274
                return
×
275
            now = time.time()
×
276
            if len(self._peers) >= NUM_PEERS_TARGET:
×
277
                continue
×
278
            peers = await self._get_next_peers_to_try()
×
279
            for peer in peers:
×
280
                if self._can_retry_addr(peer, now=now):
×
281
                    try:
×
282
                        await self._add_peer(peer.host, peer.port, peer.pubkey)
×
283
                    except ErrorAddingPeer as e:
×
284
                        self.logger.info(f"failed to add peer: {peer}. exc: {e!r}")
×
285

286
    async def _add_peer(self, host: str, port: int, node_id: bytes) -> Peer:
5✔
287
        if node_id in self._peers:
×
288
            return self._peers[node_id]
×
289
        port = int(port)
×
290
        peer_addr = LNPeerAddr(host, port, node_id)
×
291
        self._trying_addr_now(peer_addr)
×
292
        self.logger.info(f"adding peer {peer_addr}")
×
293
        if node_id == self.node_keypair.pubkey:
×
294
            raise ErrorAddingPeer("cannot connect to self")
×
295
        transport = LNTransport(self.node_keypair.privkey, peer_addr,
×
296
                                e_proxy=ESocksProxy.from_network_settings(self.network))
297
        peer = await self._add_peer_from_transport(node_id=node_id, transport=transport)
×
298
        assert peer
×
299
        return peer
×
300

301
    async def _add_peer_from_transport(self, *, node_id: bytes, transport: LNTransportBase) -> Optional[Peer]:
5✔
302
        with self.lock:
×
303
            existing_peer = self._peers.get(node_id)
×
304
            if existing_peer:
×
305
                # Two instances of the same wallet are attempting to connect simultaneously.
306
                # If we let the new connection replace the existing one, the two instances might
307
                # both keep trying to reconnect, resulting in neither being usable.
308
                if existing_peer.is_initialized():
×
309
                    # give priority to the existing connection
310
                    return
×
311
                else:
312
                    # Use the new connection. (e.g. old peer might be an outgoing connection
313
                    # for an outdated host/port that will never connect)
314
                    existing_peer.close_and_cleanup()
×
315
            peer = Peer(self, node_id, transport)
×
316
            assert node_id not in self._peers
×
317
            self._peers[node_id] = peer
×
318
        await self.taskgroup.spawn(peer.main_loop())
×
319
        return peer
×
320

321
    def peer_closed(self, peer: Peer) -> None:
5✔
322
        with self.lock:
×
323
            peer2 = self._peers.get(peer.pubkey)
×
324
            if peer2 is peer:
×
325
                self._peers.pop(peer.pubkey)
×
326

327
    def num_peers(self) -> int:
5✔
328
        return sum([p.is_initialized() for p in self.peers.values()])
×
329

330
    def start_network(self, network: 'Network'):
5✔
331
        assert network
×
332
        assert self.network is None, "already started"
×
333
        self.network = network
×
334
        self._add_peers_from_config()
×
335
        asyncio.run_coroutine_threadsafe(self.main_loop(), self.network.asyncio_loop)
×
336

337
    async def stop(self):
5✔
338
        if self.listen_server:
5✔
339
            self.listen_server.close()
×
340
        self.unregister_callbacks()
5✔
341
        await self.taskgroup.cancel_remaining()
5✔
342

343
    def _add_peers_from_config(self):
5✔
344
        peer_list = self.config.LIGHTNING_PEERS or []
×
345
        for host, port, pubkey in peer_list:
×
346
            asyncio.run_coroutine_threadsafe(
×
347
                self._add_peer(host, int(port), bfh(pubkey)),
348
                self.network.asyncio_loop)
349

350
    def is_good_peer(self, peer: LNPeerAddr) -> bool:
5✔
351
        # the purpose of this method is to filter peers that advertise the desired feature bits
352
        # it is disabled for now, because feature bits published in node announcements seem to be unreliable
353
        return True
×
354
        node_id = peer.pubkey
355
        node = self.channel_db._nodes.get(node_id)
356
        if not node:
357
            return False
358
        try:
359
            ln_compare_features(self.features, node.features)
360
        except IncompatibleLightningFeatures:
361
            return False
362
        #self.logger.info(f'is_good {peer.host}')
363
        return True
364

365
    def on_peer_successfully_established(self, peer: Peer) -> None:
5✔
366
        if isinstance(peer.transport, LNTransport):
5✔
367
            peer_addr = peer.transport.peer_addr
×
368
            # reset connection attempt count
369
            self._on_connection_successfully_established(peer_addr)
×
370
            if not self.uses_trampoline():
×
371
                # add into channel db
372
                self.channel_db.add_recent_peer(peer_addr)
×
373
            # save network address into channels we might have with peer
374
            for chan in peer.channels.values():
×
375
                chan.add_or_update_peer_addr(peer_addr)
×
376

377
    async def _get_next_peers_to_try(self) -> Sequence[LNPeerAddr]:
5✔
378
        now = time.time()
×
379
        await self.channel_db.data_loaded.wait()
×
380
        # first try from recent peers
381
        recent_peers = self.channel_db.get_recent_peers()
×
382
        for peer in recent_peers:
×
383
            if not peer:
×
384
                continue
×
385
            if peer.pubkey in self._peers:
×
386
                continue
×
387
            if not self._can_retry_addr(peer, now=now):
×
388
                continue
×
389
            if not self.is_good_peer(peer):
×
390
                continue
×
391
            return [peer]
×
392
        # try random peer from graph
393
        unconnected_nodes = self.channel_db.get_200_randomly_sorted_nodes_not_in(self.peers.keys())
×
394
        if unconnected_nodes:
×
395
            for node_id in unconnected_nodes:
×
396
                addrs = self.channel_db.get_node_addresses(node_id)
×
397
                if not addrs:
×
398
                    continue
×
399
                host, port, timestamp = self.choose_preferred_address(list(addrs))
×
400
                try:
×
401
                    peer = LNPeerAddr(host, port, node_id)
×
402
                except ValueError:
×
403
                    continue
×
404
                if not self._can_retry_addr(peer, now=now):
×
405
                    continue
×
406
                if not self.is_good_peer(peer):
×
407
                    continue
×
408
                #self.logger.info('taking random ln peer from our channel db')
409
                return [peer]
×
410

411
        # getting desperate... let's try hardcoded fallback list of peers
412
        fallback_list = constants.net.FALLBACK_LN_NODES
×
413
        fallback_list = [peer for peer in fallback_list if self._can_retry_addr(peer, now=now)]
×
414
        if fallback_list:
×
415
            return [random.choice(fallback_list)]
×
416

417
        # last resort: try dns seeds (BOLT-10)
418
        return await run_in_thread(self._get_peers_from_dns_seeds)
×
419

420
    def _get_peers_from_dns_seeds(self) -> Sequence[LNPeerAddr]:
5✔
421
        # NOTE: potentially long blocking call, do not run directly on asyncio event loop.
422
        # Return several peers to reduce the number of dns queries.
423
        if not constants.net.LN_DNS_SEEDS:
×
424
            return []
×
425
        dns_seed = random.choice(constants.net.LN_DNS_SEEDS)
×
426
        self.logger.info('asking dns seed "{}" for ln peers'.format(dns_seed))
×
427
        try:
×
428
            # note: this might block for several seconds
429
            # this will include bech32-encoded-pubkeys and ports
430
            srv_answers = resolve_dns_srv('r{}.{}'.format(
×
431
                constants.net.LN_REALM_BYTE, dns_seed))
432
        except dns.exception.DNSException as e:
×
433
            self.logger.info(f'failed querying (1) dns seed "{dns_seed}" for ln peers: {repr(e)}')
×
434
            return []
×
435
        random.shuffle(srv_answers)
×
436
        num_peers = 2 * NUM_PEERS_TARGET
×
437
        srv_answers = srv_answers[:num_peers]
×
438
        # we now have pubkeys and ports but host is still needed
439
        peers = []
×
440
        for srv_ans in srv_answers:
×
441
            try:
×
442
                # note: this might block for several seconds
443
                answers = dns.resolver.resolve(srv_ans['host'])
×
444
            except dns.exception.DNSException as e:
×
445
                self.logger.info(f'failed querying (2) dns seed "{dns_seed}" for ln peers: {repr(e)}')
×
446
                continue
×
447
            try:
×
448
                ln_host = str(answers[0])
×
449
                port = int(srv_ans['port'])
×
450
                bech32_pubkey = srv_ans['host'].split('.')[0]
×
451
                pubkey = get_compressed_pubkey_from_bech32(bech32_pubkey)
×
452
                peers.append(LNPeerAddr(ln_host, port, pubkey))
×
453
            except Exception as e:
×
454
                self.logger.info(f'error with parsing peer from dns seed: {repr(e)}')
×
455
                continue
×
456
        self.logger.info(f'got {len(peers)} ln peers from dns seed')
×
457
        return peers
×
458

459
    @staticmethod
5✔
460
    def choose_preferred_address(addr_list: Sequence[Tuple[str, int, int]]) -> Tuple[str, int, int]:
5✔
461
        assert len(addr_list) >= 1
×
462
        # choose the most recent one that is an IP
463
        for host, port, timestamp in sorted(addr_list, key=lambda a: -a[2]):
×
464
            if is_ip_address(host):
×
465
                return host, port, timestamp
×
466
        # otherwise choose one at random
467
        # TODO maybe filter out onion if not on tor?
468
        choice = random.choice(addr_list)
×
469
        return choice
×
470

471
    @event_listener
5✔
472
    def on_event_proxy_set(self, *args):
5✔
473
        for peer in self.peers.values():
×
474
            peer.close_and_cleanup()
×
475
        self._clear_addr_retry_times()
×
476

477
    @log_exceptions
5✔
478
    async def add_peer(self, connect_str: str) -> Peer:
5✔
479
        node_id, rest = extract_nodeid(connect_str)
×
480
        peer = self._peers.get(node_id)
×
481
        if not peer:
×
482
            if rest is not None:
×
483
                host, port = split_host_port(rest)
×
484
            else:
485
                if self.uses_trampoline():
×
486
                    addr = trampolines_by_id().get(node_id)
×
487
                    if not addr:
×
488
                        raise ConnStringFormatError(_('Address unknown for node:') + ' ' + node_id.hex())
×
489
                    host, port = addr.host, addr.port
×
490
                else:
491
                    addrs = self.channel_db.get_node_addresses(node_id)
×
492
                    if not addrs:
×
493
                        raise ConnStringFormatError(_('Don\'t know any addresses for node:') + ' ' + node_id.hex())
×
494
                    host, port, timestamp = self.choose_preferred_address(list(addrs))
×
495
            port = int(port)
×
496

497
            if not self.network.proxy:
×
498
                # Try DNS-resolving the host (if needed). This is simply so that
499
                # the caller gets a nice exception if it cannot be resolved.
500
                # (we don't do the DNS lookup if a proxy is set, to avoid a DNS-leak)
501
                if host.endswith('.onion'):
×
502
                    raise ConnStringFormatError(_('.onion address, but no proxy configured'))
×
503
                try:
×
504
                    await asyncio.get_running_loop().getaddrinfo(host, port)
×
505
                except socket.gaierror:
×
506
                    raise ConnStringFormatError(_('Hostname does not resolve (getaddrinfo failed)'))
×
507

508
            # add peer
509
            peer = await self._add_peer(host, port, node_id)
×
510
        return peer
×
511

512

513
class LNGossip(LNWorker):
5✔
514
    max_age = 14*24*3600
5✔
515
    LOGGING_SHORTCUT = 'g'
5✔
516

517
    def __init__(self, config: 'SimpleConfig'):
5✔
518
        seed = os.urandom(32)
×
519
        node = BIP32Node.from_rootseed(seed, xtype='standard')
×
520
        xprv = node.to_xprv()
×
521
        node_keypair = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.NODE_KEY)
×
522
        LNWorker.__init__(self, node_keypair, LNGOSSIP_FEATURES, config=config)
×
523
        self.unknown_ids = set()
×
524

525
    def start_network(self, network: 'Network'):
5✔
526
        super().start_network(network)
×
527
        for coro in [
×
528
                self._maintain_connectivity(),
529
                self.maintain_db(),
530
        ]:
531
            tg_coro = self.taskgroup.spawn(coro)
×
532
            asyncio.run_coroutine_threadsafe(tg_coro, self.network.asyncio_loop)
×
533

534
    async def maintain_db(self):
5✔
535
        await self.channel_db.data_loaded.wait()
×
536
        while True:
×
537
            if len(self.unknown_ids) == 0:
×
538
                self.channel_db.prune_old_policies(self.max_age)
×
539
                self.channel_db.prune_orphaned_channels()
×
540
            await asyncio.sleep(120)
×
541

542
    async def add_new_ids(self, ids: Iterable[bytes]):
5✔
543
        known = self.channel_db.get_channel_ids()
×
544
        new = set(ids) - set(known)
×
545
        self.unknown_ids.update(new)
×
546
        util.trigger_callback('unknown_channels', len(self.unknown_ids))
×
547
        util.trigger_callback('gossip_peers', self.num_peers())
×
548
        util.trigger_callback('ln_gossip_sync_progress')
×
549

550
    def get_ids_to_query(self) -> Sequence[bytes]:
5✔
551
        N = 500
×
552
        l = list(self.unknown_ids)
×
553
        self.unknown_ids = set(l[N:])
×
554
        util.trigger_callback('unknown_channels', len(self.unknown_ids))
×
555
        util.trigger_callback('ln_gossip_sync_progress')
×
556
        return l[0:N]
×
557

558
    def get_sync_progress_estimate(self) -> Tuple[Optional[int], Optional[int], Optional[int]]:
5✔
559
        """Estimates the gossip synchronization process and returns the number
560
        of synchronized channels, the total channels in the network and a
561
        rescaled percentage of the synchronization process."""
562
        if self.num_peers() == 0:
×
563
            return None, None, None
×
564
        nchans_with_0p, nchans_with_1p, nchans_with_2p = self.channel_db.get_num_channels_partitioned_by_policy_count()
×
565
        num_db_channels = nchans_with_0p + nchans_with_1p + nchans_with_2p
×
566
        # some channels will never have two policies (only one is in gossip?...)
567
        # so if we have at least 1 policy for a channel, we consider that channel "complete" here
568
        current_est = num_db_channels - nchans_with_0p
×
569
        total_est = len(self.unknown_ids) + num_db_channels
×
570

571
        progress = current_est / total_est if total_est and current_est else 0
×
572
        progress_percent = (1.0 / 0.95 * progress) * 100
×
573
        progress_percent = min(progress_percent, 100)
×
574
        progress_percent = round(progress_percent)
×
575
        # take a minimal number of synchronized channels to get a more accurate
576
        # percentage estimate
577
        if current_est < 200:
×
578
            progress_percent = 0
×
579
        return current_est, total_est, progress_percent
×
580

581
    async def process_gossip(self, chan_anns, node_anns, chan_upds):
5✔
582
        # note: we run in the originating peer's TaskGroup, so we can safely raise here
583
        #       and disconnect only from that peer
584
        await self.channel_db.data_loaded.wait()
×
585
        self.logger.debug(f'process_gossip {len(chan_anns)} {len(node_anns)} {len(chan_upds)}')
×
586

587
        # channel announcements
588
        def process_chan_anns():
×
589
            for payload in chan_anns:
×
590
                self.channel_db.verify_channel_announcement(payload)
×
591
            self.channel_db.add_channel_announcements(chan_anns)
×
592
        await run_in_thread(process_chan_anns)
×
593

594
        # node announcements
595
        def process_node_anns():
×
596
            for payload in node_anns:
×
597
                self.channel_db.verify_node_announcement(payload)
×
598
            self.channel_db.add_node_announcements(node_anns)
×
599
        await run_in_thread(process_node_anns)
×
600
        # channel updates
601
        categorized_chan_upds = await run_in_thread(partial(
×
602
            self.channel_db.add_channel_updates,
603
            chan_upds,
604
            max_age=self.max_age))
605
        orphaned = categorized_chan_upds.orphaned
×
606
        if orphaned:
×
607
            self.logger.info(f'adding {len(orphaned)} unknown channel ids')
×
608
            orphaned_ids = [c['short_channel_id'] for c in orphaned]
×
609
            await self.add_new_ids(orphaned_ids)
×
610
        if categorized_chan_upds.good:
×
611
            self.logger.debug(f'process_gossip: {len(categorized_chan_upds.good)}/{len(chan_upds)}')
×
612

613

614
class PaySession(Logger):
5✔
615
    def __init__(
5✔
616
            self,
617
            *,
618
            payment_hash: bytes,
619
            payment_secret: bytes,
620
            initial_trampoline_fee_level: int,
621
            invoice_features: int,
622
            r_tags,
623
            min_final_cltv_delta: int,  # delta for last node (typically from invoice)
624
            amount_to_pay: int,  # total payment amount final receiver will get
625
            invoice_pubkey: bytes,
626
            uses_trampoline: bool,  # whether sender uses trampoline or gossip
627
            use_two_trampolines: bool,  # whether legacy payments will try to use two trampolines
628
    ):
629
        assert payment_hash
5✔
630
        assert payment_secret
5✔
631
        self.payment_hash = payment_hash
5✔
632
        self.payment_secret = payment_secret
5✔
633
        self.payment_key = payment_hash + payment_secret
5✔
634
        Logger.__init__(self)
5✔
635

636
        self.invoice_features = LnFeatures(invoice_features)
5✔
637
        self.r_tags = r_tags
5✔
638
        self.min_final_cltv_delta = min_final_cltv_delta
5✔
639
        self.amount_to_pay = amount_to_pay
5✔
640
        self.invoice_pubkey = invoice_pubkey
5✔
641

642
        self.sent_htlcs_q = asyncio.Queue()  # type: asyncio.Queue[HtlcLog]
5✔
643
        self.start_time = time.time()
5✔
644

645
        self.uses_trampoline = uses_trampoline
5✔
646
        self.trampoline_fee_level = initial_trampoline_fee_level
5✔
647
        self.failed_trampoline_routes = []
5✔
648
        self.use_two_trampolines = use_two_trampolines
5✔
649
        self._sent_buckets = dict()  # psecret_bucket -> (amount_sent, amount_failed)
5✔
650

651
        self._amount_inflight = 0  # what we sent in htlcs (that receiver gets, without fees)
5✔
652
        self._nhtlcs_inflight = 0
5✔
653
        self.is_active = True  # is still trying to send new htlcs?
5✔
654

655
    def diagnostic_name(self):
5✔
656
        pkey = sha256(self.payment_key)
5✔
657
        return f"{self.payment_hash[:4].hex()}-{pkey[:2].hex()}"
5✔
658

659
    def maybe_raise_trampoline_fee(self, htlc_log: HtlcLog):
5✔
660
        if htlc_log.trampoline_fee_level == self.trampoline_fee_level:
5✔
661
            self.trampoline_fee_level += 1
5✔
662
            self.failed_trampoline_routes = []
5✔
663
            self.logger.info(f'raising trampoline fee level {self.trampoline_fee_level}')
5✔
664
        else:
665
            self.logger.info(f'NOT raising trampoline fee level, already at {self.trampoline_fee_level}')
5✔
666

667
    def handle_failed_trampoline_htlc(self, *, htlc_log: HtlcLog, failure_msg: OnionRoutingFailure):
5✔
668
        # FIXME The trampoline nodes in the path are chosen randomly.
669
        #       Some of the errors might depend on how we have chosen them.
670
        #       Having more attempts is currently useful in part because of the randomness,
671
        #       instead we should give feedback to create_routes_for_payment.
672
        # Sometimes the trampoline node fails to send a payment and returns
673
        # TEMPORARY_CHANNEL_FAILURE, while it succeeds with a higher trampoline fee.
674
        if failure_msg.code in (
5✔
675
                OnionFailureCode.TRAMPOLINE_FEE_INSUFFICIENT,
676
                OnionFailureCode.TRAMPOLINE_EXPIRY_TOO_SOON,
677
                OnionFailureCode.TEMPORARY_CHANNEL_FAILURE):
678
            # TODO: parse the node policy here (not returned by eclair yet)
679
            # TODO: erring node is always the first trampoline even if second
680
            #  trampoline demands more fees, we can't influence this
681
            self.maybe_raise_trampoline_fee(htlc_log)
5✔
682
        elif self.use_two_trampolines:
5✔
683
            self.use_two_trampolines = False
×
684
        elif failure_msg.code in (
5✔
685
                OnionFailureCode.UNKNOWN_NEXT_PEER,
686
                OnionFailureCode.TEMPORARY_NODE_FAILURE):
687
            trampoline_route = htlc_log.route
5✔
688
            r = [hop.end_node.hex() for hop in trampoline_route]
5✔
689
            self.logger.info(f'failed trampoline route: {r}')
5✔
690
            if r not in self.failed_trampoline_routes:
5✔
691
                self.failed_trampoline_routes.append(r)
5✔
692
            else:
693
                pass  # maybe the route was reused between different MPP parts
×
694
        else:
695
            raise PaymentFailure(failure_msg.code_name())
5✔
696

697
    async def wait_for_one_htlc_to_resolve(self) -> HtlcLog:
5✔
698
        self.logger.info(f"waiting... amount_inflight={self._amount_inflight}. nhtlcs_inflight={self._nhtlcs_inflight}")
5✔
699
        htlc_log = await self.sent_htlcs_q.get()
5✔
700
        self._amount_inflight -= htlc_log.amount_msat
5✔
701
        self._nhtlcs_inflight -= 1
5✔
702
        if self._amount_inflight < 0 or self._nhtlcs_inflight < 0:
5✔
703
            raise Exception(f"amount_inflight={self._amount_inflight}, nhtlcs_inflight={self._nhtlcs_inflight}. both should be >= 0 !")
×
704
        return htlc_log
5✔
705

706
    def add_new_htlc(self, sent_htlc_info: SentHtlcInfo):
5✔
707
        self._nhtlcs_inflight += 1
5✔
708
        self._amount_inflight += sent_htlc_info.amount_receiver_msat
5✔
709
        if self._amount_inflight > self.amount_to_pay:  # safety belts
5✔
710
            raise Exception(f"amount_inflight={self._amount_inflight} > amount_to_pay={self.amount_to_pay}")
×
711
        shi = sent_htlc_info
5✔
712
        bkey = shi.payment_secret_bucket
5✔
713
        # if we sent MPP to a trampoline, add item to sent_buckets
714
        if self.uses_trampoline and shi.amount_msat != shi.bucket_msat:
5✔
715
            if bkey not in self._sent_buckets:
5✔
716
                self._sent_buckets[bkey] = (0, 0)
5✔
717
            amount_sent, amount_failed = self._sent_buckets[bkey]
5✔
718
            amount_sent += shi.amount_receiver_msat
5✔
719
            self._sent_buckets[bkey] = amount_sent, amount_failed
5✔
720

721
    def on_htlc_fail_get_fail_amt_to_propagate(self, sent_htlc_info: SentHtlcInfo) -> Optional[int]:
5✔
722
        shi = sent_htlc_info
5✔
723
        # check sent_buckets if we use trampoline
724
        bkey = shi.payment_secret_bucket
5✔
725
        if self.uses_trampoline and bkey in self._sent_buckets:
5✔
726
            amount_sent, amount_failed = self._sent_buckets[bkey]
5✔
727
            amount_failed += shi.amount_receiver_msat
5✔
728
            self._sent_buckets[bkey] = amount_sent, amount_failed
5✔
729
            if amount_sent != amount_failed:
5✔
730
                self.logger.info('bucket still active...')
5✔
731
                return None
5✔
732
            self.logger.info('bucket failed')
5✔
733
            return amount_sent
5✔
734
        # not using trampoline buckets
735
        return shi.amount_receiver_msat
5✔
736

737
    def get_outstanding_amount_to_send(self) -> int:
5✔
738
        return self.amount_to_pay - self._amount_inflight
5✔
739

740
    def can_be_deleted(self) -> bool:
5✔
741
        """Returns True iff finished sending htlcs AND all pending htlcs have resolved."""
742
        if self.is_active:
5✔
743
            return False
5✔
744
        # note: no one is consuming from sent_htlcs_q anymore
745
        nhtlcs_resolved = self.sent_htlcs_q.qsize()
5✔
746
        assert nhtlcs_resolved <= self._nhtlcs_inflight
5✔
747
        return nhtlcs_resolved == self._nhtlcs_inflight
5✔
748

749

750
class LNWallet(LNWorker):
5✔
751

752
    lnwatcher: Optional['LNWalletWatcher']
5✔
753
    MPP_EXPIRY = 120
5✔
754
    TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 3  # seconds
5✔
755
    PAYMENT_TIMEOUT = 120
5✔
756
    MPP_SPLIT_PART_FRACTION = 0.2
5✔
757
    MPP_SPLIT_PART_MINAMT_MSAT = 5_000_000
5✔
758

759
    def __init__(self, wallet: 'Abstract_Wallet', xprv):
5✔
760
        self.wallet = wallet
5✔
761
        self.config = wallet.config
5✔
762
        self.db = wallet.db
5✔
763
        self.node_keypair = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.NODE_KEY)
5✔
764
        self.backup_key = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.BACKUP_CIPHER).privkey
5✔
765
        self.static_payment_key = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.PAYMENT_BASE)
5✔
766
        self.payment_secret_key = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.PAYMENT_SECRET_KEY).privkey
5✔
767
        self.funding_root_keypair = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.FUNDING_ROOT_KEY)
5✔
768
        Logger.__init__(self)
5✔
769
        features = LNWALLET_FEATURES
5✔
770
        if self.config.ENABLE_ANCHOR_CHANNELS:
5✔
771
            features |= LnFeatures.OPTION_ANCHORS_ZERO_FEE_HTLC_OPT
×
772
        if self.config.ACCEPT_ZEROCONF_CHANNELS:
5✔
773
            features |= LnFeatures.OPTION_ZEROCONF_OPT
×
774
        LNWorker.__init__(self, self.node_keypair, features, config=self.config)
5✔
775
        self.lnwatcher = None
5✔
776
        self.lnrater: LNRater = None
5✔
777
        self.payment_info = self.db.get_dict('lightning_payments')     # RHASH -> amount, direction, is_paid
5✔
778
        self.preimages = self.db.get_dict('lightning_preimages')   # RHASH -> preimage
5✔
779
        self._bolt11_cache = {}
5✔
780
        # note: this sweep_address is only used as fallback; as it might result in address-reuse
781
        self.logs = defaultdict(list)  # type: Dict[str, List[HtlcLog]]  # key is RHASH  # (not persisted)
5✔
782
        # used in tests
783
        self.enable_htlc_settle = True
5✔
784
        self.enable_htlc_settle_onchain = True
5✔
785
        self.enable_htlc_forwarding = True
5✔
786

787
        # note: accessing channels (besides simple lookup) needs self.lock!
788
        self._channels = {}  # type: Dict[bytes, Channel]
5✔
789
        channels = self.db.get_dict("channels")
5✔
790
        for channel_id, c in random_shuffled_copy(channels.items()):
5✔
791
            self._channels[bfh(channel_id)] = chan = Channel(c, lnworker=self)
5✔
792
            self.wallet.set_reserved_addresses_for_chan(chan, reserved=True)
5✔
793

794
        self._channel_backups = {}  # type: Dict[bytes, ChannelBackup]
5✔
795
        # order is important: imported should overwrite onchain
796
        for name in ["onchain_channel_backups", "imported_channel_backups"]:
5✔
797
            channel_backups = self.db.get_dict(name)
5✔
798
            for channel_id, storage in channel_backups.items():
5✔
799
                self._channel_backups[bfh(channel_id)] = cb = ChannelBackup(storage, lnworker=self)
×
800
                self.wallet.set_reserved_addresses_for_chan(cb, reserved=True)
×
801

802
        self._paysessions = dict()                      # type: Dict[bytes, PaySession]
5✔
803
        self.sent_htlcs_info = dict()                   # type: Dict[SentHtlcKey, SentHtlcInfo]
5✔
804
        self.received_mpp_htlcs = self.db.get_dict('received_mpp_htlcs')   # type: Dict[str, ReceivedMPPStatus]  # payment_key -> ReceivedMPPStatus
5✔
805

806
        # detect inflight payments
807
        self.inflight_payments = set()        # (not persisted) keys of invoices that are in PR_INFLIGHT state
5✔
808
        for payment_hash in self.get_payments(status='inflight').keys():
5✔
809
            self.set_invoice_status(payment_hash.hex(), PR_INFLIGHT)
×
810

811
        # payment forwarding
812
        self.active_forwardings = self.db.get_dict('active_forwardings')    # type: Dict[str, List[str]]        # Dict: payment_key -> list of htlc_keys
5✔
813
        self.forwarding_failures = self.db.get_dict('forwarding_failures')  # type: Dict[str, Tuple[str, str]]  # Dict: payment_key -> (error_bytes, error_message)
5✔
814
        self.downstream_to_upstream_htlc = {}                               # type: Dict[str, str]              # Dict: htlc_key -> htlc_key (not persisted)
5✔
815

816
        # payment_hash -> callback:
817
        self.hold_invoice_callbacks = {}                # type: Dict[bytes, Callable[[bytes], Awaitable[None]]]
5✔
818
        self.payment_bundles = []                       # lists of hashes. todo:persist
5✔
819

820
        self.nostr_keypair = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.NOSTR_KEY)
5✔
821
        self.swap_manager = SwapManager(wallet=self.wallet, lnworker=self)
5✔
822
        self.onion_message_manager = OnionMessageManager(self)
5✔
823

824
    def has_deterministic_node_id(self) -> bool:
5✔
825
        return bool(self.db.get('lightning_xprv'))
×
826

827
    def can_have_recoverable_channels(self) -> bool:
5✔
828
        return (self.has_deterministic_node_id()
×
829
                and not self.config.LIGHTNING_LISTEN)
830

831
    def has_recoverable_channels(self) -> bool:
5✔
832
        """Whether *future* channels opened by this wallet would be recoverable
833
        from seed (via putting OP_RETURN outputs into funding txs).
834
        """
835
        return (self.can_have_recoverable_channels()
×
836
                and self.config.LIGHTNING_USE_RECOVERABLE_CHANNELS)
837

838
    @property
5✔
839
    def channels(self) -> Mapping[bytes, Channel]:
5✔
840
        """Returns a read-only copy of channels."""
841
        with self.lock:
5✔
842
            return self._channels.copy()
5✔
843

844
    @property
5✔
845
    def channel_backups(self) -> Mapping[bytes, ChannelBackup]:
5✔
846
        """Returns a read-only copy of channels."""
847
        with self.lock:
5✔
848
            return self._channel_backups.copy()
5✔
849

850
    def get_channel_objects(self) -> Mapping[bytes, AbstractChannel]:
5✔
851
        r = self.channel_backups
×
852
        r.update(self.channels)
×
853
        return r
×
854

855
    def get_channel_by_id(self, channel_id: bytes) -> Optional[Channel]:
5✔
856
        return self._channels.get(channel_id, None)
5✔
857

858
    def diagnostic_name(self):
5✔
859
        return self.wallet.diagnostic_name()
5✔
860

861
    @ignore_exceptions
5✔
862
    @log_exceptions
5✔
863
    async def sync_with_remote_watchtower(self):
5✔
864
        self.watchtower_ctns = {}
×
865
        while True:
×
866
            # periodically poll if the user updated 'watchtower_url'
867
            await asyncio.sleep(5)
×
868
            watchtower_url = self.config.WATCHTOWER_CLIENT_URL
×
869
            if not watchtower_url:
×
870
                continue
×
871
            parsed_url = urllib.parse.urlparse(watchtower_url)
×
872
            if not (parsed_url.scheme == 'https' or is_private_netaddress(parsed_url.hostname)):
×
873
                self.logger.warning(f"got watchtower URL for remote tower but we won't use it! "
×
874
                                    f"can only use HTTPS (except if private IP): not using {watchtower_url!r}")
875
                continue
×
876
            # try to sync with the remote watchtower
877
            try:
×
878
                async with make_aiohttp_session(proxy=self.network.proxy) as session:
×
879
                    watchtower = JsonRPCClient(session, watchtower_url)
×
880
                    watchtower.add_method('get_ctn')
×
881
                    watchtower.add_method('add_sweep_tx')
×
882
                    for chan in self.channels.values():
×
883
                        await self.sync_channel_with_watchtower(chan, watchtower)
×
884
            except aiohttp.client_exceptions.ClientConnectorError:
×
885
                self.logger.info(f'could not contact remote watchtower {watchtower_url}')
×
886

887
    def get_watchtower_ctn(self, channel_point):
5✔
888
        return self.watchtower_ctns.get(channel_point)
×
889

890
    async def sync_channel_with_watchtower(self, chan: Channel, watchtower):
5✔
891
        outpoint = chan.funding_outpoint.to_str()
×
892
        addr = chan.get_funding_address()
×
893
        current_ctn = chan.get_oldest_unrevoked_ctn(REMOTE)
×
894
        watchtower_ctn = await watchtower.get_ctn(outpoint, addr)
×
895
        for ctn in range(watchtower_ctn + 1, current_ctn):
×
896
            sweeptxs = chan.create_sweeptxs_for_watchtower(ctn)
×
897
            for tx in sweeptxs:
×
898
                await watchtower.add_sweep_tx(outpoint, ctn, tx.inputs()[0].prevout.to_str(), tx.serialize())
×
899
            self.watchtower_ctns[outpoint] = ctn
×
900

901
    def start_network(self, network: 'Network'):
5✔
902
        super().start_network(network)
×
903
        self.lnwatcher = LNWalletWatcher(self, network)
×
904
        self.swap_manager.start_network(network)
×
905
        self.lnrater = LNRater(self, network)
×
906
        self.onion_message_manager.start_network(network=network)
×
907

908
        for chan in self.channels.values():
×
909
            if chan.need_to_subscribe():
×
910
                self.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address())
×
911
        for cb in self.channel_backups.values():
×
912
            if cb.need_to_subscribe():
×
913
                self.lnwatcher.add_channel(cb.funding_outpoint.to_str(), cb.get_funding_address())
×
914

915
        for coro in [
×
916
                self.maybe_listen(),
917
                self.lnwatcher.trigger_callbacks(), # shortcut (don't block) if funding tx locked and verified
918
                self.reestablish_peers_and_channels(),
919
                self.sync_with_remote_watchtower(),
920
        ]:
921
            tg_coro = self.taskgroup.spawn(coro)
×
922
            asyncio.run_coroutine_threadsafe(tg_coro, self.network.asyncio_loop)
×
923

924
    async def stop(self):
5✔
925
        self.stopping_soon = True
5✔
926
        if self.listen_server:  # stop accepting new peers
5✔
927
            self.listen_server.close()
×
928
        async with ignore_after(self.TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS):
5✔
929
            await self.wait_for_received_pending_htlcs_to_get_removed()
5✔
930
        await LNWorker.stop(self)
5✔
931
        if self.lnwatcher:
5✔
932
            await self.lnwatcher.stop()
×
933
            self.lnwatcher = None
×
934
        if self.swap_manager and self.swap_manager.network:  # may not be present in tests
5✔
935
            await self.swap_manager.stop()
×
936
        if self.onion_message_manager:
5✔
937
            await self.onion_message_manager.stop()
×
938

939
    async def wait_for_received_pending_htlcs_to_get_removed(self):
5✔
940
        assert self.stopping_soon is True
5✔
941
        # We try to fail pending MPP HTLCs, and wait a bit for them to get removed.
942
        # Note: even without MPP, if we just failed/fulfilled an HTLC, it is good
943
        #       to wait a bit for it to become irrevocably removed.
944
        # Note: we don't wait for *all htlcs* to get removed, only for those
945
        #       that we can already fail/fulfill. e.g. forwarded htlcs cannot be removed
946
        async with OldTaskGroup() as group:
5✔
947
            for peer in self.peers.values():
5✔
948
                await group.spawn(peer.wait_one_htlc_switch_iteration())
5✔
949
        while True:
5✔
950
            if all(not peer.received_htlcs_pending_removal for peer in self.peers.values()):
5✔
951
                break
5✔
952
            async with OldTaskGroup(wait=any) as group:
4✔
953
                for peer in self.peers.values():
4✔
954
                    await group.spawn(peer.received_htlc_removed_event.wait())
4✔
955

956
    def peer_closed(self, peer):
5✔
957
        for chan in self.channels_for_peer(peer.pubkey).values():
×
958
            chan.peer_state = PeerState.DISCONNECTED
×
959
            util.trigger_callback('channel', self.wallet, chan)
×
960
        super().peer_closed(peer)
×
961

962
    def get_payments(self, *, status=None) -> Mapping[bytes, List[HTLCWithStatus]]:
5✔
963
        out = defaultdict(list)
5✔
964
        for chan in self.channels.values():
5✔
965
            d = chan.get_payments(status=status)
5✔
966
            for payment_hash, plist in d.items():
5✔
967
                out[payment_hash] += plist
5✔
968
        return out
5✔
969

970
    def get_payment_value(
5✔
971
            self, info: Optional['PaymentInfo'],
972
            plist: List[HTLCWithStatus]) -> Tuple[PaymentDirection, int, Optional[int], int]:
973
        """ fee_msat is included in amount_msat"""
974
        assert plist
×
975
        amount_msat = sum(int(x.direction) * x.htlc.amount_msat for x in plist)
×
976
        if all(x.direction == SENT for x in plist):
×
977
            direction = PaymentDirection.SENT
×
978
            fee_msat = (- info.amount_msat - amount_msat) if info else None
×
979
        elif all(x.direction == RECEIVED for x in plist):
×
980
            direction = PaymentDirection.RECEIVED
×
981
            fee_msat = None
×
982
        elif amount_msat < 0:
×
983
            direction = PaymentDirection.SELF_PAYMENT
×
984
            fee_msat = - amount_msat
×
985
        else:
986
            direction = PaymentDirection.FORWARDING
×
987
            fee_msat = - amount_msat
×
988
        timestamp = min([htlc_with_status.htlc.timestamp for htlc_with_status in plist])
×
989
        return direction, amount_msat, fee_msat, timestamp
×
990

991
    def get_lightning_history(self) -> Dict[str, LightningHistoryItem]:
5✔
992
        """
993
        side effect: sets defaults labels
994
        note that the result is not ordered
995
        """
996
        out = {}
×
997
        for payment_hash, plist in self.get_payments(status='settled').items():
×
998
            if len(plist) == 0:
×
999
                continue
×
1000
            key = payment_hash.hex()
×
1001
            info = self.get_payment_info(payment_hash)
×
1002
            direction, amount_msat, fee_msat, timestamp = self.get_payment_value(info, plist)
×
1003
            label = self.wallet.get_label_for_rhash(key)
×
1004
            if not label and direction == PaymentDirection.FORWARDING:
×
1005
                label = _('Forwarding')
×
1006
            preimage = self.get_preimage(payment_hash).hex()
×
1007
            group_id = self.swap_manager.get_group_id_for_payment_hash(payment_hash)
×
1008
            item = LightningHistoryItem(
×
1009
                type = 'payment',
1010
                payment_hash = payment_hash.hex(),
1011
                preimage = preimage,
1012
                amount_msat = amount_msat,
1013
                fee_msat = fee_msat,
1014
                group_id = group_id,
1015
                timestamp = timestamp or 0,
1016
                label=label,
1017
            )
1018
            out[payment_hash.hex()] = item
×
1019
        for chan in itertools.chain(self.channels.values(), self.channel_backups.values()):  # type: AbstractChannel
×
1020
            item = chan.get_funding_height()
×
1021
            if item is None:
×
1022
                continue
×
1023
            funding_txid, funding_height, funding_timestamp = item
×
1024
            label = _('Open channel') + ' ' + chan.get_id_for_log()
×
1025
            self.wallet.set_default_label(funding_txid, label)
×
1026
            self.wallet.set_group_label(funding_txid, label)
×
1027
            item = LightningHistoryItem(
×
1028
                type = 'channel_opening',
1029
                label = label,
1030
                group_id = funding_txid,
1031
                timestamp = funding_timestamp,
1032
                amount_msat = chan.balance(LOCAL, ctn=0),
1033
                fee_msat = None,
1034
                payment_hash = None,
1035
                preimage = None,
1036
            )
1037
            out[funding_txid] = item
×
1038
            item = chan.get_closing_height()
×
1039
            if item is None:
×
1040
                continue
×
1041
            closing_txid, closing_height, closing_timestamp = item
×
1042
            label = _('Close channel') + ' ' + chan.get_id_for_log()
×
1043
            self.wallet.set_default_label(closing_txid, label)
×
1044
            self.wallet.set_group_label(closing_txid, label)
×
1045
            item = LightningHistoryItem(
×
1046
                type = 'channel_closing',
1047
                label = label,
1048
                group_id = closing_txid,
1049
                timestamp = closing_timestamp,
1050
                amount_msat = -chan.balance(LOCAL),
1051
                fee_msat = None,
1052
                payment_hash = None,
1053
                preimage = None,
1054
            )
1055
            out[closing_txid] = item
×
1056

1057
        # sanity check
1058
        balance_msat = sum([x.amount_msat for x in out.values()])
×
1059
        lb = sum(chan.balance(LOCAL) if not chan.is_closed_or_closing() else 0
×
1060
                for chan in self.channels.values())
1061
        assert balance_msat  == lb
×
1062
        return out
×
1063

1064
    def get_groups_for_onchain_history(self) -> Dict[str, str]:
5✔
1065
        """
1066
        returns dict: txid -> group_id
1067
        side effect: sets default labels
1068
        """
1069
        groups = {}
×
1070
        # add funding events
1071
        for chan in itertools.chain(self.channels.values(), self.channel_backups.values()):  # type: AbstractChannel
×
1072
            item = chan.get_funding_height()
×
1073
            if item is None:
×
1074
                continue
×
1075
            funding_txid, funding_height, funding_timestamp = item
×
1076
            groups[funding_txid] = funding_txid
×
1077
            item = chan.get_closing_height()
×
1078
            if item is None:
×
1079
                continue
×
1080
            closing_txid, closing_height, closing_timestamp = item
×
1081
            groups[closing_txid] = closing_txid
×
1082

1083
        d = self.swap_manager.get_groups_for_onchain_history()
×
1084
        for txid, v in d.items():
×
1085
            group_id = v['group_id']
×
1086
            label = v.get('label')
×
1087
            group_label = v.get('group_label') or label
×
1088
            groups[txid] = group_id
×
1089
            if label:
×
1090
                self.wallet.set_default_label(txid, label)
×
1091
            if group_label:
×
1092
                self.wallet.set_group_label(group_id, group_label)
×
1093

1094
        return groups
×
1095

1096
    def channel_peers(self) -> List[bytes]:
5✔
1097
        node_ids = [chan.node_id for chan in self.channels.values() if not chan.is_closed()]
×
1098
        return node_ids
×
1099

1100
    def channels_for_peer(self, node_id):
5✔
1101
        assert type(node_id) is bytes
5✔
1102
        return {chan_id: chan for (chan_id, chan) in self.channels.items()
5✔
1103
                if chan.node_id == node_id}
1104

1105
    def channel_state_changed(self, chan: Channel):
5✔
1106
        if type(chan) is Channel:
×
1107
            self.save_channel(chan)
×
1108
        self.clear_invoices_cache()
×
1109
        util.trigger_callback('channel', self.wallet, chan)
×
1110

1111
    def save_channel(self, chan: Channel):
5✔
1112
        assert type(chan) is Channel
×
1113
        if chan.config[REMOTE].next_per_commitment_point == chan.config[REMOTE].current_per_commitment_point:
×
1114
            raise Exception("Tried to save channel with next_point == current_point, this should not happen")
×
1115
        self.wallet.save_db()
×
1116
        util.trigger_callback('channel', self.wallet, chan)
×
1117

1118
    def channel_by_txo(self, txo: str) -> Optional[AbstractChannel]:
5✔
1119
        for chan in self.channels.values():
×
1120
            if chan.funding_outpoint.to_str() == txo:
×
1121
                return chan
×
1122
        for chan in self.channel_backups.values():
×
1123
            if chan.funding_outpoint.to_str() == txo:
×
1124
                return chan
×
1125

1126
    async def handle_onchain_state(self, chan: Channel):
5✔
1127
        if type(chan) is ChannelBackup:
×
1128
            util.trigger_callback('channel', self.wallet, chan)
×
1129
            return
×
1130

1131
        if (chan.get_state() in (ChannelState.OPEN, ChannelState.SHUTDOWN)
×
1132
                and chan.should_be_closed_due_to_expiring_htlcs(self.network.get_local_height())):
1133
            self.logger.info(f"force-closing due to expiring htlcs")
×
1134
            await self.schedule_force_closing(chan.channel_id)
×
1135

1136
        elif chan.get_state() == ChannelState.FUNDED:
×
1137
            peer = self._peers.get(chan.node_id)
×
1138
            if peer and peer.is_initialized() and chan.peer_state == PeerState.GOOD:
×
1139
                peer.send_channel_ready(chan)
×
1140

1141
        elif chan.get_state() == ChannelState.OPEN:
×
1142
            peer = self._peers.get(chan.node_id)
×
1143
            if peer and peer.is_initialized() and chan.peer_state == PeerState.GOOD:
×
1144
                peer.maybe_update_fee(chan)
×
1145
                peer.maybe_send_announcement_signatures(chan)
×
1146

1147
        elif chan.get_state() == ChannelState.FORCE_CLOSING:
×
1148
            force_close_tx = chan.force_close_tx()
×
1149
            txid = force_close_tx.txid()
×
1150
            height = self.lnwatcher.adb.get_tx_height(txid).height
×
1151
            if height == TX_HEIGHT_LOCAL:
×
1152
                self.logger.info('REBROADCASTING CLOSING TX')
×
1153
                await self.network.try_broadcasting(force_close_tx, 'force-close')
×
1154

1155
    def get_peer_by_static_jit_scid_alias(self, scid_alias: bytes) -> Optional[Peer]:
5✔
1156
        for nodeid, peer in self.peers.items():
×
1157
            if scid_alias == self._scid_alias_of_node(nodeid):
×
1158
                return peer
×
1159

1160
    def _scid_alias_of_node(self, nodeid: bytes) -> bytes:
5✔
1161
        # scid alias for just-in-time channels
1162
        return sha256(b'Electrum' + nodeid)[0:8]
×
1163

1164
    def get_static_jit_scid_alias(self) -> bytes:
5✔
1165
        return self._scid_alias_of_node(self.node_keypair.pubkey)
×
1166

1167
    @log_exceptions
5✔
1168
    async def open_channel_just_in_time(
5✔
1169
        self,
1170
        *,
1171
        next_peer: Peer,
1172
        next_amount_msat_htlc: int,
1173
        next_cltv_abs: int,
1174
        payment_hash: bytes,
1175
        next_onion: OnionPacket,
1176
    ) -> str:
1177
        # if an exception is raised during negotiation, we raise an OnionRoutingFailure.
1178
        # this will cancel the incoming HTLC
1179
        try:
×
1180
            funding_sat = 2 * (next_amount_msat_htlc // 1000) # try to fully spend htlcs
×
1181
            password = self.wallet.get_unlocked_password() if self.wallet.has_password() else None
×
1182
            channel_opening_fee = next_amount_msat_htlc // 100
×
1183
            if channel_opening_fee // 1000 < self.config.ZEROCONF_MIN_OPENING_FEE:
×
1184
                self.logger.info(f'rejecting JIT channel: payment too low')
×
1185
                raise OnionRoutingFailure(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'payment too low')
×
1186
            self.logger.info(f'channel opening fee (sats): {channel_opening_fee//1000}')
×
1187
            next_chan, funding_tx = await self.open_channel_with_peer(
×
1188
                next_peer, funding_sat,
1189
                push_sat=0,
1190
                zeroconf=True,
1191
                public=False,
1192
                opening_fee=channel_opening_fee,
1193
                password=password,
1194
            )
1195
            async def wait_for_channel():
×
1196
                while not next_chan.is_open():
×
1197
                    await asyncio.sleep(1)
×
1198
            await util.wait_for2(wait_for_channel(), LN_P2P_NETWORK_TIMEOUT)
×
1199
            next_chan.save_remote_scid_alias(self._scid_alias_of_node(next_peer.pubkey))
×
1200
            self.logger.info(f'JIT channel is open')
×
1201
            next_amount_msat_htlc -= channel_opening_fee
×
1202
            # fixme: some checks are missing
1203
            htlc = next_peer.send_htlc(
×
1204
                chan=next_chan,
1205
                payment_hash=payment_hash,
1206
                amount_msat=next_amount_msat_htlc,
1207
                cltv_abs=next_cltv_abs,
1208
                onion=next_onion)
1209
            async def wait_for_preimage():
×
1210
                while self.get_preimage(payment_hash) is None:
×
1211
                    await asyncio.sleep(1)
×
1212
            await util.wait_for2(wait_for_preimage(), LN_P2P_NETWORK_TIMEOUT)
×
1213
        except OnionRoutingFailure:
×
1214
            raise
×
1215
        except Exception:
×
1216
            raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_NODE_FAILURE, data=b'')
×
1217
        # We have been paid and can broadcast
1218
        # todo: if broadcasting raise an exception, we should try to rebroadcast
1219
        await self.network.broadcast_transaction(funding_tx)
×
1220
        htlc_key = serialize_htlc_key(next_chan.get_scid_or_local_alias(), htlc.htlc_id)
×
1221
        return htlc_key
×
1222

1223
    @log_exceptions
5✔
1224
    async def open_channel_with_peer(
5✔
1225
            self, peer, funding_sat, *,
1226
            push_sat: int = 0,
1227
            public: bool = False,
1228
            zeroconf: bool = False,
1229
            opening_fee: int = None,
1230
            password=None):
1231
        if self.config.ENABLE_ANCHOR_CHANNELS:
×
1232
            self.wallet.unlock(password)
×
1233
        coins = self.wallet.get_spendable_coins(None)
×
1234
        node_id = peer.pubkey
×
1235
        funding_tx = self.mktx_for_open_channel(
×
1236
            coins=coins,
1237
            funding_sat=funding_sat,
1238
            node_id=node_id,
1239
            fee_est=None)
1240
        chan, funding_tx = await self._open_channel_coroutine(
×
1241
            peer=peer,
1242
            funding_tx=funding_tx,
1243
            funding_sat=funding_sat,
1244
            push_sat=push_sat,
1245
            public=public,
1246
            zeroconf=zeroconf,
1247
            opening_fee=opening_fee,
1248
            password=password)
1249
        return chan, funding_tx
×
1250

1251
    @log_exceptions
5✔
1252
    async def _open_channel_coroutine(
5✔
1253
            self, *,
1254
            peer: Peer,
1255
            funding_tx: PartialTransaction,
1256
            funding_sat: int,
1257
            push_sat: int,
1258
            public: bool,
1259
            zeroconf=False,
1260
            opening_fee=None,
1261
            password: Optional[str],
1262
    ) -> Tuple[Channel, PartialTransaction]:
1263

1264
        if funding_sat > self.config.LIGHTNING_MAX_FUNDING_SAT:
×
1265
            raise Exception(
×
1266
                _("Requested channel capacity is over maximum.")
1267
                + f"\n{funding_sat} sat > {self.config.LIGHTNING_MAX_FUNDING_SAT} sat"
1268
            )
1269
        coro = peer.channel_establishment_flow(
×
1270
            funding_tx=funding_tx,
1271
            funding_sat=funding_sat,
1272
            push_msat=push_sat * 1000,
1273
            public=public,
1274
            zeroconf=zeroconf,
1275
            opening_fee=opening_fee,
1276
            temp_channel_id=os.urandom(32))
1277
        chan, funding_tx = await util.wait_for2(coro, LN_P2P_NETWORK_TIMEOUT)
×
1278
        util.trigger_callback('channels_updated', self.wallet)
×
1279
        self.wallet.adb.add_transaction(funding_tx)  # save tx as local into the wallet
×
1280
        self.wallet.sign_transaction(funding_tx, password)
×
1281
        if funding_tx.is_complete() and not zeroconf:
×
1282
            await self.network.try_broadcasting(funding_tx, 'open_channel')
×
1283
        return chan, funding_tx
×
1284

1285
    def add_channel(self, chan: Channel):
5✔
1286
        with self.lock:
×
1287
            self._channels[chan.channel_id] = chan
×
1288
        self.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address())
×
1289

1290
    def add_new_channel(self, chan: Channel):
5✔
1291
        self.add_channel(chan)
×
1292
        channels_db = self.db.get_dict('channels')
×
1293
        channels_db[chan.channel_id.hex()] = chan.storage
×
1294
        self.wallet.set_reserved_addresses_for_chan(chan, reserved=True)
×
1295
        try:
×
1296
            self.save_channel(chan)
×
1297
        except Exception:
×
1298
            chan.set_state(ChannelState.REDEEMED)
×
1299
            self.remove_channel(chan.channel_id)
×
1300
            raise
×
1301

1302
    def cb_data(self, node_id: bytes) -> bytes:
5✔
1303
        return CB_MAGIC_BYTES + node_id[0:NODE_ID_PREFIX_LEN]
×
1304

1305
    def decrypt_cb_data(self, encrypted_data, funding_address):
5✔
1306
        funding_scripthash = bytes.fromhex(address_to_scripthash(funding_address))
×
1307
        nonce = funding_scripthash[0:12]
×
1308
        return chacha20_decrypt(key=self.backup_key, data=encrypted_data, nonce=nonce)
×
1309

1310
    def encrypt_cb_data(self, data, funding_address):
5✔
1311
        funding_scripthash = bytes.fromhex(address_to_scripthash(funding_address))
×
1312
        nonce = funding_scripthash[0:12]
×
1313
        # note: we are only using chacha20 instead of chacha20+poly1305 to save onchain space
1314
        #       (not have the 16 byte MAC). Otherwise, the latter would be preferable.
1315
        return chacha20_encrypt(key=self.backup_key, data=data, nonce=nonce)
×
1316

1317
    def mktx_for_open_channel(
5✔
1318
            self, *,
1319
            coins: Sequence[PartialTxInput],
1320
            funding_sat: int,
1321
            node_id: bytes,
1322
            fee_est=None) -> PartialTransaction:
1323
        from .wallet import get_locktime_for_new_transaction
×
1324

1325
        outputs = [PartialTxOutput.from_address_and_value(DummyAddress.CHANNEL, funding_sat)]
×
1326
        if self.has_recoverable_channels():
×
1327
            dummy_scriptpubkey = make_op_return(self.cb_data(node_id))
×
1328
            outputs.append(PartialTxOutput(scriptpubkey=dummy_scriptpubkey, value=0))
×
1329
        tx = self.wallet.make_unsigned_transaction(
×
1330
            coins=coins,
1331
            outputs=outputs,
1332
            fee=fee_est)
1333
        tx.set_rbf(False)
×
1334
        # rm randomness from locktime, as we use the locktime as entropy for deriving the funding_privkey
1335
        # (and it would be confusing to get a collision as a consequence of the randomness)
1336
        tx.locktime = get_locktime_for_new_transaction(self.network, include_random_component=False)
×
1337
        return tx
×
1338

1339
    def suggest_funding_amount(self, amount_to_pay, coins):
5✔
1340
        """ whether we can pay amount_sat after opening a new channel"""
1341
        num_sats_can_send = int(self.num_sats_can_send())
×
1342
        lightning_needed = amount_to_pay - num_sats_can_send
×
1343
        assert lightning_needed > 0
×
1344
        min_funding_sat = lightning_needed + (lightning_needed // 20) + 1000 # safety margin
×
1345
        min_funding_sat = max(min_funding_sat, 100_000) # at least 1mBTC
×
1346
        if min_funding_sat > self.config.LIGHTNING_MAX_FUNDING_SAT:
×
1347
            return
×
1348
        fee_est = partial(self.config.estimate_fee, allow_fallback_to_static_rates=True)  # to avoid NoDynamicFeeEstimates
×
1349
        try:
×
1350
            self.mktx_for_open_channel(coins=coins, funding_sat=min_funding_sat, node_id=bytes(32), fee_est=fee_est)
×
1351
            funding_sat = min_funding_sat
×
1352
        except NotEnoughFunds:
×
1353
            return
×
1354
        # if available, suggest twice that amount:
1355
        if 2 * min_funding_sat <= self.config.LIGHTNING_MAX_FUNDING_SAT:
×
1356
            try:
×
1357
                self.mktx_for_open_channel(coins=coins, funding_sat=2*min_funding_sat, node_id=bytes(32), fee_est=fee_est)
×
1358
                funding_sat = 2 * min_funding_sat
×
1359
            except NotEnoughFunds:
×
1360
                pass
×
1361
        return funding_sat, min_funding_sat
×
1362

1363
    def open_channel(
5✔
1364
            self, *,
1365
            connect_str: str,
1366
            funding_tx: PartialTransaction,
1367
            funding_sat: int,
1368
            push_amt_sat: int,
1369
            public: bool = False,
1370
            password: str = None,
1371
    ) -> Tuple[Channel, PartialTransaction]:
1372

1373
        fut = asyncio.run_coroutine_threadsafe(self.add_peer(connect_str), self.network.asyncio_loop)
×
1374
        try:
×
1375
            peer = fut.result()
×
1376
        except concurrent.futures.TimeoutError:
×
1377
            raise Exception(_("add peer timed out"))
×
1378
        coro = self._open_channel_coroutine(
×
1379
            peer=peer,
1380
            funding_tx=funding_tx,
1381
            funding_sat=funding_sat,
1382
            push_sat=push_amt_sat,
1383
            public=public,
1384
            password=password)
1385
        fut = asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop)
×
1386
        try:
×
1387
            chan, funding_tx = fut.result()
×
1388
        except concurrent.futures.TimeoutError:
×
1389
            raise Exception(_("open_channel timed out"))
×
1390
        return chan, funding_tx
×
1391

1392
    def get_channel_by_short_id(self, short_channel_id: bytes) -> Optional[Channel]:
5✔
1393
        # First check against *real* SCIDs.
1394
        # This e.g. protects against maliciously chosen SCID aliases, and accidental collisions.
1395
        for chan in self.channels.values():
×
1396
            if chan.short_channel_id == short_channel_id:
×
1397
                return chan
×
1398
        # Now we also consider aliases.
1399
        # TODO we should split this as this search currently ignores the "direction"
1400
        #      of the aliases. We should only look at either the remote OR the local alias,
1401
        #      depending on context.
1402
        for chan in self.channels.values():
×
1403
            if chan.get_remote_scid_alias() == short_channel_id:
×
1404
                return chan
×
1405
            if chan.get_local_scid_alias() == short_channel_id:
×
1406
                return chan
×
1407

1408
    def can_pay_invoice(self, invoice: Invoice) -> bool:
5✔
1409
        assert invoice.is_lightning()
×
1410
        return (invoice.get_amount_sat() or 0) <= self.num_sats_can_send()
×
1411

1412
    @log_exceptions
5✔
1413
    async def pay_invoice(
5✔
1414
            self, invoice: str, *,
1415
            amount_msat: int = None,
1416
            attempts: int = None,  # used only in unit tests
1417
            full_path: LNPaymentPath = None,
1418
            channels: Optional[Sequence[Channel]] = None,
1419
    ) -> Tuple[bool, List[HtlcLog]]:
1420

1421
        lnaddr = self._check_invoice(invoice, amount_msat=amount_msat)
5✔
1422
        min_final_cltv_delta = lnaddr.get_min_final_cltv_delta()
5✔
1423
        payment_hash = lnaddr.paymenthash
5✔
1424
        key = payment_hash.hex()
5✔
1425
        payment_secret = lnaddr.payment_secret
5✔
1426
        invoice_pubkey = lnaddr.pubkey.serialize()
5✔
1427
        invoice_features = lnaddr.get_features()
5✔
1428
        r_tags = lnaddr.get_routing_info('r')
5✔
1429
        amount_to_pay = lnaddr.get_amount_msat()
5✔
1430
        status = self.get_payment_status(payment_hash)
5✔
1431
        if status == PR_PAID:
5✔
1432
            raise PaymentFailure(_("This invoice has been paid already"))
×
1433
        if status == PR_INFLIGHT:
5✔
1434
            raise PaymentFailure(_("A payment was already initiated for this invoice"))
×
1435
        if payment_hash in self.get_payments(status='inflight'):
5✔
1436
            raise PaymentFailure(_("A previous attempt to pay this invoice did not clear"))
×
1437
        info = PaymentInfo(payment_hash, amount_to_pay, SENT, PR_UNPAID)
5✔
1438
        self.save_payment_info(info)
5✔
1439
        self.wallet.set_label(key, lnaddr.get_description())
5✔
1440
        self.set_invoice_status(key, PR_INFLIGHT)
5✔
1441
        budget = PaymentFeeBudget.default(invoice_amount_msat=amount_to_pay, config=self.config)
5✔
1442
        if attempts is None and self.uses_trampoline():
5✔
1443
            # we don't expect lots of failed htlcs with trampoline, so we can fail sooner
1444
            attempts = 30
5✔
1445
        success = False
5✔
1446
        try:
5✔
1447
            await self.pay_to_node(
5✔
1448
                node_pubkey=invoice_pubkey,
1449
                payment_hash=payment_hash,
1450
                payment_secret=payment_secret,
1451
                amount_to_pay=amount_to_pay,
1452
                min_final_cltv_delta=min_final_cltv_delta,
1453
                r_tags=r_tags,
1454
                invoice_features=invoice_features,
1455
                attempts=attempts,
1456
                full_path=full_path,
1457
                channels=channels,
1458
                budget=budget,
1459
            )
1460
            success = True
5✔
1461
        except PaymentFailure as e:
5✔
1462
            self.logger.info(f'payment failure: {e!r}')
5✔
1463
            reason = str(e)
5✔
1464
        except ChannelDBNotLoaded as e:
5✔
1465
            self.logger.info(f'payment failure: {e!r}')
×
1466
            reason = str(e)
×
1467
        finally:
1468
            self.logger.info(f"pay_invoice ending session for RHASH={payment_hash.hex()}. {success=}")
5✔
1469
        if success:
5✔
1470
            self.set_invoice_status(key, PR_PAID)
5✔
1471
            util.trigger_callback('payment_succeeded', self.wallet, key)
5✔
1472
        else:
1473
            self.set_invoice_status(key, PR_UNPAID)
5✔
1474
            util.trigger_callback('payment_failed', self.wallet, key, reason)
5✔
1475
        log = self.logs[key]
5✔
1476
        return success, log
5✔
1477

1478
    async def pay_to_node(
5✔
1479
            self, *,
1480
            node_pubkey: bytes,
1481
            payment_hash: bytes,
1482
            payment_secret: bytes,
1483
            amount_to_pay: int,  # in msat
1484
            min_final_cltv_delta: int,
1485
            r_tags,
1486
            invoice_features: int,
1487
            attempts: int = None,
1488
            full_path: LNPaymentPath = None,
1489
            fwd_trampoline_onion: OnionPacket = None,
1490
            budget: PaymentFeeBudget,
1491
            channels: Optional[Sequence[Channel]] = None,
1492
            fw_payment_key: str = None,  # for forwarding
1493
    ) -> None:
1494

1495
        assert budget
5✔
1496
        assert budget.fee_msat >= 0, budget
5✔
1497
        assert budget.cltv >= 0, budget
5✔
1498

1499
        payment_key = payment_hash + payment_secret
5✔
1500
        assert payment_key not in self._paysessions
5✔
1501
        self._paysessions[payment_key] = paysession = PaySession(
5✔
1502
            payment_hash=payment_hash,
1503
            payment_secret=payment_secret,
1504
            initial_trampoline_fee_level=self.config.INITIAL_TRAMPOLINE_FEE_LEVEL,
1505
            invoice_features=invoice_features,
1506
            r_tags=r_tags,
1507
            min_final_cltv_delta=min_final_cltv_delta,
1508
            amount_to_pay=amount_to_pay,
1509
            invoice_pubkey=node_pubkey,
1510
            uses_trampoline=self.uses_trampoline(),
1511
            use_two_trampolines=self.config.LIGHTNING_LEGACY_ADD_TRAMPOLINE,
1512
        )
1513
        self.logs[payment_hash.hex()] = log = []  # TODO incl payment_secret in key (re trampoline forwarding)
5✔
1514

1515
        paysession.logger.info(
5✔
1516
            f"pay_to_node starting session for RHASH={payment_hash.hex()}. "
1517
            f"using_trampoline={self.uses_trampoline()}. "
1518
            f"invoice_features={paysession.invoice_features.get_names()}. "
1519
            f"{amount_to_pay=} msat. {budget=}")
1520
        if not self.uses_trampoline():
5✔
1521
            self.logger.info(
5✔
1522
                f"gossip_db status. sync progress: {self.network.lngossip.get_sync_progress_estimate()}. "
1523
                f"num_nodes={self.channel_db.num_nodes}, "
1524
                f"num_channels={self.channel_db.num_channels}, "
1525
                f"num_policies={self.channel_db.num_policies}.")
1526

1527
        # when encountering trampoline forwarding difficulties in the legacy case, we
1528
        # sometimes need to fall back to a single trampoline forwarder, at the expense
1529
        # of privacy
1530
        try:
5✔
1531
            while True:
5✔
1532
                if (amount_to_send := paysession.get_outstanding_amount_to_send()) > 0:
5✔
1533
                    # 1. create a set of routes for remaining amount.
1534
                    # note: path-finding runs in a separate thread so that we don't block the asyncio loop
1535
                    # graph updates might occur during the computation
1536
                    remaining_fee_budget_msat = (budget.fee_msat * amount_to_send) // amount_to_pay
5✔
1537
                    routes = self.create_routes_for_payment(
5✔
1538
                        paysession=paysession,
1539
                        amount_msat=amount_to_send,
1540
                        full_path=full_path,
1541
                        fwd_trampoline_onion=fwd_trampoline_onion,
1542
                        channels=channels,
1543
                        budget=budget._replace(fee_msat=remaining_fee_budget_msat),
1544
                    )
1545
                    # 2. send htlcs
1546
                    async for sent_htlc_info, cltv_delta, trampoline_onion in routes:
5✔
1547
                        await self.pay_to_route(
5✔
1548
                            paysession=paysession,
1549
                            sent_htlc_info=sent_htlc_info,
1550
                            min_final_cltv_delta=cltv_delta,
1551
                            trampoline_onion=trampoline_onion,
1552
                            fw_payment_key=fw_payment_key,
1553
                        )
1554
                    # invoice_status is triggered in self.set_invoice_status when it actually changes.
1555
                    # It is also triggered here to update progress for a lightning payment in the GUI
1556
                    # (e.g. attempt counter)
1557
                    util.trigger_callback('invoice_status', self.wallet, payment_hash.hex(), PR_INFLIGHT)
5✔
1558
                # 3. await a queue
1559
                htlc_log = await paysession.wait_for_one_htlc_to_resolve()  # TODO maybe wait a bit, more failures might come
5✔
1560
                log.append(htlc_log)
5✔
1561
                if htlc_log.success:
5✔
1562
                    if self.network.path_finder:
5✔
1563
                        # TODO: report every route to liquidity hints for mpp
1564
                        # in the case of success, we report channels of the
1565
                        # route as being able to send the same amount in the future,
1566
                        # as we assume to not know the capacity
1567
                        self.network.path_finder.update_liquidity_hints(htlc_log.route, htlc_log.amount_msat)
5✔
1568
                        # remove inflight htlcs from liquidity hints
1569
                        self.network.path_finder.update_inflight_htlcs(htlc_log.route, add_htlcs=False)
5✔
1570
                    return
5✔
1571
                # htlc failed
1572
                # if we get a tmp channel failure, it might work to split the amount and try more routes
1573
                # if we get a channel update, we might retry the same route and amount
1574
                route = htlc_log.route
5✔
1575
                sender_idx = htlc_log.sender_idx
5✔
1576
                failure_msg = htlc_log.failure_msg
5✔
1577
                if sender_idx is None:
5✔
1578
                    raise PaymentFailure(failure_msg.code_name())
4✔
1579
                erring_node_id = route[sender_idx].node_id
5✔
1580
                code, data = failure_msg.code, failure_msg.data
5✔
1581
                self.logger.info(f"UPDATE_FAIL_HTLC. code={repr(code)}. "
5✔
1582
                                 f"decoded_data={failure_msg.decode_data()}. data={data.hex()!r}")
1583
                self.logger.info(f"error reported by {erring_node_id.hex()}")
5✔
1584
                if code == OnionFailureCode.MPP_TIMEOUT:
5✔
1585
                    raise PaymentFailure(failure_msg.code_name())
5✔
1586
                # errors returned by the next trampoline.
1587
                if fwd_trampoline_onion and code in [
5✔
1588
                        OnionFailureCode.TRAMPOLINE_FEE_INSUFFICIENT,
1589
                        OnionFailureCode.TRAMPOLINE_EXPIRY_TOO_SOON]:
1590
                    raise failure_msg
×
1591
                # trampoline
1592
                if self.uses_trampoline():
5✔
1593
                    paysession.handle_failed_trampoline_htlc(
5✔
1594
                        htlc_log=htlc_log, failure_msg=failure_msg)
1595
                else:
1596
                    self.handle_error_code_from_failed_htlc(
5✔
1597
                        route=route, sender_idx=sender_idx, failure_msg=failure_msg, amount=htlc_log.amount_msat)
1598
                # max attempts or timeout
1599
                if (attempts is not None and len(log) >= attempts) or (attempts is None and time.time() - paysession.start_time > self.PAYMENT_TIMEOUT):
5✔
1600
                    raise PaymentFailure('Giving up after %d attempts'%len(log))
5✔
1601
        finally:
1602
            paysession.is_active = False
5✔
1603
            if paysession.can_be_deleted():
5✔
1604
                self._paysessions.pop(payment_key)
5✔
1605
            paysession.logger.info(f"pay_to_node ending session for RHASH={payment_hash.hex()}")
5✔
1606

1607
    async def pay_to_route(
5✔
1608
            self, *,
1609
            paysession: PaySession,
1610
            sent_htlc_info: SentHtlcInfo,
1611
            min_final_cltv_delta: int,
1612
            trampoline_onion: Optional[OnionPacket] = None,
1613
            fw_payment_key: str = None,
1614
    ) -> None:
1615
        """Sends a single HTLC."""
1616
        shi = sent_htlc_info
5✔
1617
        del sent_htlc_info  # just renamed
5✔
1618
        short_channel_id = shi.route[0].short_channel_id
5✔
1619
        chan = self.get_channel_by_short_id(short_channel_id)
5✔
1620
        assert chan, ShortChannelID(short_channel_id)
5✔
1621
        peer = self._peers.get(shi.route[0].node_id)
5✔
1622
        if not peer:
5✔
1623
            raise PaymentFailure('Dropped peer')
×
1624
        await peer.initialized
5✔
1625
        htlc = peer.pay(
5✔
1626
            route=shi.route,
1627
            chan=chan,
1628
            amount_msat=shi.amount_msat,
1629
            total_msat=shi.bucket_msat,
1630
            payment_hash=paysession.payment_hash,
1631
            min_final_cltv_delta=min_final_cltv_delta,
1632
            payment_secret=shi.payment_secret_bucket,
1633
            trampoline_onion=trampoline_onion)
1634

1635
        key = (paysession.payment_hash, short_channel_id, htlc.htlc_id)
5✔
1636
        self.sent_htlcs_info[key] = shi
5✔
1637
        paysession.add_new_htlc(shi)
5✔
1638
        if fw_payment_key:
5✔
1639
            htlc_key = serialize_htlc_key(short_channel_id, htlc.htlc_id)
5✔
1640
            self.logger.info(f'adding active forwarding {fw_payment_key}')
5✔
1641
            self.active_forwardings[fw_payment_key].append(htlc_key)
5✔
1642
        if self.network.path_finder:
5✔
1643
            # add inflight htlcs to liquidity hints
1644
            self.network.path_finder.update_inflight_htlcs(shi.route, add_htlcs=True)
5✔
1645
        util.trigger_callback('htlc_added', chan, htlc, SENT)
5✔
1646

1647
    def handle_error_code_from_failed_htlc(
5✔
1648
            self,
1649
            *,
1650
            route: LNPaymentRoute,
1651
            sender_idx: int,
1652
            failure_msg: OnionRoutingFailure,
1653
            amount: int) -> None:
1654

1655
        assert self.channel_db  # cannot be in trampoline mode
5✔
1656
        assert self.network.path_finder
5✔
1657

1658
        # remove inflight htlcs from liquidity hints
1659
        self.network.path_finder.update_inflight_htlcs(route, add_htlcs=False)
5✔
1660

1661
        code, data = failure_msg.code, failure_msg.data
5✔
1662
        # TODO can we use lnmsg.OnionWireSerializer here?
1663
        # TODO update onion_wire.csv
1664
        # handle some specific error codes
1665
        failure_codes = {
5✔
1666
            OnionFailureCode.TEMPORARY_CHANNEL_FAILURE: 0,
1667
            OnionFailureCode.AMOUNT_BELOW_MINIMUM: 8,
1668
            OnionFailureCode.FEE_INSUFFICIENT: 8,
1669
            OnionFailureCode.INCORRECT_CLTV_EXPIRY: 4,
1670
            OnionFailureCode.EXPIRY_TOO_SOON: 0,
1671
            OnionFailureCode.CHANNEL_DISABLED: 2,
1672
        }
1673
        try:
5✔
1674
            failing_channel = route[sender_idx + 1].short_channel_id
5✔
1675
        except IndexError:
5✔
1676
            raise PaymentFailure(f'payment destination reported error: {failure_msg.code_name()}') from None
5✔
1677

1678
        # TODO: handle unknown next peer?
1679
        # handle failure codes that include a channel update
1680
        if code in failure_codes:
5✔
1681
            offset = failure_codes[code]
5✔
1682
            channel_update_len = int.from_bytes(data[offset:offset+2], byteorder="big")
5✔
1683
            channel_update_as_received = data[offset+2: offset+2+channel_update_len]
5✔
1684
            payload = self._decode_channel_update_msg(channel_update_as_received)
5✔
1685
            if payload is None:
5✔
1686
                self.logger.info(f'could not decode channel_update for failed htlc: '
×
1687
                                 f'{channel_update_as_received.hex()}')
1688
                blacklist = True
×
1689
            elif payload.get('short_channel_id') != failing_channel:
5✔
1690
                self.logger.info(f'short_channel_id in channel_update does not match our route')
×
1691
                blacklist = True
×
1692
            else:
1693
                # apply the channel update or get blacklisted
1694
                blacklist, update = self._handle_chanupd_from_failed_htlc(
5✔
1695
                    payload, route=route, sender_idx=sender_idx, failure_msg=failure_msg)
1696
                # we interpret a temporary channel failure as a liquidity issue
1697
                # in the channel and update our liquidity hints accordingly
1698
                if code == OnionFailureCode.TEMPORARY_CHANNEL_FAILURE:
5✔
1699
                    self.network.path_finder.update_liquidity_hints(
5✔
1700
                        route,
1701
                        amount,
1702
                        failing_channel=ShortChannelID(failing_channel))
1703
                # if we can't decide on some action, we are stuck
1704
                if not (blacklist or update):
5✔
1705
                    raise PaymentFailure(failure_msg.code_name())
×
1706
        # for errors that do not include a channel update
1707
        else:
1708
            blacklist = True
5✔
1709
        if blacklist:
5✔
1710
            self.network.path_finder.add_edge_to_blacklist(short_channel_id=failing_channel)
5✔
1711

1712
    def _handle_chanupd_from_failed_htlc(
5✔
1713
        self, payload, *,
1714
        route: LNPaymentRoute,
1715
        sender_idx: int,
1716
        failure_msg: OnionRoutingFailure,
1717
    ) -> Tuple[bool, bool]:
1718
        blacklist = False
5✔
1719
        update = False
5✔
1720
        try:
5✔
1721
            r = self.channel_db.add_channel_update(payload, verify=True)
5✔
1722
        except InvalidGossipMsg:
×
1723
            return True, False  # blacklist
×
1724
        short_channel_id = ShortChannelID(payload['short_channel_id'])
5✔
1725
        if r == UpdateStatus.GOOD:
5✔
1726
            self.logger.info(f"applied channel update to {short_channel_id}")
×
1727
            # TODO: add test for this
1728
            # FIXME: this does not work for our own unannounced channels.
1729
            for chan in self.channels.values():
×
1730
                if chan.short_channel_id == short_channel_id:
×
1731
                    chan.set_remote_update(payload)
×
1732
            update = True
×
1733
        elif r == UpdateStatus.ORPHANED:
5✔
1734
            # maybe it is a private channel (and data in invoice was outdated)
1735
            self.logger.info(f"Could not find {short_channel_id}. maybe update is for private channel?")
5✔
1736
            start_node_id = route[sender_idx].node_id
5✔
1737
            cache_ttl = None
5✔
1738
            if failure_msg.code == OnionFailureCode.CHANNEL_DISABLED:
5✔
1739
                # eclair sends CHANNEL_DISABLED if its peer is offline. E.g. we might be trying to pay
1740
                # a mobile phone with the app closed. So we cache this with a short TTL.
1741
                cache_ttl = self.channel_db.PRIVATE_CHAN_UPD_CACHE_TTL_SHORT
×
1742
            update = self.channel_db.add_channel_update_for_private_channel(payload, start_node_id, cache_ttl=cache_ttl)
5✔
1743
            blacklist = not update
5✔
1744
        elif r == UpdateStatus.EXPIRED:
×
1745
            blacklist = True
×
1746
        elif r == UpdateStatus.DEPRECATED:
×
1747
            self.logger.info(f'channel update is not more recent.')
×
1748
            blacklist = True
×
1749
        elif r == UpdateStatus.UNCHANGED:
×
1750
            blacklist = True
×
1751
        return blacklist, update
5✔
1752

1753
    @classmethod
5✔
1754
    def _decode_channel_update_msg(cls, chan_upd_msg: bytes) -> Optional[Dict[str, Any]]:
5✔
1755
        channel_update_as_received = chan_upd_msg
5✔
1756
        channel_update_typed = (258).to_bytes(length=2, byteorder="big") + channel_update_as_received
5✔
1757
        # note: some nodes put channel updates in error msgs with the leading msg_type already there.
1758
        #       we try decoding both ways here.
1759
        try:
5✔
1760
            message_type, payload = decode_msg(channel_update_typed)
5✔
1761
            if payload['chain_hash'] != constants.net.rev_genesis_bytes(): raise Exception()
5✔
1762
            payload['raw'] = channel_update_typed
5✔
1763
            return payload
5✔
1764
        except Exception:  # FIXME: too broad
5✔
1765
            try:
5✔
1766
                message_type, payload = decode_msg(channel_update_as_received)
5✔
1767
                if payload['chain_hash'] != constants.net.rev_genesis_bytes(): raise Exception()
5✔
1768
                payload['raw'] = channel_update_as_received
5✔
1769
                return payload
5✔
1770
            except Exception:
5✔
1771
                return None
5✔
1772

1773
    def _check_invoice(self, invoice: str, *, amount_msat: int = None) -> LnAddr:
5✔
1774
        """Parses and validates a bolt11 invoice str into a LnAddr.
1775
        Includes pre-payment checks external to the parser.
1776
        """
1777
        addr = lndecode(invoice)
5✔
1778
        if addr.is_expired():
5✔
1779
            raise InvoiceError(_("This invoice has expired"))
×
1780
        # check amount
1781
        if amount_msat:  # replace amt in invoice. main usecase is paying zero amt invoices
5✔
1782
            existing_amt_msat = addr.get_amount_msat()
×
1783
            if existing_amt_msat and amount_msat < existing_amt_msat:
×
1784
                raise Exception("cannot pay lower amt than what is originally in LN invoice")
×
1785
            addr.amount = Decimal(amount_msat) / COIN / 1000
×
1786
        if addr.amount is None:
5✔
1787
            raise InvoiceError(_("Missing amount"))
×
1788
        # check cltv
1789
        if addr.get_min_final_cltv_delta() > NBLOCK_CLTV_DELTA_TOO_FAR_INTO_FUTURE:
5✔
1790
            raise InvoiceError("{}\n{}".format(
5✔
1791
                _("Invoice wants us to risk locking funds for unreasonably long."),
1792
                f"min_final_cltv_delta: {addr.get_min_final_cltv_delta()}"))
1793
        # check features
1794
        addr.validate_and_compare_features(self.features)
5✔
1795
        return addr
5✔
1796

1797
    def is_trampoline_peer(self, node_id: bytes) -> bool:
5✔
1798
        # until trampoline is advertised in lnfeatures, check against hardcoded list
1799
        if is_hardcoded_trampoline(node_id):
5✔
1800
            return True
5✔
1801
        peer = self._peers.get(node_id)
×
1802
        if not peer:
×
1803
            return False
×
1804
        return (peer.their_features.supports(LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ECLAIR)\
×
1805
                or peer.their_features.supports(LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ELECTRUM))
1806

1807
    def suggest_peer(self) -> Optional[bytes]:
5✔
1808
        if not self.uses_trampoline():
×
1809
            return self.lnrater.suggest_peer()
×
1810
        else:
1811
            return random.choice(list(hardcoded_trampoline_nodes().values())).pubkey
×
1812

1813
    def suggest_splits(
5✔
1814
        self,
1815
        *,
1816
        amount_msat: int,
1817
        final_total_msat: int,
1818
        my_active_channels: Sequence[Channel],
1819
        invoice_features: LnFeatures,
1820
        r_tags,
1821
    ) -> List['SplitConfigRating']:
1822
        channels_with_funds = {
5✔
1823
            (chan.channel_id, chan.node_id): int(chan.available_to_spend(HTLCOwner.LOCAL))
1824
            for chan in my_active_channels
1825
        }
1826
        self.logger.info(f"channels_with_funds: {channels_with_funds}")
5✔
1827
        exclude_single_part_payments = False
5✔
1828
        if self.uses_trampoline():
5✔
1829
            # in the case of a legacy payment, we don't allow splitting via different
1830
            # trampoline nodes, because of https://github.com/ACINQ/eclair/issues/2127
1831
            is_legacy, _ = is_legacy_relay(invoice_features, r_tags)
5✔
1832
            exclude_multinode_payments = is_legacy
5✔
1833
            # we don't split within a channel when sending to a trampoline node,
1834
            # the trampoline node will split for us
1835
            exclude_single_channel_splits = True
5✔
1836
        else:
1837
            exclude_multinode_payments = False
5✔
1838
            exclude_single_channel_splits = False
5✔
1839
            if invoice_features.supports(LnFeatures.BASIC_MPP_OPT) and not self.config.TEST_FORCE_DISABLE_MPP:
5✔
1840
                # if amt is still large compared to total_msat, split it:
1841
                if (amount_msat / final_total_msat > self.MPP_SPLIT_PART_FRACTION
5✔
1842
                        and amount_msat > self.MPP_SPLIT_PART_MINAMT_MSAT):
1843
                    exclude_single_part_payments = True
×
1844

1845
        def get_splits():
5✔
1846
            return suggest_splits(
5✔
1847
                amount_msat,
1848
                channels_with_funds,
1849
                exclude_single_part_payments=exclude_single_part_payments,
1850
                exclude_multinode_payments=exclude_multinode_payments,
1851
                exclude_single_channel_splits=exclude_single_channel_splits
1852
            )
1853

1854
        split_configurations = get_splits()
5✔
1855
        if not split_configurations and exclude_single_part_payments:
5✔
1856
            exclude_single_part_payments = False
×
1857
            split_configurations = get_splits()
×
1858
        self.logger.info(f'suggest_split {amount_msat} returned {len(split_configurations)} configurations')
5✔
1859
        return split_configurations
5✔
1860

1861
    async def create_routes_for_payment(
5✔
1862
            self, *,
1863
            paysession: PaySession,
1864
            amount_msat: int,        # part of payment amount we want routes for now
1865
            fwd_trampoline_onion: OnionPacket = None,
1866
            full_path: LNPaymentPath = None,
1867
            channels: Optional[Sequence[Channel]] = None,
1868
            budget: PaymentFeeBudget,
1869
    ) -> AsyncGenerator[Tuple[SentHtlcInfo, int, Optional[OnionPacket]], None]:
1870

1871
        """Creates multiple routes for splitting a payment over the available
1872
        private channels.
1873

1874
        We first try to conduct the payment over a single channel. If that fails
1875
        and mpp is supported by the receiver, we will split the payment."""
1876
        trampoline_features = LnFeatures.VAR_ONION_OPT
5✔
1877
        local_height = self.network.get_local_height()
5✔
1878
        fee_related_error = None  # type: Optional[FeeBudgetExceeded]
5✔
1879
        if channels:
5✔
1880
            my_active_channels = channels
×
1881
        else:
1882
            my_active_channels = [
5✔
1883
                chan for chan in self.channels.values() if
1884
                chan.is_active() and not chan.is_frozen_for_sending()]
1885
        # try random order
1886
        random.shuffle(my_active_channels)
5✔
1887
        split_configurations = self.suggest_splits(
5✔
1888
            amount_msat=amount_msat,
1889
            final_total_msat=paysession.amount_to_pay,
1890
            my_active_channels=my_active_channels,
1891
            invoice_features=paysession.invoice_features,
1892
            r_tags=paysession.r_tags,
1893
        )
1894
        for sc in split_configurations:
5✔
1895
            is_multichan_mpp = len(sc.config.items()) > 1
5✔
1896
            is_mpp = sc.config.number_parts() > 1
5✔
1897
            if is_mpp and not paysession.invoice_features.supports(LnFeatures.BASIC_MPP_OPT):
5✔
1898
                continue
5✔
1899
            if not is_mpp and self.config.TEST_FORCE_MPP:
5✔
1900
                continue
5✔
1901
            if is_mpp and self.config.TEST_FORCE_DISABLE_MPP:
5✔
1902
                continue
×
1903
            self.logger.info(f"trying split configuration: {sc.config.values()} rating: {sc.rating}")
5✔
1904
            routes = []
5✔
1905
            try:
5✔
1906
                if self.uses_trampoline():
5✔
1907
                    per_trampoline_channel_amounts = defaultdict(list)
5✔
1908
                    # categorize by trampoline nodes for trampoline mpp construction
1909
                    for (chan_id, _), part_amounts_msat in sc.config.items():
5✔
1910
                        chan = self.channels[chan_id]
5✔
1911
                        for part_amount_msat in part_amounts_msat:
5✔
1912
                            per_trampoline_channel_amounts[chan.node_id].append((chan_id, part_amount_msat))
5✔
1913
                    # for each trampoline forwarder, construct mpp trampoline
1914
                    for trampoline_node_id, trampoline_parts in per_trampoline_channel_amounts.items():
5✔
1915
                        per_trampoline_amount = sum([x[1] for x in trampoline_parts])
5✔
1916
                        trampoline_route, trampoline_onion, per_trampoline_amount_with_fees, per_trampoline_cltv_delta = create_trampoline_route_and_onion(
5✔
1917
                            amount_msat=per_trampoline_amount,
1918
                            total_msat=paysession.amount_to_pay,
1919
                            min_final_cltv_delta=paysession.min_final_cltv_delta,
1920
                            my_pubkey=self.node_keypair.pubkey,
1921
                            invoice_pubkey=paysession.invoice_pubkey,
1922
                            invoice_features=paysession.invoice_features,
1923
                            node_id=trampoline_node_id,
1924
                            r_tags=paysession.r_tags,
1925
                            payment_hash=paysession.payment_hash,
1926
                            payment_secret=paysession.payment_secret,
1927
                            local_height=local_height,
1928
                            trampoline_fee_level=paysession.trampoline_fee_level,
1929
                            use_two_trampolines=paysession.use_two_trampolines,
1930
                            failed_routes=paysession.failed_trampoline_routes,
1931
                            budget=budget._replace(fee_msat=budget.fee_msat // len(per_trampoline_channel_amounts)),
1932
                        )
1933
                        # node_features is only used to determine is_tlv
1934
                        per_trampoline_secret = os.urandom(32)
5✔
1935
                        per_trampoline_fees = per_trampoline_amount_with_fees - per_trampoline_amount
5✔
1936
                        self.logger.info(f'created route with trampoline fee level={paysession.trampoline_fee_level}')
5✔
1937
                        self.logger.info(f'trampoline hops: {[hop.end_node.hex() for hop in trampoline_route]}')
5✔
1938
                        self.logger.info(f'per trampoline fees: {per_trampoline_fees}')
5✔
1939
                        for chan_id, part_amount_msat in trampoline_parts:
5✔
1940
                            chan = self.channels[chan_id]
5✔
1941
                            margin = chan.available_to_spend(LOCAL, strict=True) - part_amount_msat
5✔
1942
                            delta_fee = min(per_trampoline_fees, margin)
5✔
1943
                            # TODO: distribute trampoline fee over several channels?
1944
                            part_amount_msat_with_fees = part_amount_msat + delta_fee
5✔
1945
                            per_trampoline_fees -= delta_fee
5✔
1946
                            route = [
5✔
1947
                                RouteEdge(
1948
                                    start_node=self.node_keypair.pubkey,
1949
                                    end_node=trampoline_node_id,
1950
                                    short_channel_id=chan.short_channel_id,
1951
                                    fee_base_msat=0,
1952
                                    fee_proportional_millionths=0,
1953
                                    cltv_delta=0,
1954
                                    node_features=trampoline_features)
1955
                            ]
1956
                            self.logger.info(f'adding route {part_amount_msat} {delta_fee} {margin}')
5✔
1957
                            shi = SentHtlcInfo(
5✔
1958
                                route=route,
1959
                                payment_secret_orig=paysession.payment_secret,
1960
                                payment_secret_bucket=per_trampoline_secret,
1961
                                amount_msat=part_amount_msat_with_fees,
1962
                                bucket_msat=per_trampoline_amount_with_fees,
1963
                                amount_receiver_msat=part_amount_msat,
1964
                                trampoline_fee_level=paysession.trampoline_fee_level,
1965
                                trampoline_route=trampoline_route,
1966
                            )
1967
                            routes.append((shi, per_trampoline_cltv_delta, trampoline_onion))
5✔
1968
                        if per_trampoline_fees != 0:
5✔
1969
                            e = 'not enough margin to pay trampoline fee'
×
1970
                            self.logger.info(e)
×
1971
                            raise FeeBudgetExceeded(e)
×
1972
                else:
1973
                    # We atomically loop through a split configuration. If there was
1974
                    # a failure to find a path for a single part, we try the next configuration
1975
                    for (chan_id, _), part_amounts_msat in sc.config.items():
5✔
1976
                        for part_amount_msat in part_amounts_msat:
5✔
1977
                            channel = self.channels[chan_id]
5✔
1978
                            route = await run_in_thread(
5✔
1979
                                partial(
1980
                                    self.create_route_for_single_htlc,
1981
                                    amount_msat=part_amount_msat,
1982
                                    invoice_pubkey=paysession.invoice_pubkey,
1983
                                    min_final_cltv_delta=paysession.min_final_cltv_delta,
1984
                                    r_tags=paysession.r_tags,
1985
                                    invoice_features=paysession.invoice_features,
1986
                                    my_sending_channels=[channel] if is_multichan_mpp else my_active_channels,
1987
                                    full_path=full_path,
1988
                                    budget=budget._replace(fee_msat=budget.fee_msat // sc.config.number_parts()),
1989
                                )
1990
                            )
1991
                            shi = SentHtlcInfo(
5✔
1992
                                route=route,
1993
                                payment_secret_orig=paysession.payment_secret,
1994
                                payment_secret_bucket=paysession.payment_secret,
1995
                                amount_msat=part_amount_msat,
1996
                                bucket_msat=paysession.amount_to_pay,
1997
                                amount_receiver_msat=part_amount_msat,
1998
                                trampoline_fee_level=None,
1999
                                trampoline_route=None,
2000
                            )
2001
                            routes.append((shi, paysession.min_final_cltv_delta, fwd_trampoline_onion))
5✔
2002
            except NoPathFound:
5✔
2003
                continue
5✔
2004
            except FeeBudgetExceeded as e:
5✔
2005
                fee_related_error = e
×
2006
                continue
×
2007
            for route in routes:
5✔
2008
                yield route
5✔
2009
            return
5✔
2010
        if fee_related_error is not None:
5✔
2011
            raise fee_related_error
×
2012
        raise NoPathFound()
5✔
2013

2014
    @profiler
5✔
2015
    def create_route_for_single_htlc(
5✔
2016
            self, *,
2017
            amount_msat: int,  # that final receiver gets
2018
            invoice_pubkey: bytes,
2019
            min_final_cltv_delta: int,
2020
            r_tags,
2021
            invoice_features: int,
2022
            my_sending_channels: List[Channel],
2023
            full_path: Optional[LNPaymentPath],
2024
            budget: PaymentFeeBudget,
2025
    ) -> LNPaymentRoute:
2026

2027
        my_sending_aliases = set(chan.get_local_scid_alias() for chan in my_sending_channels)
5✔
2028
        my_sending_channels = {chan.short_channel_id: chan for chan in my_sending_channels
5✔
2029
            if chan.short_channel_id is not None}
2030
        # Collect all private edges from route hints.
2031
        # Note: if some route hints are multiple edges long, and these paths cross each other,
2032
        #       we allow our path finding to cross the paths; i.e. the route hints are not isolated.
2033
        private_route_edges = {}  # type: Dict[ShortChannelID, RouteEdge]
5✔
2034
        for private_path in r_tags:
5✔
2035
            # we need to shift the node pubkey by one towards the destination:
2036
            private_path_nodes = [edge[0] for edge in private_path][1:] + [invoice_pubkey]
5✔
2037
            private_path_rest = [edge[1:] for edge in private_path]
5✔
2038
            start_node = private_path[0][0]
5✔
2039
            # remove aliases from direct routes
2040
            if len(private_path) == 1 and private_path[0][1] in my_sending_aliases:
5✔
2041
                self.logger.info(f'create_route: skipping alias {ShortChannelID(private_path[0][1])}')
×
2042
                continue
×
2043
            for end_node, edge_rest in zip(private_path_nodes, private_path_rest):
5✔
2044
                short_channel_id, fee_base_msat, fee_proportional_millionths, cltv_delta = edge_rest
5✔
2045
                short_channel_id = ShortChannelID(short_channel_id)
5✔
2046
                # if we have a routing policy for this edge in the db, that takes precedence,
2047
                # as it is likely from a previous failure
2048
                channel_policy = self.channel_db.get_policy_for_node(
5✔
2049
                    short_channel_id=short_channel_id,
2050
                    node_id=start_node,
2051
                    my_channels=my_sending_channels)
2052
                if channel_policy:
5✔
2053
                    fee_base_msat = channel_policy.fee_base_msat
5✔
2054
                    fee_proportional_millionths = channel_policy.fee_proportional_millionths
5✔
2055
                    cltv_delta = channel_policy.cltv_delta
5✔
2056
                node_info = self.channel_db.get_node_info_for_node_id(node_id=end_node)
5✔
2057
                route_edge = RouteEdge(
5✔
2058
                        start_node=start_node,
2059
                        end_node=end_node,
2060
                        short_channel_id=short_channel_id,
2061
                        fee_base_msat=fee_base_msat,
2062
                        fee_proportional_millionths=fee_proportional_millionths,
2063
                        cltv_delta=cltv_delta,
2064
                        node_features=node_info.features if node_info else 0)
2065
                private_route_edges[route_edge.short_channel_id] = route_edge
5✔
2066
                start_node = end_node
5✔
2067
        # now find a route, end to end: between us and the recipient
2068
        try:
5✔
2069
            route = self.network.path_finder.find_route(
5✔
2070
                nodeA=self.node_keypair.pubkey,
2071
                nodeB=invoice_pubkey,
2072
                invoice_amount_msat=amount_msat,
2073
                path=full_path,
2074
                my_sending_channels=my_sending_channels,
2075
                private_route_edges=private_route_edges)
2076
        except NoChannelPolicy as e:
5✔
2077
            raise NoPathFound() from e
×
2078
        if not route:
5✔
2079
            raise NoPathFound()
5✔
2080
        if not is_route_within_budget(
5✔
2081
            route, budget=budget, amount_msat_for_dest=amount_msat, cltv_delta_for_dest=min_final_cltv_delta,
2082
        ):
2083
            self.logger.info(f"rejecting route (exceeds budget): {route=}. {budget=}")
×
2084
            raise FeeBudgetExceeded()
×
2085
        assert len(route) > 0
5✔
2086
        if route[-1].end_node != invoice_pubkey:
5✔
2087
            raise LNPathInconsistent("last node_id != invoice pubkey")
5✔
2088
        # add features from invoice
2089
        route[-1].node_features |= invoice_features
5✔
2090
        return route
5✔
2091

2092
    def clear_invoices_cache(self):
5✔
2093
        self._bolt11_cache.clear()
×
2094

2095
    def get_bolt11_invoice(
5✔
2096
            self, *,
2097
            payment_hash: bytes,
2098
            amount_msat: Optional[int],
2099
            message: str,
2100
            expiry: int,  # expiration of invoice (in seconds, relative)
2101
            fallback_address: Optional[str],
2102
            channels: Optional[Sequence[Channel]] = None,
2103
            min_final_cltv_expiry_delta: Optional[int] = None,
2104
    ) -> Tuple[LnAddr, str]:
2105
        assert isinstance(payment_hash, bytes), f"expected bytes, but got {type(payment_hash)}"
×
2106

2107
        pair = self._bolt11_cache.get(payment_hash)
×
2108
        if pair:
×
2109
            lnaddr, invoice = pair
×
2110
            assert lnaddr.get_amount_msat() == amount_msat
×
2111
            return pair
×
2112

2113
        assert amount_msat is None or amount_msat > 0
×
2114
        timestamp = int(time.time())
×
NEW
2115
        needs_jit: bool = self._receive_requires_jit_channel(amount_msat)
×
NEW
2116
        routing_hints = self.calc_routing_hints_for_invoice(amount_msat, channels=channels, needs_jit=needs_jit)
×
NEW
2117
        self.logger.info(f"creating bolt11 invoice with routing_hints: {routing_hints}, jit: {needs_jit}")
×
2118
        invoice_features = self.features.for_invoice()
×
2119
        if not self.uses_trampoline():
×
2120
            invoice_features &= ~ LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ELECTRUM
×
NEW
2121
        if needs_jit:
×
2122
            # jit only works with single htlcs, mpp will cause LSP to open channels for each htlc
NEW
2123
            invoice_features &= ~ LnFeatures.BASIC_MPP_OPT & ~ LnFeatures.BASIC_MPP_REQ
×
2124
        payment_secret = self.get_payment_secret(payment_hash)
×
2125
        amount_btc = amount_msat/Decimal(COIN*1000) if amount_msat else None
×
2126
        if expiry == 0:
×
2127
            expiry = LN_EXPIRY_NEVER
×
2128
        if min_final_cltv_expiry_delta is None:
×
2129
            min_final_cltv_expiry_delta = MIN_FINAL_CLTV_DELTA_FOR_INVOICE
×
2130
        lnaddr = LnAddr(
×
2131
            paymenthash=payment_hash,
2132
            amount=amount_btc,
2133
            tags=[
2134
                ('d', message),
2135
                ('c', min_final_cltv_expiry_delta),
2136
                ('x', expiry),
2137
                ('9', invoice_features),
2138
                ('f', fallback_address),
2139
            ] + routing_hints,
2140
            date=timestamp,
2141
            payment_secret=payment_secret)
2142
        invoice = lnencode(lnaddr, self.node_keypair.privkey)
×
2143
        pair = lnaddr, invoice
×
2144
        self._bolt11_cache[payment_hash] = pair
×
2145
        return pair
×
2146

2147
    def get_payment_secret(self, payment_hash):
5✔
2148
        return sha256(sha256(self.payment_secret_key) + payment_hash)
5✔
2149

2150
    def _get_payment_key(self, payment_hash: bytes) -> bytes:
5✔
2151
        """Return payment bucket key.
2152
        We bucket htlcs based on payment_hash+payment_secret. payment_secret is included
2153
        as it changes over a trampoline path (in the outer onion), and these paths can overlap.
2154
        """
2155
        payment_secret = self.get_payment_secret(payment_hash)
5✔
2156
        return payment_hash + payment_secret
5✔
2157

2158
    def create_payment_info(self, *, amount_msat: Optional[int], write_to_disk=True) -> bytes:
5✔
2159
        payment_preimage = os.urandom(32)
5✔
2160
        payment_hash = sha256(payment_preimage)
5✔
2161
        info = PaymentInfo(payment_hash, amount_msat, RECEIVED, PR_UNPAID)
5✔
2162
        self.save_preimage(payment_hash, payment_preimage, write_to_disk=False)
5✔
2163
        self.save_payment_info(info, write_to_disk=False)
5✔
2164
        if write_to_disk:
5✔
2165
            self.wallet.save_db()
×
2166
        return payment_hash
5✔
2167

2168
    def bundle_payments(self, hash_list):
5✔
2169
        payment_keys = [self._get_payment_key(x) for x in hash_list]
5✔
2170
        self.payment_bundles.append(payment_keys)
5✔
2171

2172
    def get_payment_bundle(self, payment_key: bytes) -> Sequence[bytes]:
5✔
2173
        for key_list in self.payment_bundles:
5✔
2174
            if payment_key in key_list:
5✔
2175
                return key_list
5✔
2176

2177
    def save_preimage(self, payment_hash: bytes, preimage: bytes, *, write_to_disk: bool = True):
5✔
2178
        if sha256(preimage) != payment_hash:
5✔
2179
            raise Exception("tried to save incorrect preimage for payment_hash")
×
2180
        self.preimages[payment_hash.hex()] = preimage.hex()
5✔
2181
        if write_to_disk:
5✔
2182
            self.wallet.save_db()
5✔
2183

2184
    def get_preimage(self, payment_hash: bytes) -> Optional[bytes]:
5✔
2185
        assert isinstance(payment_hash, bytes), f"expected bytes, but got {type(payment_hash)}"
5✔
2186
        preimage_hex = self.preimages.get(payment_hash.hex())
5✔
2187
        if preimage_hex is None:
5✔
2188
            return None
5✔
2189
        preimage_bytes = bytes.fromhex(preimage_hex)
5✔
2190
        if sha256(preimage_bytes) != payment_hash:
5✔
2191
            raise Exception("found incorrect preimage for payment_hash")
×
2192
        return preimage_bytes
5✔
2193

2194
    def get_payment_info(self, payment_hash: bytes) -> Optional[PaymentInfo]:
5✔
2195
        """returns None if payment_hash is a payment we are forwarding"""
2196
        key = payment_hash.hex()
5✔
2197
        with self.lock:
5✔
2198
            if key in self.payment_info:
5✔
2199
                amount_msat, direction, status = self.payment_info[key]
5✔
2200
                return PaymentInfo(payment_hash, amount_msat, direction, status)
5✔
2201

2202
    def add_payment_info_for_hold_invoice(self, payment_hash: bytes, lightning_amount_sat: int):
5✔
2203
        info = PaymentInfo(payment_hash, lightning_amount_sat * 1000, RECEIVED, PR_UNPAID)
×
2204
        self.save_payment_info(info, write_to_disk=False)
×
2205

2206
    def register_hold_invoice(self, payment_hash: bytes, cb: Callable[[bytes], Awaitable[None]]):
5✔
2207
        self.hold_invoice_callbacks[payment_hash] = cb
5✔
2208

2209
    def unregister_hold_invoice(self, payment_hash: bytes):
5✔
2210
        self.hold_invoice_callbacks.pop(payment_hash)
×
2211

2212
    def save_payment_info(self, info: PaymentInfo, *, write_to_disk: bool = True) -> None:
5✔
2213
        key = info.payment_hash.hex()
5✔
2214
        assert info.status in SAVED_PR_STATUS
5✔
2215
        with self.lock:
5✔
2216
            self.payment_info[key] = info.amount_msat, info.direction, info.status
5✔
2217
        if write_to_disk:
5✔
2218
            self.wallet.save_db()
5✔
2219

2220
    def check_mpp_status(
5✔
2221
            self, *,
2222
            payment_secret: bytes,
2223
            short_channel_id: ShortChannelID,
2224
            htlc: UpdateAddHtlc,
2225
            expected_msat: int,
2226
    ) -> RecvMPPResolution:
2227
        """Returns the status of the incoming htlc set the given *htlc* belongs to.
2228

2229
        ACCEPTED simply means the mpp set is complete, and we can proceed with further
2230
        checks before fulfilling (or failing) the htlcs.
2231
        In particular, note that hold-invoice-htlcs typically remain in the ACCEPTED state
2232
        for quite some time -- not in the "WAITING" state (which would refer to the mpp set
2233
        not yet being complete!).
2234
        """
2235
        payment_hash = htlc.payment_hash
5✔
2236
        payment_key = payment_hash + payment_secret
5✔
2237
        self.update_mpp_with_received_htlc(
5✔
2238
            payment_key=payment_key, scid=short_channel_id, htlc=htlc, expected_msat=expected_msat)
2239
        mpp_resolution = self.received_mpp_htlcs[payment_key.hex()].resolution
5✔
2240
        # if still waiting, calc resolution now:
2241
        if mpp_resolution == RecvMPPResolution.WAITING:
5✔
2242
            bundle = self.get_payment_bundle(payment_key)
5✔
2243
            if bundle:
5✔
2244
                payment_keys = bundle
5✔
2245
            else:
2246
                payment_keys = [payment_key]
5✔
2247
            first_timestamp = min([self.get_first_timestamp_of_mpp(pkey) for pkey in payment_keys])
5✔
2248
            if self.get_payment_status(payment_hash) == PR_PAID:
5✔
2249
                mpp_resolution = RecvMPPResolution.ACCEPTED
×
2250
            elif self.stopping_soon:
5✔
2251
                # try to time out pending HTLCs before shutting down
2252
                mpp_resolution = RecvMPPResolution.EXPIRED
5✔
2253
            elif all([self.is_mpp_amount_reached(pkey) for pkey in payment_keys]):
5✔
2254
                mpp_resolution = RecvMPPResolution.ACCEPTED
5✔
2255
            elif time.time() - first_timestamp > self.MPP_EXPIRY:
5✔
2256
                mpp_resolution = RecvMPPResolution.EXPIRED
5✔
2257
            # save resolution, if any.
2258
            if mpp_resolution != RecvMPPResolution.WAITING:
5✔
2259
                for pkey in payment_keys:
5✔
2260
                    if pkey.hex() in self.received_mpp_htlcs:
5✔
2261
                        self.set_mpp_resolution(payment_key=pkey, resolution=mpp_resolution)
5✔
2262

2263
        return mpp_resolution
5✔
2264

2265
    def update_mpp_with_received_htlc(
5✔
2266
        self,
2267
        *,
2268
        payment_key: bytes,
2269
        scid: ShortChannelID,
2270
        htlc: UpdateAddHtlc,
2271
        expected_msat: int,
2272
    ):
2273
        # add new htlc to set
2274
        mpp_status = self.received_mpp_htlcs.get(payment_key.hex())
5✔
2275
        if mpp_status is None:
5✔
2276
            mpp_status = ReceivedMPPStatus(
5✔
2277
                resolution=RecvMPPResolution.WAITING,
2278
                expected_msat=expected_msat,
2279
                htlc_set=set(),
2280
            )
2281
        if expected_msat != mpp_status.expected_msat:
5✔
2282
            self.logger.info(
5✔
2283
                f"marking received mpp as failed. inconsistent total_msats in bucket. {payment_key.hex()=}")
2284
            mpp_status = mpp_status._replace(resolution=RecvMPPResolution.FAILED)
5✔
2285
        key = (scid, htlc)
5✔
2286
        if key not in mpp_status.htlc_set:
5✔
2287
            mpp_status.htlc_set.add(key)  # side-effecting htlc_set
5✔
2288
        self.received_mpp_htlcs[payment_key.hex()] = mpp_status
5✔
2289

2290
    def set_mpp_resolution(self, *, payment_key: bytes, resolution: RecvMPPResolution):
5✔
2291
        mpp_status = self.received_mpp_htlcs[payment_key.hex()]
5✔
2292
        self.logger.info(f'set_mpp_resolution {resolution.name} {len(mpp_status.htlc_set)} {payment_key.hex()}')
5✔
2293
        self.received_mpp_htlcs[payment_key.hex()] = mpp_status._replace(resolution=resolution)
5✔
2294

2295
    def is_mpp_amount_reached(self, payment_key: bytes) -> bool:
5✔
2296
        mpp_status = self.received_mpp_htlcs.get(payment_key.hex())
5✔
2297
        if not mpp_status:
5✔
2298
            return False
5✔
2299
        total = sum([_htlc.amount_msat for scid, _htlc in mpp_status.htlc_set])
5✔
2300
        return total >= mpp_status.expected_msat
5✔
2301

2302
    def is_accepted_mpp(self, payment_hash: bytes) -> bool:
5✔
2303
        payment_key = self._get_payment_key(payment_hash)
×
2304
        status = self.received_mpp_htlcs.get(payment_key.hex())
×
2305
        return status and status.resolution == RecvMPPResolution.ACCEPTED
×
2306

2307
    def get_first_timestamp_of_mpp(self, payment_key: bytes) -> int:
5✔
2308
        mpp_status = self.received_mpp_htlcs.get(payment_key.hex())
5✔
2309
        if not mpp_status:
5✔
2310
            return int(time.time())
5✔
2311
        return min([_htlc.timestamp for scid, _htlc in mpp_status.htlc_set])
5✔
2312

2313
    def maybe_cleanup_mpp(
5✔
2314
            self,
2315
            short_channel_id: ShortChannelID,
2316
            htlc: UpdateAddHtlc,
2317
    ) -> None:
2318

2319
        htlc_key = (short_channel_id, htlc)
5✔
2320
        for payment_key_hex, mpp_status in list(self.received_mpp_htlcs.items()):
5✔
2321
            if htlc_key not in mpp_status.htlc_set:
5✔
2322
                continue
5✔
2323
            assert mpp_status.resolution != RecvMPPResolution.WAITING
5✔
2324
            self.logger.info(f'maybe_cleanup_mpp: removing htlc of MPP {payment_key_hex}')
5✔
2325
            mpp_status.htlc_set.remove(htlc_key)  # side-effecting htlc_set
5✔
2326
            if len(mpp_status.htlc_set) == 0:
5✔
2327
                self.logger.info(f'maybe_cleanup_mpp: removing mpp {payment_key_hex}')
5✔
2328
                self.received_mpp_htlcs.pop(payment_key_hex)
5✔
2329
                self.maybe_cleanup_forwarding(payment_key_hex)
5✔
2330

2331
    def maybe_cleanup_forwarding(self, payment_key_hex: str) -> None:
5✔
2332
        self.active_forwardings.pop(payment_key_hex, None)
5✔
2333
        self.forwarding_failures.pop(payment_key_hex, None)
5✔
2334

2335
    def get_payment_status(self, payment_hash: bytes) -> int:
5✔
2336
        info = self.get_payment_info(payment_hash)
5✔
2337
        return info.status if info else PR_UNPAID
5✔
2338

2339
    def get_invoice_status(self, invoice: BaseInvoice) -> int:
5✔
2340
        invoice_id = invoice.rhash
5✔
2341
        status = self.get_payment_status(bfh(invoice_id))
5✔
2342
        if status == PR_UNPAID and invoice_id in self.inflight_payments:
5✔
2343
            return PR_INFLIGHT
×
2344
        # status may be PR_FAILED
2345
        if status == PR_UNPAID and invoice_id in self.logs:
5✔
2346
            status = PR_FAILED
×
2347
        return status
5✔
2348

2349
    def set_invoice_status(self, key: str, status: int) -> None:
5✔
2350
        if status == PR_INFLIGHT:
5✔
2351
            self.inflight_payments.add(key)
5✔
2352
        elif key in self.inflight_payments:
5✔
2353
            self.inflight_payments.remove(key)
5✔
2354
        if status in SAVED_PR_STATUS:
5✔
2355
            self.set_payment_status(bfh(key), status)
5✔
2356
        util.trigger_callback('invoice_status', self.wallet, key, status)
5✔
2357
        self.logger.info(f"set_invoice_status {key}: {status}")
5✔
2358
        # liquidity changed
2359
        self.clear_invoices_cache()
5✔
2360

2361
    def set_request_status(self, payment_hash: bytes, status: int) -> None:
5✔
2362
        if self.get_payment_status(payment_hash) == status:
5✔
2363
            return
5✔
2364
        self.set_payment_status(payment_hash, status)
5✔
2365
        request_id = payment_hash.hex()
5✔
2366
        req = self.wallet.get_request(request_id)
5✔
2367
        if req is None:
5✔
2368
            return
5✔
2369
        util.trigger_callback('request_status', self.wallet, request_id, status)
5✔
2370

2371
    def set_payment_status(self, payment_hash: bytes, status: int) -> None:
5✔
2372
        info = self.get_payment_info(payment_hash)
5✔
2373
        if info is None:
5✔
2374
            # if we are forwarding
2375
            return
5✔
2376
        info = info._replace(status=status)
5✔
2377
        self.save_payment_info(info)
5✔
2378

2379
    def is_forwarded_htlc(self, htlc_key) -> Optional[str]:
5✔
2380
        """Returns whether this was a forwarded HTLC."""
2381
        for payment_key, htlcs in self.active_forwardings.items():
5✔
2382
            if htlc_key in htlcs:
5✔
2383
                return payment_key
5✔
2384

2385
    def notify_upstream_peer(self, htlc_key: str) -> None:
5✔
2386
        """Called when an HTLC we offered on chan gets irrevocably fulfilled or failed.
2387
        If we find this was a forwarded HTLC, the upstream peer is notified.
2388
        """
2389
        upstream_key = self.downstream_to_upstream_htlc.pop(htlc_key, None)
5✔
2390
        if not upstream_key:
5✔
2391
            return
4✔
2392
        upstream_chan_scid, _ = deserialize_htlc_key(upstream_key)
5✔
2393
        upstream_chan = self.get_channel_by_short_id(upstream_chan_scid)
5✔
2394
        upstream_peer = self.peers.get(upstream_chan.node_id) if upstream_chan else None
5✔
2395
        if upstream_peer:
5✔
2396
            upstream_peer.downstream_htlc_resolved_event.set()
5✔
2397
            upstream_peer.downstream_htlc_resolved_event.clear()
5✔
2398

2399
    def htlc_fulfilled(self, chan: Channel, payment_hash: bytes, htlc_id: int):
5✔
2400

2401
        util.trigger_callback('htlc_fulfilled', payment_hash, chan, htlc_id)
5✔
2402
        htlc_key = serialize_htlc_key(chan.get_scid_or_local_alias(), htlc_id)
5✔
2403
        fw_key = self.is_forwarded_htlc(htlc_key)
5✔
2404
        if fw_key:
5✔
2405
            fw_htlcs = self.active_forwardings[fw_key]
5✔
2406
            fw_htlcs.remove(htlc_key)
5✔
2407

2408
        if shi := self.sent_htlcs_info.get((payment_hash, chan.short_channel_id, htlc_id)):
5✔
2409
            chan.pop_onion_key(htlc_id)
5✔
2410
            payment_key = payment_hash + shi.payment_secret_orig
5✔
2411
            paysession = self._paysessions[payment_key]
5✔
2412
            q = paysession.sent_htlcs_q
5✔
2413
            htlc_log = HtlcLog(
5✔
2414
                success=True,
2415
                route=shi.route,
2416
                amount_msat=shi.amount_receiver_msat,
2417
                trampoline_fee_level=shi.trampoline_fee_level)
2418
            q.put_nowait(htlc_log)
5✔
2419
            if paysession.can_be_deleted():
5✔
2420
                self._paysessions.pop(payment_key)
5✔
2421
                paysession_active = False
5✔
2422
            else:
2423
                paysession_active = True
5✔
2424
        else:
2425
            if fw_key:
5✔
2426
                paysession_active = False
5✔
2427
            else:
2428
                key = payment_hash.hex()
5✔
2429
                self.set_invoice_status(key, PR_PAID)
5✔
2430
                util.trigger_callback('payment_succeeded', self.wallet, key)
5✔
2431

2432
        if fw_key:
5✔
2433
            fw_htlcs = self.active_forwardings[fw_key]
5✔
2434
            if len(fw_htlcs) == 0 and not paysession_active:
5✔
2435
                self.notify_upstream_peer(htlc_key)
5✔
2436

2437
    def htlc_failed(
5✔
2438
            self,
2439
            chan: Channel,
2440
            payment_hash: bytes,
2441
            htlc_id: int,
2442
            error_bytes: Optional[bytes],
2443
            failure_message: Optional['OnionRoutingFailure']):
2444

2445
        util.trigger_callback('htlc_failed', payment_hash, chan, htlc_id)
5✔
2446
        htlc_key = serialize_htlc_key(chan.get_scid_or_local_alias(), htlc_id)
5✔
2447
        fw_key = self.is_forwarded_htlc(htlc_key)
5✔
2448
        if fw_key:
5✔
2449
            fw_htlcs = self.active_forwardings[fw_key]
5✔
2450
            fw_htlcs.remove(htlc_key)
5✔
2451

2452
        if shi := self.sent_htlcs_info.get((payment_hash, chan.short_channel_id, htlc_id)):
5✔
2453
            onion_key = chan.pop_onion_key(htlc_id)
5✔
2454
            payment_okey = payment_hash + shi.payment_secret_orig
5✔
2455
            paysession = self._paysessions[payment_okey]
5✔
2456
            q = paysession.sent_htlcs_q
5✔
2457
            # detect if it is part of a bucket
2458
            # if yes, wait until the bucket completely failed
2459
            route = shi.route
5✔
2460
            if error_bytes:
5✔
2461
                # TODO "decode_onion_error" might raise, catch and maybe blacklist/penalise someone?
2462
                try:
5✔
2463
                    failure_message, sender_idx = decode_onion_error(
5✔
2464
                        error_bytes,
2465
                        [x.node_id for x in route],
2466
                        onion_key)
2467
                except Exception as e:
4✔
2468
                    sender_idx = None
4✔
2469
                    failure_message = OnionRoutingFailure(OnionFailureCode.INVALID_ONION_PAYLOAD, str(e).encode())
4✔
2470
            else:
2471
                # probably got "update_fail_malformed_htlc". well... who to penalise now?
2472
                assert failure_message is not None
×
2473
                sender_idx = None
×
2474
            self.logger.info(f"htlc_failed {failure_message}")
5✔
2475
            amount_receiver_msat = paysession.on_htlc_fail_get_fail_amt_to_propagate(shi)
5✔
2476
            if amount_receiver_msat is None:
5✔
2477
                return
5✔
2478
            if shi.trampoline_route:
5✔
2479
                route = shi.trampoline_route
5✔
2480
            htlc_log = HtlcLog(
5✔
2481
                success=False,
2482
                route=route,
2483
                amount_msat=amount_receiver_msat,
2484
                error_bytes=error_bytes,
2485
                failure_msg=failure_message,
2486
                sender_idx=sender_idx,
2487
                trampoline_fee_level=shi.trampoline_fee_level)
2488
            q.put_nowait(htlc_log)
5✔
2489
            if paysession.can_be_deleted():
5✔
2490
                self._paysessions.pop(payment_okey)
5✔
2491
                paysession_active = False
5✔
2492
            else:
2493
                paysession_active = True
5✔
2494
        else:
2495
            if fw_key:
5✔
2496
                paysession_active = False
5✔
2497
            else:
2498
                self.logger.info(f"received unknown htlc_failed, probably from previous session (phash={payment_hash.hex()})")
5✔
2499
                key = payment_hash.hex()
5✔
2500
                self.set_invoice_status(key, PR_UNPAID)
5✔
2501
                util.trigger_callback('payment_failed', self.wallet, key, '')
5✔
2502

2503
        if fw_key:
5✔
2504
            fw_htlcs = self.active_forwardings[fw_key]
5✔
2505
            can_forward_failure = (len(fw_htlcs) == 0) and not paysession_active
5✔
2506
            if can_forward_failure:
5✔
2507
                self.save_forwarding_failure(fw_key, error_bytes=error_bytes, failure_message=failure_message)
5✔
2508
                self.notify_upstream_peer(htlc_key)
5✔
2509
            else:
2510
                self.logger.info(f"waiting for other htlcs to fail (phash={payment_hash.hex()})")
5✔
2511

2512
    def calc_routing_hints_for_invoice(self, amount_msat: Optional[int], channels=None, needs_jit=False):
5✔
2513
        """calculate routing hints (BOLT-11 'r' field)"""
2514
        routing_hints = []
5✔
2515
        if needs_jit:
5✔
2516
            node_id, rest = extract_nodeid(self.config.ZEROCONF_TRUSTED_NODE)
×
NEW
2517
            alias_or_scid = self.get_static_jit_scid_alias()
×
2518
            routing_hints.append(('r', [(node_id, alias_or_scid, 0, 0, 144)]))
×
2519
            # no need for more because we cannot receive enough through the others and mpp is disabled for jit
2520
            channels = []
×
2521
        else:
2522
            if channels is None:
5✔
2523
                channels = list(self.get_channels_for_receiving(amount_msat))
5✔
2524
                random.shuffle(channels)  # let's not leak channel order
5✔
2525
            scid_to_my_channels = {
5✔
2526
                chan.short_channel_id: chan for chan in channels
2527
                if chan.short_channel_id is not None
2528
            }
2529
        for chan in channels:
5✔
2530
            alias_or_scid = chan.get_remote_scid_alias() or chan.short_channel_id
5✔
2531
            assert isinstance(alias_or_scid, bytes), alias_or_scid
5✔
2532
            channel_info = get_mychannel_info(chan.short_channel_id, scid_to_my_channels)
5✔
2533
            # note: as a fallback, if we don't have a channel update for the
2534
            # incoming direction of our private channel, we fill the invoice with garbage.
2535
            # the sender should still be able to pay us, but will incur an extra round trip
2536
            # (they will get the channel update from the onion error)
2537
            # at least, that's the theory. https://github.com/lightningnetwork/lnd/issues/2066
2538
            fee_base_msat = fee_proportional_millionths = 0
5✔
2539
            cltv_delta = 1  # lnd won't even try with zero
5✔
2540
            missing_info = True
5✔
2541
            if channel_info:
5✔
2542
                policy = get_mychannel_policy(channel_info.short_channel_id, chan.node_id, scid_to_my_channels)
5✔
2543
                if policy:
5✔
2544
                    fee_base_msat = policy.fee_base_msat
5✔
2545
                    fee_proportional_millionths = policy.fee_proportional_millionths
5✔
2546
                    cltv_delta = policy.cltv_delta
5✔
2547
                    missing_info = False
5✔
2548
            if missing_info:
5✔
2549
                self.logger.info(
×
2550
                    f"Warning. Missing channel update for our channel {chan.short_channel_id}; "
2551
                    f"filling invoice with incorrect data.")
2552
            routing_hints.append(('r', [(
5✔
2553
                chan.node_id,
2554
                alias_or_scid,
2555
                fee_base_msat,
2556
                fee_proportional_millionths,
2557
                cltv_delta)]))
2558
        return routing_hints
5✔
2559

2560
    def delete_payment_info(self, payment_hash_hex: str):
5✔
2561
        # This method is called when an invoice or request is deleted by the user.
2562
        # The GUI only lets the user delete invoices or requests that have not been paid.
2563
        # Once an invoice/request has been paid, it is part of the history,
2564
        # and get_lightning_history assumes that payment_info is there.
2565
        assert self.get_payment_status(bytes.fromhex(payment_hash_hex)) != PR_PAID
×
2566
        with self.lock:
×
2567
            self.payment_info.pop(payment_hash_hex, None)
×
2568

2569
    def get_balance(self, frozen=False):
5✔
2570
        with self.lock:
×
2571
            return Decimal(sum(
×
2572
                chan.balance(LOCAL) if not chan.is_closed() and (chan.is_frozen_for_sending() if frozen else True) else 0
2573
                for chan in self.channels.values())) / 1000
2574

2575
    def get_channels_for_sending(self):
5✔
2576
        for c in self.channels.values():
×
2577
            if c.is_active() and not c.is_frozen_for_sending():
×
2578
                if self.channel_db or self.is_trampoline_peer(c.node_id):
×
2579
                    yield c
×
2580

2581
    def fee_estimate(self, amount_sat):
5✔
2582
        # Here we have to guess a fee, because some callers (submarine swaps)
2583
        # use this method to initiate a payment, which would otherwise fail.
2584
        fee_base_msat = 5000               # FIXME ehh.. there ought to be a better way...
×
2585
        fee_proportional_millionths = 500  # FIXME
×
2586
        # inverse of fee_for_edge_msat
2587
        amount_msat = amount_sat * 1000
×
2588
        amount_minus_fees = (amount_msat - fee_base_msat) * 1_000_000 // ( 1_000_000 + fee_proportional_millionths)
×
2589
        return Decimal(amount_msat - amount_minus_fees) / 1000
×
2590

2591
    def num_sats_can_send(self, deltas=None) -> Decimal:
5✔
2592
        """
2593
        without trampoline, sum of all channel capacity
2594
        with trampoline, MPP must use a single trampoline
2595
        """
2596
        if deltas is None:
×
2597
            deltas = {}
×
2598

2599
        def send_capacity(chan):
×
2600
            if chan in deltas:
×
2601
                delta_msat = deltas[chan] * 1000
×
2602
                if delta_msat > chan.available_to_spend(REMOTE):
×
2603
                    delta_msat = 0
×
2604
            else:
2605
                delta_msat = 0
×
2606
            return chan.available_to_spend(LOCAL) + delta_msat
×
2607
        can_send_dict = defaultdict(int)
×
2608
        with self.lock:
×
2609
            for c in self.get_channels_for_sending():
×
2610
                if not self.uses_trampoline():
×
2611
                    can_send_dict[0] += send_capacity(c)
×
2612
                else:
2613
                    can_send_dict[c.node_id] += send_capacity(c)
×
2614
        can_send = max(can_send_dict.values()) if can_send_dict else 0
×
2615
        can_send_sat = Decimal(can_send)/1000
×
2616
        can_send_sat -= self.fee_estimate(can_send_sat)
×
2617
        return max(can_send_sat, 0)
×
2618

2619
    def get_channels_for_receiving(self, amount_msat=None) -> Sequence[Channel]:
5✔
2620
        if not amount_msat:  # assume we want to recv a large amt, e.g. finding max.
5✔
2621
            amount_msat = float('inf')
×
2622
        with self.lock:
5✔
2623
            channels = list(self.channels.values())
5✔
2624
            # we exclude channels that cannot *right now* receive (e.g. peer offline)
2625
            channels = [chan for chan in channels
5✔
2626
                        if (chan.is_open() and not chan.is_frozen_for_receiving())]
2627
            # Filter out nodes that have low receive capacity compared to invoice amt.
2628
            # Even with MPP, below a certain threshold, including these channels probably
2629
            # hurts more than help, as they lead to many failed attempts for the sender.
2630
            channels = sorted(channels, key=lambda chan: -chan.available_to_spend(REMOTE))
5✔
2631
            selected_channels = []
5✔
2632
            running_sum = 0
5✔
2633
            cutoff_factor = 0.2  # heuristic
5✔
2634
            for chan in channels:
5✔
2635
                recv_capacity = chan.available_to_spend(REMOTE)
5✔
2636
                chan_can_handle_payment_as_single_part = recv_capacity >= amount_msat
5✔
2637
                chan_small_compared_to_running_sum = recv_capacity < cutoff_factor * running_sum
5✔
2638
                if not chan_can_handle_payment_as_single_part and chan_small_compared_to_running_sum:
5✔
2639
                    break
5✔
2640
                running_sum += recv_capacity
5✔
2641
                selected_channels.append(chan)
5✔
2642
            channels = selected_channels
5✔
2643
            del selected_channels
5✔
2644
            # cap max channels to include to keep QR code reasonably scannable
2645
            channels = channels[:10]
5✔
2646
            return channels
5✔
2647

2648
    def num_sats_can_receive(self, deltas=None) -> Decimal:
5✔
2649
        """
2650
        We no longer assume the sender to send MPP on different channels,
2651
        because channel liquidities are hard to guess
2652
        """
2653
        if deltas is None:
×
2654
            deltas = {}
×
2655

2656
        def recv_capacity(chan):
×
2657
            if chan in deltas:
×
2658
                delta_msat = deltas[chan] * 1000
×
2659
                if delta_msat > chan.available_to_spend(LOCAL):
×
2660
                    delta_msat = 0
×
2661
            else:
2662
                delta_msat = 0
×
2663
            return chan.available_to_spend(REMOTE) + delta_msat
×
2664
        with self.lock:
×
2665
            recv_channels = self.get_channels_for_receiving()
×
2666
            recv_chan_msats = [recv_capacity(chan) for chan in recv_channels]
×
2667
        if not recv_chan_msats:
×
2668
            return Decimal(0)
×
2669
        can_receive_msat = max(recv_chan_msats)
×
2670
        return Decimal(can_receive_msat) / 1000
×
2671

2672
    def _receive_requires_jit_channel(self, amount_msat: Optional[int]) -> bool:
5✔
2673
        """Returns true if we cannot receive the amount and have set up a trusted LSP node.
2674
        Cannot work reliably with 0 amount invoices as we don't know if we are able to receive it.
2675
        """
2676
        # a trusted zeroconf node is configured
NEW
2677
        if (self.config.ZEROCONF_TRUSTED_NODE
×
2678
                # the zeroconf node is a peer, it doesn't make sense to request a channel from an offline LSP
2679
                and extract_nodeid(self.config.ZEROCONF_TRUSTED_NODE)[0] in self.peers
2680
                # we cannot receive the amount specified
2681
                and ((amount_msat and self.num_sats_can_receive() < (amount_msat // 1000))
2682
                    # or we cannot receive anything, and it's a 0 amount invoice
2683
                    or (not amount_msat and self.num_sats_can_receive() < 1))):
NEW
2684
            return True
×
NEW
2685
        return False
×
2686

2687
    def _suggest_channels_for_rebalance(self, direction, amount_sat) -> Sequence[Tuple[Channel, int]]:
5✔
2688
        """
2689
        Suggest a channel and amount to send/receive with that channel, so that we will be able to receive/send amount_sat
2690
        This is used when suggesting a swap or rebalance in order to receive a payment
2691
        """
2692
        with self.lock:
×
2693
            func = self.num_sats_can_send if direction == SENT else self.num_sats_can_receive
×
2694
            suggestions = []
×
2695
            channels = self.get_channels_for_sending() if direction == SENT else self.get_channels_for_receiving()
×
2696
            for chan in channels:
×
2697
                available_sat = chan.available_to_spend(LOCAL if direction == SENT else REMOTE) // 1000
×
2698
                delta = amount_sat - available_sat
×
2699
                delta += self.fee_estimate(amount_sat)
×
2700
                # add safety margin
2701
                delta += delta // 100 + 1
×
2702
                if func(deltas={chan:delta}) >= amount_sat:
×
2703
                    suggestions.append((chan, delta))
×
2704
                elif direction==RECEIVED and func(deltas={chan:2*delta}) >= amount_sat:
×
2705
                    # MPP heuristics has a 0.5 slope
2706
                    suggestions.append((chan, 2*delta))
×
2707
        if not suggestions:
×
2708
            raise NotEnoughFunds
×
2709
        return suggestions
×
2710

2711
    def _suggest_rebalance(self, direction, amount_sat):
5✔
2712
        """
2713
        Suggest a rebalance in order to be able to send or receive amount_sat.
2714
        Returns (from_channel, to_channel, amount to shuffle)
2715
        """
2716
        try:
×
2717
            suggestions = self._suggest_channels_for_rebalance(direction, amount_sat)
×
2718
        except NotEnoughFunds:
×
2719
            return False
×
2720
        for chan2, delta in suggestions:
×
2721
            # margin for fee caused by rebalancing
2722
            delta += self.fee_estimate(amount_sat)
×
2723
            # find other channel or trampoline that can send delta
2724
            for chan1 in self.channels.values():
×
2725
                if chan1.is_frozen_for_sending() or not chan1.is_active():
×
2726
                    continue
×
2727
                if chan1 == chan2:
×
2728
                    continue
×
2729
                if self.uses_trampoline() and chan1.node_id == chan2.node_id:
×
2730
                    continue
×
2731
                if direction == SENT:
×
2732
                    if chan1.can_pay(delta*1000):
×
2733
                        return (chan1, chan2, delta)
×
2734
                else:
2735
                    if chan1.can_receive(delta*1000):
×
2736
                        return (chan2, chan1, delta)
×
2737
            else:
2738
                continue
×
2739
        else:
2740
            return False
×
2741

2742
    def num_sats_can_rebalance(self, chan1, chan2):
5✔
2743
        # TODO: we should be able to spend 'max', with variable fee
2744
        n1 = chan1.available_to_spend(LOCAL)
×
2745
        n1 -= self.fee_estimate(n1)
×
2746
        n2 = chan2.available_to_spend(REMOTE)
×
2747
        amount_sat = min(n1, n2) // 1000
×
2748
        return amount_sat
×
2749

2750
    def suggest_rebalance_to_send(self, amount_sat):
5✔
2751
        return self._suggest_rebalance(SENT, amount_sat)
×
2752

2753
    def suggest_rebalance_to_receive(self, amount_sat):
5✔
2754
        return self._suggest_rebalance(RECEIVED, amount_sat)
×
2755

2756
    def suggest_swap_to_send(self, amount_sat, coins):
5✔
2757
        # fixme: if swap_amount_sat is lower than the minimum swap amount, we need to propose a higher value
2758
        assert amount_sat > self.num_sats_can_send()
×
2759
        try:
×
2760
            suggestions = self._suggest_channels_for_rebalance(SENT, amount_sat)
×
2761
        except NotEnoughFunds:
×
2762
            return
×
2763
        for chan, swap_recv_amount in suggestions:
×
2764
            # check that we can send onchain
2765
            swap_server_mining_fee = 10000 # guessing, because we have not called get_pairs yet
×
2766
            swap_funding_sat = swap_recv_amount + swap_server_mining_fee
×
2767
            swap_output = PartialTxOutput.from_address_and_value(DummyAddress.SWAP, int(swap_funding_sat))
×
2768
            if not self.wallet.can_pay_onchain([swap_output], coins=coins):
×
2769
                continue
×
2770
            return chan, swap_recv_amount
×
2771

2772
    def suggest_swap_to_receive(self, amount_sat):
5✔
2773
        assert amount_sat > self.num_sats_can_receive()
×
2774
        try:
×
2775
            suggestions = self._suggest_channels_for_rebalance(RECEIVED, amount_sat)
×
2776
        except NotEnoughFunds:
×
2777
            return
×
2778
        for chan, swap_recv_amount in suggestions:
×
2779
            return chan, swap_recv_amount
×
2780

2781
    async def rebalance_channels(self, chan1: Channel, chan2: Channel, *, amount_msat: int):
5✔
2782
        if chan1 == chan2:
×
2783
            raise Exception('Rebalance requires two different channels')
×
2784
        if self.uses_trampoline() and chan1.node_id == chan2.node_id:
×
2785
            raise Exception('Rebalance requires channels from different trampolines')
×
2786
        payment_hash = self.create_payment_info(amount_msat=amount_msat)
×
2787
        lnaddr, invoice = self.get_bolt11_invoice(
×
2788
            payment_hash=payment_hash,
2789
            amount_msat=amount_msat,
2790
            message='rebalance',
2791
            expiry=3600,
2792
            fallback_address=None,
2793
            channels=[chan2],
2794
        )
2795
        return await self.pay_invoice(
×
2796
            invoice, channels=[chan1])
2797

2798
    def can_receive_invoice(self, invoice: BaseInvoice) -> bool:
5✔
2799
        assert invoice.is_lightning()
×
2800
        return (invoice.get_amount_sat() or 0) <= self.num_sats_can_receive()
×
2801

2802
    async def close_channel(self, chan_id):
5✔
2803
        chan = self._channels[chan_id]
×
2804
        peer = self._peers[chan.node_id]
×
2805
        return await peer.close_channel(chan_id)
×
2806

2807
    def _force_close_channel(self, chan_id: bytes) -> Transaction:
5✔
2808
        chan = self._channels[chan_id]
5✔
2809
        tx = chan.force_close_tx()
5✔
2810
        # We set the channel state to make sure we won't sign new commitment txs.
2811
        # We expect the caller to try to broadcast this tx, after which it is
2812
        # not safe to keep using the channel even if the broadcast errors (server could be lying).
2813
        # Until the tx is seen in the mempool, there will be automatic rebroadcasts.
2814
        chan.set_state(ChannelState.FORCE_CLOSING)
5✔
2815
        # Add local tx to wallet to also allow manual rebroadcasts.
2816
        try:
5✔
2817
            self.wallet.adb.add_transaction(tx)
5✔
2818
        except UnrelatedTransactionException:
×
2819
            pass  # this can happen if (~all the balance goes to REMOTE)
×
2820
        return tx
5✔
2821

2822
    async def force_close_channel(self, chan_id: bytes) -> str:
5✔
2823
        """Force-close the channel. Network-related exceptions are propagated to the caller.
2824
        (automatic rebroadcasts will be scheduled)
2825
        """
2826
        # note: as we are async, it can take a few event loop iterations between the caller
2827
        #       "calling us" and us getting to run, and we only set the channel state now:
2828
        tx = self._force_close_channel(chan_id)
5✔
2829
        await self.network.broadcast_transaction(tx)
5✔
2830
        return tx.txid()
5✔
2831

2832
    def schedule_force_closing(self, chan_id: bytes) -> 'asyncio.Task[bool]':
5✔
2833
        """Schedules a task to force-close the channel and returns it.
2834
        Network-related exceptions are suppressed.
2835
        (automatic rebroadcasts will be scheduled)
2836
        Note: this method is intentionally not async so that callers have a guarantee
2837
              that the channel state is set immediately.
2838
        """
2839
        tx = self._force_close_channel(chan_id)
5✔
2840
        return asyncio.create_task(self.network.try_broadcasting(tx, 'force-close'))
5✔
2841

2842
    def remove_channel(self, chan_id):
5✔
2843
        chan = self.channels[chan_id]
×
2844
        assert chan.can_be_deleted()
×
2845
        with self.lock:
×
2846
            self._channels.pop(chan_id)
×
2847
            self.db.get('channels').pop(chan_id.hex())
×
2848
        self.wallet.set_reserved_addresses_for_chan(chan, reserved=False)
×
2849

2850
        util.trigger_callback('channels_updated', self.wallet)
×
2851
        util.trigger_callback('wallet_updated', self.wallet)
×
2852

2853
    @ignore_exceptions
5✔
2854
    @log_exceptions
5✔
2855
    async def reestablish_peer_for_given_channel(self, chan: Channel) -> None:
5✔
2856
        now = time.time()
×
2857
        peer_addresses = []
×
2858
        if self.uses_trampoline():
×
2859
            addr = trampolines_by_id().get(chan.node_id)
×
2860
            if addr:
×
2861
                peer_addresses.append(addr)
×
2862
        else:
2863
            # will try last good address first, from gossip
2864
            last_good_addr = self.channel_db.get_last_good_address(chan.node_id)
×
2865
            if last_good_addr:
×
2866
                peer_addresses.append(last_good_addr)
×
2867
            # will try addresses for node_id from gossip
2868
            addrs_from_gossip = self.channel_db.get_node_addresses(chan.node_id) or []
×
2869
            for host, port, ts in addrs_from_gossip:
×
2870
                peer_addresses.append(LNPeerAddr(host, port, chan.node_id))
×
2871
        # will try addresses stored in channel storage
2872
        peer_addresses += list(chan.get_peer_addresses())
×
2873
        # Done gathering addresses.
2874
        # Now select first one that has not failed recently.
2875
        for peer in peer_addresses:
×
2876
            if self._can_retry_addr(peer, urgent=True, now=now):
×
2877
                await self._add_peer(peer.host, peer.port, peer.pubkey)
×
2878
                return
×
2879

2880
    async def reestablish_peers_and_channels(self):
5✔
2881
        while True:
×
2882
            await asyncio.sleep(1)
×
2883
            if self.stopping_soon:
×
2884
                return
×
2885
            if self.config.ZEROCONF_TRUSTED_NODE:
×
2886
                peer = LNPeerAddr.from_str(self.config.ZEROCONF_TRUSTED_NODE)
×
2887
                if self._can_retry_addr(peer, urgent=True):
×
2888
                    await self._add_peer(peer.host, peer.port, peer.pubkey)
×
2889
            for chan in self.channels.values():
×
2890
                # reestablish
2891
                # note: we delegate filtering out uninteresting chans to this:
2892
                if not chan.should_try_to_reestablish_peer():
×
2893
                    continue
×
2894
                peer = self._peers.get(chan.node_id, None)
×
2895
                if peer:
×
2896
                    await peer.taskgroup.spawn(peer.reestablish_channel(chan))
×
2897
                else:
2898
                    await self.taskgroup.spawn(self.reestablish_peer_for_given_channel(chan))
×
2899

2900
    def current_target_feerate_per_kw(self) -> int:
5✔
2901
        from .simple_config import FEE_LN_ETA_TARGET, FEERATE_FALLBACK_STATIC_FEE
5✔
2902
        from .simple_config import FEERATE_PER_KW_MIN_RELAY_LIGHTNING
5✔
2903
        if constants.net is constants.BitcoinRegtest:
5✔
2904
            feerate_per_kvbyte = self.network.config.FEE_EST_STATIC_FEERATE
×
2905
        else:
2906
            feerate_per_kvbyte = self.network.config.eta_target_to_fee(FEE_LN_ETA_TARGET)
5✔
2907
            if feerate_per_kvbyte is None:
5✔
2908
                feerate_per_kvbyte = FEERATE_FALLBACK_STATIC_FEE
5✔
2909
        return max(FEERATE_PER_KW_MIN_RELAY_LIGHTNING, feerate_per_kvbyte // 4)
5✔
2910

2911
    def current_low_feerate_per_kw(self) -> int:
5✔
2912
        from .simple_config import FEE_LN_LOW_ETA_TARGET
5✔
2913
        from .simple_config import FEERATE_PER_KW_MIN_RELAY_LIGHTNING
5✔
2914
        if constants.net is constants.BitcoinRegtest:
5✔
2915
            feerate_per_kvbyte = 0
×
2916
        else:
2917
            feerate_per_kvbyte = self.network.config.eta_target_to_fee(FEE_LN_LOW_ETA_TARGET) or 0
5✔
2918
        low_feerate_per_kw = max(FEERATE_PER_KW_MIN_RELAY_LIGHTNING, feerate_per_kvbyte // 4)
5✔
2919
        # make sure this is never higher than the target feerate:
2920
        low_feerate_per_kw = min(low_feerate_per_kw, self.current_target_feerate_per_kw())
5✔
2921
        return low_feerate_per_kw
5✔
2922

2923
    def create_channel_backup(self, channel_id: bytes):
5✔
2924
        chan = self._channels[channel_id]
×
2925
        # do not backup old-style channels
2926
        assert chan.is_static_remotekey_enabled()
×
2927
        peer_addresses = list(chan.get_peer_addresses())
×
2928
        peer_addr = peer_addresses[0]
×
2929
        return ImportedChannelBackupStorage(
×
2930
            node_id = chan.node_id,
2931
            privkey = self.node_keypair.privkey,
2932
            funding_txid = chan.funding_outpoint.txid,
2933
            funding_index = chan.funding_outpoint.output_index,
2934
            funding_address = chan.get_funding_address(),
2935
            host = peer_addr.host,
2936
            port = peer_addr.port,
2937
            is_initiator = chan.constraints.is_initiator,
2938
            channel_seed = chan.config[LOCAL].channel_seed,
2939
            local_delay = chan.config[LOCAL].to_self_delay,
2940
            remote_delay = chan.config[REMOTE].to_self_delay,
2941
            remote_revocation_pubkey = chan.config[REMOTE].revocation_basepoint.pubkey,
2942
            remote_payment_pubkey = chan.config[REMOTE].payment_basepoint.pubkey,
2943
            local_payment_pubkey=chan.config[LOCAL].payment_basepoint.pubkey,
2944
            multisig_funding_privkey=chan.config[LOCAL].multisig_key.privkey,
2945
        )
2946

2947
    def export_channel_backup(self, channel_id):
5✔
2948
        xpub = self.wallet.get_fingerprint()
×
2949
        backup_bytes = self.create_channel_backup(channel_id).to_bytes()
×
2950
        assert backup_bytes == ImportedChannelBackupStorage.from_bytes(backup_bytes).to_bytes(), "roundtrip failed"
×
2951
        encrypted = pw_encode_with_version_and_mac(backup_bytes, xpub)
×
2952
        assert backup_bytes == pw_decode_with_version_and_mac(encrypted, xpub), "encrypt failed"
×
2953
        return 'channel_backup:' + encrypted
×
2954

2955
    async def request_force_close(self, channel_id: bytes, *, connect_str=None) -> None:
5✔
2956
        if channel_id in self.channels:
×
2957
            chan = self.channels[channel_id]
×
2958
            peer = self._peers.get(chan.node_id)
×
2959
            chan.should_request_force_close = True
×
2960
            if peer:
×
2961
                peer.close_and_cleanup()  # to force a reconnect
×
2962
        elif connect_str:
×
2963
            peer = await self.add_peer(connect_str)
×
2964
            await peer.request_force_close(channel_id)
×
2965
        elif channel_id in self.channel_backups:
×
2966
            await self._request_force_close_from_backup(channel_id)
×
2967
        else:
2968
            raise Exception(f'Unknown channel {channel_id.hex()}')
×
2969

2970
    def import_channel_backup(self, data):
5✔
2971
        xpub = self.wallet.get_fingerprint()
×
2972
        cb_storage = ImportedChannelBackupStorage.from_encrypted_str(data, password=xpub)
×
2973
        channel_id = cb_storage.channel_id()
×
2974
        if channel_id.hex() in self.db.get_dict("channels"):
×
2975
            raise Exception('Channel already in wallet')
×
2976
        self.logger.info(f'importing channel backup: {channel_id.hex()}')
×
2977
        d = self.db.get_dict("imported_channel_backups")
×
2978
        d[channel_id.hex()] = cb_storage
×
2979
        with self.lock:
×
2980
            cb = ChannelBackup(cb_storage, lnworker=self)
×
2981
            self._channel_backups[channel_id] = cb
×
2982
        self.wallet.set_reserved_addresses_for_chan(cb, reserved=True)
×
2983
        self.wallet.save_db()
×
2984
        util.trigger_callback('channels_updated', self.wallet)
×
2985
        self.lnwatcher.add_channel(cb.funding_outpoint.to_str(), cb.get_funding_address())
×
2986

2987
    def has_conflicting_backup_with(self, remote_node_id: bytes):
5✔
2988
        """ Returns whether we have an active channel with this node on another device, using same local node id. """
2989
        channel_backup_peers = [
×
2990
            cb.node_id for cb in self.channel_backups.values()
2991
            if (not cb.is_closed() and cb.get_local_pubkey() == self.node_keypair.pubkey)]
2992
        return any(remote_node_id.startswith(cb_peer_nodeid) for cb_peer_nodeid in channel_backup_peers)
×
2993

2994
    def remove_channel_backup(self, channel_id):
5✔
2995
        chan = self.channel_backups[channel_id]
×
2996
        assert chan.can_be_deleted()
×
2997
        found = False
×
2998
        onchain_backups = self.db.get_dict("onchain_channel_backups")
×
2999
        imported_backups = self.db.get_dict("imported_channel_backups")
×
3000
        if channel_id.hex() in onchain_backups:
×
3001
            onchain_backups.pop(channel_id.hex())
×
3002
            found = True
×
3003
        if channel_id.hex() in imported_backups:
×
3004
            imported_backups.pop(channel_id.hex())
×
3005
            found = True
×
3006
        if not found:
×
3007
            raise Exception('Channel not found')
×
3008
        with self.lock:
×
3009
            self._channel_backups.pop(channel_id)
×
3010
        self.wallet.set_reserved_addresses_for_chan(chan, reserved=False)
×
3011
        self.wallet.save_db()
×
3012
        util.trigger_callback('channels_updated', self.wallet)
×
3013

3014
    @log_exceptions
5✔
3015
    async def _request_force_close_from_backup(self, channel_id: bytes):
5✔
3016
        cb = self.channel_backups.get(channel_id)
×
3017
        if not cb:
×
3018
            raise Exception(f'channel backup not found {self.channel_backups}')
×
3019
        cb = cb.cb # storage
×
3020
        self.logger.info(f'requesting channel force close: {channel_id.hex()}')
×
3021
        if isinstance(cb, ImportedChannelBackupStorage):
×
3022
            node_id = cb.node_id
×
3023
            privkey = cb.privkey
×
3024
            addresses = [(cb.host, cb.port, 0)]
×
3025
        else:
3026
            assert isinstance(cb, OnchainChannelBackupStorage)
×
3027
            privkey = self.node_keypair.privkey
×
3028
            for pubkey, peer_addr in trampolines_by_id().items():
×
3029
                if pubkey.startswith(cb.node_id_prefix):
×
3030
                    node_id = pubkey
×
3031
                    addresses = [(peer_addr.host, peer_addr.port, 0)]
×
3032
                    break
×
3033
            else:
3034
                # we will try with gossip (see below)
3035
                addresses = []
×
3036

3037
        async def _request_fclose(addresses):
×
3038
            for host, port, timestamp in addresses:
×
3039
                peer_addr = LNPeerAddr(host, port, node_id)
×
3040
                transport = LNTransport(privkey, peer_addr, e_proxy=ESocksProxy.from_network_settings(self.network))
×
3041
                peer = Peer(self, node_id, transport, is_channel_backup=True)
×
3042
                try:
×
3043
                    async with OldTaskGroup(wait=any) as group:
×
3044
                        await group.spawn(peer._message_loop())
×
3045
                        await group.spawn(peer.request_force_close(channel_id))
×
3046
                    return True
×
3047
                except Exception as e:
×
3048
                    self.logger.info(f'failed to connect {host} {e}')
×
3049
                    continue
×
3050
            else:
3051
                return False
×
3052
        # try first without gossip db
3053
        success = await _request_fclose(addresses)
×
3054
        if success:
×
3055
            return
×
3056
        # try with gossip db
3057
        if self.uses_trampoline():
×
3058
            raise Exception(_('Please enable gossip'))
×
3059
        node_id = self.network.channel_db.get_node_by_prefix(cb.node_id_prefix)
×
3060
        addresses_from_gossip = self.network.channel_db.get_node_addresses(node_id)
×
3061
        if not addresses_from_gossip:
×
3062
            raise Exception('Peer not found in gossip database')
×
3063
        success = await _request_fclose(addresses_from_gossip)
×
3064
        if not success:
×
3065
            raise Exception('failed to connect')
×
3066

3067
    def maybe_add_backup_from_tx(self, tx):
5✔
3068
        funding_address = None
5✔
3069
        node_id_prefix = None
5✔
3070
        for i, o in enumerate(tx.outputs()):
5✔
3071
            script_type = get_script_type_from_output_script(o.scriptpubkey)
5✔
3072
            if script_type == 'p2wsh':
5✔
3073
                funding_index = i
×
3074
                funding_address = o.address
×
3075
                for o2 in tx.outputs():
×
3076
                    if o2.scriptpubkey.startswith(bytes([opcodes.OP_RETURN])):
×
3077
                        encrypted_data = o2.scriptpubkey[2:]
×
3078
                        data = self.decrypt_cb_data(encrypted_data, funding_address)
×
3079
                        if data.startswith(CB_MAGIC_BYTES):
×
3080
                            node_id_prefix = data[len(CB_MAGIC_BYTES):]
×
3081
        if node_id_prefix is None:
5✔
3082
            return
5✔
3083
        funding_txid = tx.txid()
×
3084
        cb_storage = OnchainChannelBackupStorage(
×
3085
            node_id_prefix = node_id_prefix,
3086
            funding_txid = funding_txid,
3087
            funding_index = funding_index,
3088
            funding_address = funding_address,
3089
            is_initiator = True)
3090
        channel_id = cb_storage.channel_id().hex()
×
3091
        if channel_id in self.db.get_dict("channels"):
×
3092
            return
×
3093
        self.logger.info(f"adding backup from tx")
×
3094
        d = self.db.get_dict("onchain_channel_backups")
×
3095
        d[channel_id] = cb_storage
×
3096
        cb = ChannelBackup(cb_storage, lnworker=self)
×
3097
        self.wallet.set_reserved_addresses_for_chan(cb, reserved=True)
×
3098
        self.wallet.save_db()
×
3099
        with self.lock:
×
3100
            self._channel_backups[bfh(channel_id)] = cb
×
3101
        util.trigger_callback('channels_updated', self.wallet)
×
3102
        self.lnwatcher.add_channel(cb.funding_outpoint.to_str(), cb.get_funding_address())
×
3103

3104
    def save_forwarding_failure(
5✔
3105
            self, payment_key:str, *,
3106
            error_bytes: Optional[bytes] = None,
3107
            failure_message: Optional['OnionRoutingFailure'] = None):
3108
        error_hex = error_bytes.hex() if error_bytes else None
5✔
3109
        failure_hex = failure_message.to_bytes().hex() if failure_message else None
5✔
3110
        self.forwarding_failures[payment_key] = (error_hex, failure_hex)
5✔
3111

3112
    def get_forwarding_failure(self, payment_key: str) -> Tuple[Optional[bytes], Optional['OnionRoutingFailure']]:
5✔
3113
        error_hex, failure_hex = self.forwarding_failures.get(payment_key, (None, None))
5✔
3114
        error_bytes = bytes.fromhex(error_hex) if error_hex else None
5✔
3115
        failure_message = OnionRoutingFailure.from_bytes(bytes.fromhex(failure_hex)) if failure_hex else None
5✔
3116
        return error_bytes, failure_message
5✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc