• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

spesmilo / electrum / 6330511049621504

25 Nov 2025 12:26PM UTC coverage: 62.326% (+0.1%) from 62.228%
6330511049621504

Pull #10230

CirrusCI

f321x
tests: lnpeer: test_dont_expire_htlcs

Adds unittest to test the dont_expire_htlcs logic
Pull Request #10230: lightning: refactor htlc switch

486 of 603 new or added lines in 9 files covered. (80.6%)

26 existing lines in 7 files now uncovered.

23583 of 37838 relevant lines covered (62.33%)

0.62 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

54.98
/electrum/lnworker.py
1
# Copyright (C) 2018 The Electrum developers
2
# Distributed under the MIT software license, see the accompanying
3
# file LICENCE or http://www.opensource.org/licenses/mit-license.php
4

5
import asyncio
1✔
6
import os
1✔
7
from decimal import Decimal
1✔
8
import random
1✔
9
import time
1✔
10
from enum import IntEnum
1✔
11
from typing import (
1✔
12
    Optional, Sequence, Tuple, List, Set, Dict, TYPE_CHECKING, NamedTuple, Mapping, Any, Iterable, AsyncGenerator,
13
    Callable, Awaitable, Union,
14
)
15
from types import MappingProxyType
1✔
16
import threading
1✔
17
import socket
1✔
18
from functools import partial
1✔
19
from collections import defaultdict
1✔
20
import concurrent
1✔
21
from concurrent import futures
1✔
22
import urllib.parse
1✔
23
import itertools
1✔
24
import dataclasses
1✔
25

26
import aiohttp
1✔
27
import dns.asyncresolver
1✔
28
import dns.exception
1✔
29
from aiorpcx import run_in_thread, NetAddress, ignore_after
1✔
30

31
from .logging import Logger
1✔
32
from .i18n import _
1✔
33
from .json_db import stored_in
1✔
34
from .channel_db import UpdateStatus, ChannelDBNotLoaded, get_mychannel_info, get_mychannel_policy
1✔
35

36
from . import constants, util, lnutil
1✔
37
from .util import (
1✔
38
    profiler, OldTaskGroup, ESocksProxy, NetworkRetryManager, JsonRPCClient, NotEnoughFunds, EventListener,
39
    event_listener, bfh, InvoiceError, resolve_dns_srv, is_ip_address, log_exceptions, ignore_exceptions,
40
    make_aiohttp_session, random_shuffled_copy, is_private_netaddress,
41
    UnrelatedTransactionException, LightningHistoryItem
42
)
43
from .fee_policy import (
1✔
44
    FeePolicy, FEERATE_FALLBACK_STATIC_FEE, FEE_LN_ETA_TARGET, FEE_LN_LOW_ETA_TARGET,
45
    FEERATE_PER_KW_MIN_RELAY_LIGHTNING, FEE_LN_MINIMUM_ETA_TARGET
46
)
47
from .invoices import Invoice, PR_UNPAID, PR_PAID, PR_INFLIGHT, PR_FAILED, LN_EXPIRY_NEVER, BaseInvoice
1✔
48
from .bitcoin import COIN, opcodes, make_op_return, address_to_scripthash, DummyAddress
1✔
49
from .bip32 import BIP32Node
1✔
50
from .address_synchronizer import TX_HEIGHT_LOCAL
1✔
51
from .transaction import (
1✔
52
    Transaction, get_script_type_from_output_script, PartialTxOutput, PartialTransaction, PartialTxInput
53
)
54
from .crypto import (
1✔
55
    sha256, chacha20_encrypt, chacha20_decrypt, pw_encode_with_version_and_mac, pw_decode_with_version_and_mac
56
)
57

58
from .onion_message import OnionMessageManager
1✔
59
from .lntransport import (
1✔
60
    LNTransport, LNResponderTransport, LNTransportBase, LNPeerAddr, split_host_port, extract_nodeid,
61
    ConnStringFormatError
62
)
63
from .lnpeer import Peer, LN_P2P_NETWORK_TIMEOUT
1✔
64
from .lnaddr import lnencode, LnAddr, lndecode
1✔
65
from .lnchannel import Channel, AbstractChannel, ChannelState, PeerState, HTLCWithStatus, ChannelBackup
1✔
66
from .lnrater import LNRater
1✔
67
from .lnutil import (
1✔
68
    get_compressed_pubkey_from_bech32, serialize_htlc_key, deserialize_htlc_key, PaymentFailure, generate_keypair,
69
    LnKeyFamily, LOCAL, REMOTE, MIN_FINAL_CLTV_DELTA_ACCEPTED, SENT, RECEIVED, HTLCOwner, UpdateAddHtlc, LnFeatures,
70
    ShortChannelID, HtlcLog, NoPathFound, InvalidGossipMsg, FeeBudgetExceeded, ImportedChannelBackupStorage,
71
    OnchainChannelBackupStorage, ln_compare_features, IncompatibleLightningFeatures, PaymentFeeBudget,
72
    NBLOCK_CLTV_DELTA_TOO_FAR_INTO_FUTURE, GossipForwardingMessage, MIN_FUNDING_SAT,
73
    MIN_FINAL_CLTV_DELTA_BUFFER_INVOICE, RecvMPPResolution, ReceivedMPPStatus, ReceivedMPPHtlc,
74
    PaymentSuccess,
75
)
76
from .lnonion import (
1✔
77
    decode_onion_error, OnionFailureCode, OnionRoutingFailure, OnionPacket,
78
    ProcessedOnionPacket, calc_hops_data_for_payment, new_onion_packet,
79
)
80
from .lnmsg import decode_msg
1✔
81
from .lnrouter import (
1✔
82
    RouteEdge, LNPaymentRoute, LNPaymentPath, is_route_within_budget, NoChannelPolicy,
83
    LNPathInconsistent, fee_for_edge_msat,
84
)
85
from .lnwatcher import LNWatcher
1✔
86
from .submarine_swaps import SwapManager
1✔
87
from .mpp_split import suggest_splits, SplitConfigRating
1✔
88
from .trampoline import (
1✔
89
    create_trampoline_route_and_onion, is_legacy_relay, trampolines_by_id, hardcoded_trampoline_nodes,
90
    is_hardcoded_trampoline, decode_routing_info
91
)
92

93
if TYPE_CHECKING:
94
    from .network import Network
95
    from .wallet import Abstract_Wallet
96
    from .channel_db import ChannelDB
97
    from .simple_config import SimpleConfig
98

99

100
SAVED_PR_STATUS = [PR_PAID, PR_UNPAID]  # status that are persisted
1✔
101

102
NUM_PEERS_TARGET = 4
1✔
103

104
# onchain channel backup data
105
CB_VERSION = 0
1✔
106
CB_MAGIC_BYTES = bytes([0, 0, 0, CB_VERSION])
1✔
107
NODE_ID_PREFIX_LEN = 16
1✔
108

109

110
class PaymentDirection(IntEnum):
1✔
111
    SENT = 0
1✔
112
    RECEIVED = 1
1✔
113
    SELF_PAYMENT = 2
1✔
114
    FORWARDING = 3
1✔
115

116

117
@dataclasses.dataclass(frozen=True, kw_only=True)
1✔
118
class PaymentInfo:
1✔
119
    """Information required to handle incoming htlcs for a payment request"""
120
    payment_hash: bytes
1✔
121
    amount_msat: Optional[int]
1✔
122
    direction: int
1✔
123
    status: int
1✔
124
    min_final_cltv_delta: int
1✔
125
    expiry_delay: int
1✔
126
    creation_ts: int = dataclasses.field(default_factory=lambda: int(time.time()))
1✔
127

128
    @property
1✔
129
    def expiration_ts(self):
1✔
130
        return self.creation_ts + self.expiry_delay
1✔
131

132
    def validate(self):
1✔
133
        assert isinstance(self.payment_hash, bytes) and len(self.payment_hash) == 32
1✔
134
        assert self.amount_msat is None or isinstance(self.amount_msat, int)
1✔
135
        assert isinstance(self.direction, int)
1✔
136
        assert isinstance(self.status, int)
1✔
137
        assert isinstance(self.min_final_cltv_delta, int)
1✔
138
        assert isinstance(self.expiry_delay, int) and self.expiry_delay > 0
1✔
139
        assert isinstance(self.creation_ts, int)
1✔
140

141
    def __post_init__(self):
1✔
142
        self.validate()
1✔
143

144

145

146
SentHtlcKey = Tuple[bytes, ShortChannelID, int]  # RHASH, scid, htlc_id
1✔
147

148

149
class SentHtlcInfo(NamedTuple):
1✔
150
    route: LNPaymentRoute
1✔
151
    payment_secret_orig: bytes
1✔
152
    payment_secret_bucket: bytes
1✔
153
    amount_msat: int
1✔
154
    bucket_msat: int
1✔
155
    amount_receiver_msat: int
1✔
156
    trampoline_fee_level: Optional[int]
1✔
157
    trampoline_route: Optional[LNPaymentRoute]
1✔
158

159

160
class ErrorAddingPeer(Exception): pass
1✔
161

162

163
# set some feature flags as baseline for both LNWallet and LNGossip
164
# note that e.g. DATA_LOSS_PROTECT is needed for LNGossip as many peers require it
165
BASE_FEATURES = (
1✔
166
    LnFeatures(0)
167
    | LnFeatures.OPTION_DATA_LOSS_PROTECT_OPT
168
    | LnFeatures.OPTION_STATIC_REMOTEKEY_OPT
169
    | LnFeatures.VAR_ONION_OPT
170
    | LnFeatures.PAYMENT_SECRET_OPT
171
    | LnFeatures.OPTION_UPFRONT_SHUTDOWN_SCRIPT_OPT
172
)
173

174
# we do not want to receive unrequested gossip (see lnpeer.maybe_save_remote_update)
175
LNWALLET_FEATURES = (
1✔
176
    BASE_FEATURES
177
    | LnFeatures.OPTION_DATA_LOSS_PROTECT_REQ
178
    | LnFeatures.OPTION_STATIC_REMOTEKEY_REQ
179
    | LnFeatures.VAR_ONION_REQ
180
    | LnFeatures.PAYMENT_SECRET_REQ
181
    | LnFeatures.BASIC_MPP_OPT
182
    | LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ELECTRUM
183
    | LnFeatures.OPTION_SHUTDOWN_ANYSEGWIT_OPT
184
    | LnFeatures.OPTION_CHANNEL_TYPE_OPT
185
    | LnFeatures.OPTION_SCID_ALIAS_OPT
186
    | LnFeatures.OPTION_SUPPORT_LARGE_CHANNEL_OPT
187
)
188

189
LNGOSSIP_FEATURES = (
1✔
190
    BASE_FEATURES
191
    # LNGossip doesn't serve gossip but weirdly have to signal so
192
    # that peers satisfy our queries
193
    | LnFeatures.GOSSIP_QUERIES_REQ
194
    | LnFeatures.GOSSIP_QUERIES_OPT
195
)
196

197

198
class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
1✔
199

200
    def __init__(self, node_keypair, features: LnFeatures, *, config: 'SimpleConfig'):
1✔
201
        Logger.__init__(self)
1✔
202
        NetworkRetryManager.__init__(
1✔
203
            self,
204
            max_retry_delay_normal=3600,
205
            init_retry_delay_normal=600,
206
            max_retry_delay_urgent=300,
207
            init_retry_delay_urgent=4,
208
        )
209
        self.lock = threading.RLock()
1✔
210
        self.node_keypair = node_keypair
1✔
211
        self._peers = {}  # type: Dict[bytes, Peer]  # pubkey -> Peer  # needs self.lock
1✔
212
        self._channelless_incoming_peers = set()  # type: Set[bytes]  # node_ids  # needs self.lock
1✔
213
        self.taskgroup = OldTaskGroup()
1✔
214
        self.listen_server = None  # type: Optional[asyncio.AbstractServer]
1✔
215
        self.features = features
1✔
216
        self.network = None  # type: Optional[Network]
1✔
217
        self.config = config
1✔
218
        self.stopping_soon = False  # whether we are being shut down
1✔
219
        self.register_callbacks()
1✔
220

221
    @property
1✔
222
    def channel_db(self) -> 'ChannelDB':
1✔
223
        return self.network.channel_db if self.network else None
1✔
224

225
    def uses_trampoline(self) -> bool:
1✔
226
        return not bool(self.channel_db)
1✔
227

228
    @property
1✔
229
    def peers(self) -> Mapping[bytes, Peer]:
1✔
230
        """Returns a read-only copy of peers."""
231
        with self.lock:
×
232
            return self._peers.copy()
×
233

234
    def channels_for_peer(self, node_id: bytes) -> Dict[bytes, Channel]:
1✔
235
        return {}
×
236

237
    def get_node_alias(self, node_id: bytes) -> Optional[str]:
1✔
238
        """Returns the alias of the node, or None if unknown."""
239
        node_alias = None
×
240
        if not self.uses_trampoline():
×
241
            node_info = self.channel_db.get_node_info_for_node_id(node_id)
×
242
            if node_info:
×
243
                node_alias = node_info.alias
×
244
        else:
245
            for k, v in hardcoded_trampoline_nodes().items():
×
246
                if v.pubkey.startswith(node_id):
×
247
                    node_alias = k
×
248
                    break
×
249
        return node_alias
×
250

251
    async def maybe_listen(self):
1✔
252
        # FIXME: only one LNWorker can listen at a time (single port)
253
        listen_addr = self.config.LIGHTNING_LISTEN
1✔
254
        if listen_addr:
1✔
255
            self.logger.info(f'lightning_listen enabled. will try to bind: {listen_addr!r}')
×
256
            try:
×
257
                netaddr = NetAddress.from_string(listen_addr)
×
258
            except Exception as e:
×
259
                self.logger.error(f"failed to parse config key '{self.config.cv.LIGHTNING_LISTEN.key()}'. got: {e!r}")
×
260
                return
×
261
            addr = str(netaddr.host)
×
262

263
            async def cb(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
×
264
                transport = LNResponderTransport(self.node_keypair.privkey, reader, writer)
×
265
                try:
×
266
                    node_id = await transport.handshake()
×
267
                except Exception as e:
×
268
                    self.logger.info(f'handshake failure from incoming connection: {e!r}')
×
269
                    return
×
270
                peername = writer.get_extra_info('peername')
×
271
                self.logger.debug(f"handshake done for incoming peer: {peername=}, node_id={node_id.hex()}")
×
272
                await self._add_peer_from_transport(node_id=node_id, transport=transport)
×
273
            try:
×
274
                self.listen_server = await asyncio.start_server(cb, addr, netaddr.port)
×
275
            except OSError as e:
×
276
                self.logger.error(f"cannot listen for lightning p2p. error: {e!r}")
×
277

278
    async def main_loop(self):
1✔
279
        self.logger.info("starting taskgroup.")
1✔
280
        try:
1✔
281
            async with self.taskgroup as group:
1✔
282
                await group.spawn(asyncio.Event().wait)  # run forever (until cancel)
1✔
283
        except Exception as e:
1✔
284
            self.logger.exception("taskgroup died.")
×
285
        finally:
286
            self.logger.info("taskgroup stopped.")
1✔
287

288
    async def _maintain_connectivity(self):
1✔
289
        while True:
×
290
            await asyncio.sleep(1)
×
291
            if self.stopping_soon:
×
292
                return
×
293
            now = time.time()
×
294
            if len(self._peers) >= NUM_PEERS_TARGET:
×
295
                continue
×
296
            peers = await self._get_next_peers_to_try()
×
297
            for peer in peers:
×
298
                if self._can_retry_addr(peer, now=now):
×
299
                    try:
×
300
                        await self._add_peer(peer.host, peer.port, peer.pubkey)
×
301
                    except ErrorAddingPeer as e:
×
302
                        self.logger.info(f"failed to add peer: {peer}. exc: {e!r}")
×
303

304
    async def _add_peer(self, host: str, port: int, node_id: bytes) -> Peer:
1✔
305
        if node_id in self._peers:
×
306
            return self._peers[node_id]
×
307
        port = int(port)
×
308
        peer_addr = LNPeerAddr(host, port, node_id)
×
309
        self._trying_addr_now(peer_addr)
×
310
        self.logger.info(f"adding peer {peer_addr}")
×
311
        if node_id == self.node_keypair.pubkey or self.is_our_lnwallet(node_id):
×
312
            raise ErrorAddingPeer("cannot connect to self")
×
313
        transport = LNTransport(self.node_keypair.privkey, peer_addr,
×
314
                                e_proxy=ESocksProxy.from_network_settings(self.network))
315
        peer = await self._add_peer_from_transport(node_id=node_id, transport=transport)
×
316
        assert peer
×
317
        return peer
×
318

319
    async def _add_peer_from_transport(self, *, node_id: bytes, transport: LNTransportBase) -> Optional[Peer]:
1✔
320
        with self.lock:
×
321
            existing_peer = self._peers.get(node_id)
×
322
            if existing_peer:
×
323
                # Two instances of the same wallet are attempting to connect simultaneously.
324
                # If we let the new connection replace the existing one, the two instances might
325
                # both keep trying to reconnect, resulting in neither being usable.
326
                if existing_peer.is_initialized():
×
327
                    # give priority to the existing connection
328
                    transport.close()
×
329
                    return None
×
330
                else:
331
                    # Use the new connection. (e.g. old peer might be an outgoing connection
332
                    # for an outdated host/port that will never connect)
333
                    existing_peer.close_and_cleanup()
×
334
            # limit max number of incoming channel-less peers.
335
            # what to do if limit is reached?
336
            # - chosen strategy: we don't allow new connections.
337
            #   - drawback: attacker can use up all our slots
338
            # - alternative: kick oldest channel-less peer
339
            #   - drawback: if many legit peers want to connect to us, we will keep kicking them
340
            #               in round-robin, and they will keep reconnecting. no stable state -> we self-DOS
341
            # TODO make slots IP-based?
342
            if isinstance(transport, LNResponderTransport):
×
343
                assert node_id not in self._channelless_incoming_peers
×
344
                chans = [chan for chan in self.channels_for_peer(node_id).values() if chan.is_funded()]
×
345
                if not chans:
×
346
                    if len(self._channelless_incoming_peers) > 100:
×
347
                        transport.close()
×
348
                        return None
×
349
                    self._channelless_incoming_peers.add(node_id)
×
350
            # checks done: we are adding this peer.
351
            peer = Peer(self, node_id, transport)
×
352
            assert node_id not in self._peers
×
353
            self._peers[node_id] = peer
×
354
        await self.taskgroup.spawn(peer.main_loop())
×
355
        return peer
×
356

357
    def peer_closed(self, peer: Peer) -> None:
1✔
358
        with self.lock:
×
359
            peer2 = self._peers.get(peer.pubkey)
×
360
            if peer2 is peer:
×
361
                self._peers.pop(peer.pubkey)
×
362
            self._channelless_incoming_peers.discard(peer.pubkey)
×
363

364
    def num_peers(self) -> int:
1✔
365
        return sum([p.is_initialized() for p in self.peers.values()])
×
366

367
    def is_our_lnwallet(self, node_id: bytes) -> bool:
1✔
368
        """Check if node_id is one of our own wallets"""
369
        wallets = self.network.daemon.get_wallets()
×
370
        for wallet in wallets.values():
×
371
            if wallet.lnworker and wallet.lnworker.node_keypair.pubkey == node_id:
×
372
                return True
×
373
        return False
×
374

375
    def start_network(self, network: 'Network'):
1✔
376
        assert network
1✔
377
        assert self.network is None, "already started"
1✔
378
        self.network = network
1✔
379
        self._add_peers_from_config()
1✔
380
        asyncio.run_coroutine_threadsafe(self.main_loop(), self.network.asyncio_loop)
1✔
381

382
    async def stop(self):
1✔
383
        if self.listen_server:
1✔
384
            self.listen_server.close()
×
385
        self.unregister_callbacks()
1✔
386
        await self.taskgroup.cancel_remaining()
1✔
387

388
    def _add_peers_from_config(self):
1✔
389
        peer_list = self.config.LIGHTNING_PEERS or []
1✔
390
        for host, port, pubkey in peer_list:
1✔
391
            asyncio.run_coroutine_threadsafe(
×
392
                self._add_peer(host, int(port), bfh(pubkey)),
393
                self.network.asyncio_loop)
394

395
    def is_good_peer(self, peer: LNPeerAddr) -> bool:
1✔
396
        # the purpose of this method is to filter peers that advertise the desired feature bits
397
        # it is disabled for now, because feature bits published in node announcements seem to be unreliable
398
        return True
×
399
        node_id = peer.pubkey
400
        node = self.channel_db._nodes.get(node_id)
401
        if not node:
402
            return False
403
        try:
404
            ln_compare_features(self.features, node.features)
405
        except IncompatibleLightningFeatures:
406
            return False
407
        #self.logger.info(f'is_good {peer.host}')
408
        return True
409

410
    def on_peer_successfully_established(self, peer: Peer) -> None:
1✔
411
        if isinstance(peer.transport, LNTransport):
1✔
412
            peer_addr = peer.transport.peer_addr
×
413
            # reset connection attempt count
414
            self._on_connection_successfully_established(peer_addr)
×
415
            if not self.uses_trampoline():
×
416
                # add into channel db
417
                self.channel_db.add_recent_peer(peer_addr)
×
418
            # save network address into channels we might have with peer
419
            for chan in peer.channels.values():
×
420
                chan.add_or_update_peer_addr(peer_addr)
×
421

422
    async def _get_next_peers_to_try(self) -> Sequence[LNPeerAddr]:
1✔
423
        now = time.time()
×
424
        await self.channel_db.data_loaded.wait()
×
425
        # first try from recent peers
426
        recent_peers = self.channel_db.get_recent_peers()
×
427
        for peer in recent_peers:
×
428
            if not peer:
×
429
                continue
×
430
            if peer.pubkey in self._peers:
×
431
                continue
×
432
            if not self._can_retry_addr(peer, now=now):
×
433
                continue
×
434
            if not self.is_good_peer(peer):
×
435
                continue
×
436
            return [peer]
×
437
        # try random peer from graph
438
        unconnected_nodes = self.channel_db.get_200_randomly_sorted_nodes_not_in(self.peers.keys())
×
439
        if unconnected_nodes:
×
440
            for node_id in unconnected_nodes:
×
441
                addrs = self.channel_db.get_node_addresses(node_id)
×
442
                if not addrs:
×
443
                    continue
×
444
                host, port, timestamp = self.choose_preferred_address(list(addrs))
×
445
                try:
×
446
                    peer = LNPeerAddr(host, port, node_id)
×
447
                except ValueError:
×
448
                    continue
×
449
                if not self._can_retry_addr(peer, now=now):
×
450
                    continue
×
451
                if not self.is_good_peer(peer):
×
452
                    continue
×
453
                #self.logger.info('taking random ln peer from our channel db')
454
                return [peer]
×
455

456
        # getting desperate... let's try hardcoded fallback list of peers
457
        fallback_list = constants.net.FALLBACK_LN_NODES
×
458
        fallback_list = [peer for peer in fallback_list if self._can_retry_addr(peer, now=now)]
×
459
        if fallback_list:
×
460
            return [random.choice(fallback_list)]
×
461

462
        # last resort: try dns seeds (BOLT-10)
463
        return await self._get_peers_from_dns_seeds()
×
464

465
    async def _get_peers_from_dns_seeds(self) -> Sequence[LNPeerAddr]:
1✔
466
        # Return several peers to reduce the number of dns queries.
467
        if not constants.net.LN_DNS_SEEDS:
×
468
            return []
×
469
        dns_seed = random.choice(constants.net.LN_DNS_SEEDS)
×
470
        self.logger.info('asking dns seed "{}" for ln peers'.format(dns_seed))
×
471
        try:
×
472
            # note: this might block for several seconds
473
            # this will include bech32-encoded-pubkeys and ports
474
            srv_answers = await resolve_dns_srv('r{}.{}'.format(
×
475
                constants.net.LN_REALM_BYTE, dns_seed))
476
        except dns.exception.DNSException as e:
×
477
            self.logger.info(f'failed querying (1) dns seed "{dns_seed}" for ln peers: {repr(e)}')
×
478
            return []
×
479
        random.shuffle(srv_answers)
×
480
        num_peers = 2 * NUM_PEERS_TARGET
×
481
        srv_answers = srv_answers[:num_peers]
×
482
        # we now have pubkeys and ports but host is still needed
483
        peers = []
×
484
        for srv_ans in srv_answers:
×
485
            try:
×
486
                # note: this might take several seconds
487
                answers = await dns.asyncresolver.resolve(srv_ans['host'])
×
488
            except dns.exception.DNSException as e:
×
489
                self.logger.info(f'failed querying (2) dns seed "{dns_seed}" for ln peers: {repr(e)}')
×
490
                continue
×
491
            try:
×
492
                ln_host = str(answers[0])
×
493
                port = int(srv_ans['port'])
×
494
                bech32_pubkey = srv_ans['host'].split('.')[0]
×
495
                pubkey = get_compressed_pubkey_from_bech32(bech32_pubkey)
×
496
                peers.append(LNPeerAddr(ln_host, port, pubkey))
×
497
            except Exception as e:
×
498
                self.logger.info(f'error with parsing peer from dns seed: {repr(e)}')
×
499
                continue
×
500
        self.logger.info(f'got {len(peers)} ln peers from dns seed')
×
501
        return peers
×
502

503
    @staticmethod
1✔
504
    def choose_preferred_address(addr_list: Sequence[Tuple[str, int, int]]) -> Tuple[str, int, int]:
1✔
505
        assert len(addr_list) >= 1
×
506
        # choose the most recent one that is an IP
507
        for host, port, timestamp in sorted(addr_list, key=lambda a: -a[2]):
×
508
            if is_ip_address(host):
×
509
                return host, port, timestamp
×
510
        # otherwise choose one at random
511
        # TODO maybe filter out onion if not on tor?
512
        choice = random.choice(addr_list)
×
513
        return choice
×
514

515
    @event_listener
1✔
516
    def on_event_proxy_set(self, *args):
1✔
517
        for peer in self.peers.values():
×
518
            peer.close_and_cleanup()
×
519
        self._clear_addr_retry_times()
×
520

521
    @log_exceptions
1✔
522
    async def add_peer(self, connect_str: str) -> Peer:
1✔
523
        node_id, rest = extract_nodeid(connect_str)
×
524
        peer = self._peers.get(node_id)
×
525
        if not peer:
×
526
            if rest is not None:
×
527
                host, port = split_host_port(rest)
×
528
            else:
529
                if self.uses_trampoline():
×
530
                    addr = trampolines_by_id().get(node_id)
×
531
                    if not addr:
×
532
                        raise ConnStringFormatError(_('Address unknown for node:') + ' ' + node_id.hex())
×
533
                    host, port = addr.host, addr.port
×
534
                else:
535
                    addrs = self.channel_db.get_node_addresses(node_id)
×
536
                    if not addrs:
×
537
                        raise ConnStringFormatError(_('Don\'t know any addresses for node:') + ' ' + node_id.hex())
×
538
                    host, port, timestamp = self.choose_preferred_address(list(addrs))
×
539
            port = int(port)
×
540

541
            if not self.network.proxy:
×
542
                # Try DNS-resolving the host (if needed). This is simply so that
543
                # the caller gets a nice exception if it cannot be resolved.
544
                # (we don't do the DNS lookup if a proxy is set, to avoid a DNS-leak)
545
                if host.endswith('.onion'):
×
546
                    raise ConnStringFormatError(_('.onion address, but no proxy configured'))
×
547
                try:
×
548
                    await asyncio.get_running_loop().getaddrinfo(host, port)
×
549
                except socket.gaierror:
×
550
                    raise ConnStringFormatError(_('Hostname does not resolve (getaddrinfo failed)'))
×
551

552
            # add peer
553
            peer = await self._add_peer(host, port, node_id)
×
554
        return peer
×
555

556

557
class LNGossip(LNWorker):
1✔
558
    """The LNGossip class is a separate, unannounced Lightning node with random id that is just querying
559
    gossip from other nodes. The LNGossip node does not satisfy gossip queries, this is done by the
560
    LNWallet class(es). LNWallets are the advertised nodes used for actual payments and only satisfy
561
    peer queries without fetching gossip themselves. This separation is done so that gossip can be queried
562
    independently of the active LNWallets. LNGossip keeps a curated batch of gossip in _forwarding_gossip
563
    that is fetched by the LNWallets for regular forwarding."""
564
    max_age = 14*24*3600
1✔
565

566
    def __init__(self, config: 'SimpleConfig'):
1✔
567
        seed = os.urandom(32)
×
568
        node = BIP32Node.from_rootseed(seed, xtype='standard')
×
569
        xprv = node.to_xprv()
×
570
        node_keypair = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.NODE_KEY)
×
571
        LNWorker.__init__(self, node_keypair, LNGOSSIP_FEATURES, config=config)
×
572
        self.unknown_ids = set()
×
573
        self._forwarding_gossip = []  # type: List[GossipForwardingMessage]
×
574
        self._last_gossip_batch_ts = 0  # type: int
×
575
        self._forwarding_gossip_lock = asyncio.Lock()
×
576
        self.gossip_request_semaphore = asyncio.Semaphore(5)
×
577
        # statistics
578
        self._num_chan_ann = 0
×
579
        self._num_node_ann = 0
×
580
        self._num_chan_upd = 0
×
581
        self._num_chan_upd_good = 0
×
582

583
    def start_network(self, network: 'Network'):
1✔
584
        super().start_network(network)
×
585
        for coro in [
×
586
                self._maintain_connectivity(),
587
                self.maintain_db(),
588
                self._maintain_forwarding_gossip()
589
        ]:
590
            tg_coro = self.taskgroup.spawn(coro)
×
591
            asyncio.run_coroutine_threadsafe(tg_coro, self.network.asyncio_loop)
×
592

593
    async def maintain_db(self):
1✔
594
        await self.channel_db.data_loaded.wait()
×
595
        while True:
×
596
            if len(self.unknown_ids) == 0:
×
597
                self.channel_db.prune_old_policies(self.max_age)
×
598
                self.channel_db.prune_orphaned_channels()
×
599
            await asyncio.sleep(120)
×
600

601
    async def _maintain_forwarding_gossip(self):
1✔
602
        await self.channel_db.data_loaded.wait()
×
603
        await self.wait_for_sync()
×
604
        while True:
×
605
            async with self._forwarding_gossip_lock:
×
606
                self._forwarding_gossip = self.channel_db.get_forwarding_gossip_batch()
×
607
                self._last_gossip_batch_ts = int(time.time())
×
608
            self.logger.debug(f"{len(self._forwarding_gossip)} gossip messages available to forward")
×
609
            await asyncio.sleep(60)
×
610

611
    async def get_forwarding_gossip(self) -> tuple[List[GossipForwardingMessage], int]:
1✔
612
        async with self._forwarding_gossip_lock:
×
613
            return self._forwarding_gossip, self._last_gossip_batch_ts
×
614

615
    async def add_new_ids(self, ids: Iterable[bytes]):
1✔
616
        known = self.channel_db.get_channel_ids()
×
617
        new = set(ids) - set(known)
×
618
        self.unknown_ids.update(new)
×
619
        util.trigger_callback('unknown_channels', len(self.unknown_ids))
×
620
        util.trigger_callback('gossip_peers', self.num_peers())
×
621
        util.trigger_callback('ln_gossip_sync_progress')
×
622

623
    def get_ids_to_query(self) -> Sequence[bytes]:
1✔
624
        N = 500
×
625
        l = list(self.unknown_ids)
×
626
        self.unknown_ids = set(l[N:])
×
627
        util.trigger_callback('unknown_channels', len(self.unknown_ids))
×
628
        util.trigger_callback('ln_gossip_sync_progress')
×
629
        return l[0:N]
×
630

631
    def get_sync_progress_estimate(self) -> Tuple[Optional[int], Optional[int], Optional[int]]:
1✔
632
        """Estimates the gossip synchronization process and returns the number
633
        of synchronized channels, the total channels in the network and a
634
        rescaled percentage of the synchronization process."""
635
        if self.num_peers() == 0:
×
636
            return None, None, None
×
637
        nchans_with_0p, nchans_with_1p, nchans_with_2p = self.channel_db.get_num_channels_partitioned_by_policy_count()
×
638
        num_db_channels = nchans_with_0p + nchans_with_1p + nchans_with_2p
×
639
        num_nodes = self.channel_db.num_nodes
×
640
        num_nodes_associated_to_chans = max(len(self.channel_db._channels_for_node.keys()), 1)
×
641
        # some channels will never have two policies (only one is in gossip?...)
642
        # so if we have at least 1 policy for a channel, we consider that channel "complete" here
643
        current_est = num_db_channels - nchans_with_0p
×
644
        total_est = len(self.unknown_ids) + num_db_channels
×
645

646
        progress_chans = current_est / total_est if total_est and current_est else 0
×
647
        # consider that we got at least 10% of the node anns of node ids we know about
648
        progress_nodes = min((num_nodes / num_nodes_associated_to_chans) * 10, 1)
×
649
        progress = (progress_chans * 3 + progress_nodes) / 4  # weigh the channel progress higher
×
650
        # self.logger.debug(f"Sync process chans: {progress_chans} | Progress nodes: {progress_nodes} | "
651
        #                   f"Total progress: {progress} | NUM_NODES: {num_nodes} / {num_nodes_associated_to_chans}")
652
        progress_percent = (1.0 / 0.95 * progress) * 100
×
653
        progress_percent = min(progress_percent, 100)
×
654
        progress_percent = round(progress_percent)
×
655
        # take a minimal number of synchronized channels to get a more accurate
656
        # percentage estimate
657
        if current_est < 200:
×
658
            progress_percent = 0
×
659
        return current_est, total_est, progress_percent
×
660

661
    async def process_gossip(self, chan_anns, node_anns, chan_upds):
1✔
662
        # note: we run in the originating peer's TaskGroup, so we can safely raise here
663
        #       and disconnect only from that peer
664
        await self.channel_db.data_loaded.wait()
×
665

666
        # channel announcements
667
        def process_chan_anns():
×
668
            for payload in chan_anns:
×
669
                self.channel_db.verify_channel_announcement(payload)
×
670
            self.channel_db.add_channel_announcements(chan_anns)
×
671
        await run_in_thread(process_chan_anns)
×
672

673
        # node announcements
674
        def process_node_anns():
×
675
            for payload in node_anns:
×
676
                self.channel_db.verify_node_announcement(payload)
×
677
            self.channel_db.add_node_announcements(node_anns)
×
678
        await run_in_thread(process_node_anns)
×
679
        # channel updates
680
        categorized_chan_upds = await run_in_thread(partial(
×
681
            self.channel_db.add_channel_updates,
682
            chan_upds,
683
            max_age=self.max_age))
684
        orphaned = categorized_chan_upds.orphaned
×
685
        if orphaned:
×
686
            self.logger.info(f'adding {len(orphaned)} unknown channel ids')
×
687
            orphaned_ids = [c['short_channel_id'] for c in orphaned]
×
688
            await self.add_new_ids(orphaned_ids)
×
689

690
        self._num_chan_ann += len(chan_anns)
×
691
        self._num_node_ann += len(node_anns)
×
692
        self._num_chan_upd += len(chan_upds)
×
693
        self._num_chan_upd_good += len(categorized_chan_upds.good)
×
694

695
    def is_synced(self) -> bool:
1✔
696
        _, _, percentage_synced = self.get_sync_progress_estimate()
×
697
        if percentage_synced is not None and percentage_synced >= 100:
×
698
            return True
×
699
        return False
×
700

701
    async def wait_for_sync(self, times_to_check: int = 3):
1✔
702
        """Check if we have 100% sync progress `times_to_check` times in a row (because the
703
        estimate often jumps back after some seconds when doing initial sync)."""
704
        while True:
×
705
            if self.is_synced():
×
706
                times_to_check -= 1
×
707
                if times_to_check <= 0:
×
708
                    return
×
709
            await asyncio.sleep(10)
×
710
            # flush the gossip queue so we don't forward old gossip after sync is complete
711
            self.channel_db.get_forwarding_gossip_batch()
×
712

713

714
class PaySession(Logger):
1✔
715

716
    # how long we wait for another htlc to resolve after receiving a failure for one sent htlc.
717
    TIMEOUT_WAIT_FOR_NEXT_RESOLVED_HTLC = 0.5
1✔
718

719
    def __init__(
1✔
720
            self,
721
            *,
722
            payment_hash: bytes,
723
            payment_secret: bytes,
724
            initial_trampoline_fee_level: int,
725
            invoice_features: int,
726
            r_tags,
727
            min_final_cltv_delta: int,  # delta for last node (typically from invoice)
728
            amount_to_pay: int,  # total payment amount final receiver will get
729
            invoice_pubkey: bytes,
730
            uses_trampoline: bool,  # whether sender uses trampoline or gossip
731
            use_two_trampolines: bool,  # whether legacy payments will try to use two trampolines
732
    ):
733
        assert payment_hash
1✔
734
        assert payment_secret
1✔
735
        self.payment_hash = payment_hash
1✔
736
        self.payment_secret = payment_secret
1✔
737
        self.payment_key = payment_hash + payment_secret
1✔
738
        Logger.__init__(self)
1✔
739

740
        self.invoice_features = LnFeatures(invoice_features)
1✔
741
        self.r_tags = r_tags
1✔
742
        self.min_final_cltv_delta = min_final_cltv_delta
1✔
743
        self.amount_to_pay = amount_to_pay
1✔
744
        self.invoice_pubkey = invoice_pubkey
1✔
745

746
        self.sent_htlcs_q = asyncio.Queue()  # type: asyncio.Queue[HtlcLog]
1✔
747
        self.start_time = time.time()
1✔
748

749
        self.uses_trampoline = uses_trampoline
1✔
750
        self.trampoline_fee_level = initial_trampoline_fee_level
1✔
751
        self.failed_trampoline_routes = []
1✔
752
        self.use_two_trampolines = use_two_trampolines
1✔
753
        self._sent_buckets = dict()  # psecret_bucket -> (amount_sent, amount_failed)
1✔
754

755
        self._amount_inflight = 0  # what we sent in htlcs (that receiver gets, without fees)
1✔
756
        self._nhtlcs_inflight = 0
1✔
757
        self.is_active = True  # is still trying to send new htlcs?
1✔
758

759
    def diagnostic_name(self):
1✔
760
        pkey = sha256(self.payment_key)
1✔
761
        return f"{self.payment_hash[:4].hex()}-{pkey[:2].hex()}"
1✔
762

763
    @property
1✔
764
    def number_htlcs_inflight(self) -> int:
1✔
765
        return self._nhtlcs_inflight
1✔
766

767
    def maybe_raise_trampoline_fee(self, htlc_log: HtlcLog):
1✔
768
        if htlc_log.trampoline_fee_level == self.trampoline_fee_level:
1✔
769
            self.trampoline_fee_level += 1
1✔
770
            self.failed_trampoline_routes = []
1✔
771
            self.logger.info(f'raising trampoline fee level {self.trampoline_fee_level}')
1✔
772
        else:
773
            self.logger.info(f'NOT raising trampoline fee level, already at {self.trampoline_fee_level}')
1✔
774

775
    def handle_failed_trampoline_htlc(self, *, htlc_log: HtlcLog, failure_msg: OnionRoutingFailure):
1✔
776
        # FIXME The trampoline nodes in the path are chosen randomly.
777
        #       Some of the errors might depend on how we have chosen them.
778
        #       Having more attempts is currently useful in part because of the randomness,
779
        #       instead we should give feedback to create_routes_for_payment.
780
        # Sometimes the trampoline node fails to send a payment and returns
781
        # TEMPORARY_CHANNEL_FAILURE, while it succeeds with a higher trampoline fee.
782
        if failure_msg.code in (
1✔
783
                OnionFailureCode.TRAMPOLINE_FEE_INSUFFICIENT,
784
                OnionFailureCode.TRAMPOLINE_EXPIRY_TOO_SOON,
785
                OnionFailureCode.TEMPORARY_CHANNEL_FAILURE):
786
            # TODO: parse the node policy here (not returned by eclair yet)
787
            # TODO: erring node is always the first trampoline even if second
788
            #  trampoline demands more fees, we can't influence this
789
            self.maybe_raise_trampoline_fee(htlc_log)
1✔
790
        elif self.use_two_trampolines:
1✔
791
            self.use_two_trampolines = False
×
792
        elif failure_msg.code in (
1✔
793
                OnionFailureCode.UNKNOWN_NEXT_PEER,
794
                OnionFailureCode.TEMPORARY_NODE_FAILURE):
795
            trampoline_route = htlc_log.route
1✔
796
            r = [hop.end_node.hex() for hop in trampoline_route]
1✔
797
            self.logger.info(f'failed trampoline route: {r}')
1✔
798
            if r not in self.failed_trampoline_routes:
1✔
799
                self.failed_trampoline_routes.append(r)
1✔
800
            else:
801
                pass  # maybe the route was reused between different MPP parts
×
802
        else:
803
            raise PaymentFailure(failure_msg.code_name())
1✔
804

805
    async def wait_for_one_htlc_to_resolve(self) -> HtlcLog:
1✔
806
        self.logger.info(f"waiting... amount_inflight={self._amount_inflight}. nhtlcs_inflight={self._nhtlcs_inflight}")
1✔
807
        htlc_log = await self.sent_htlcs_q.get()
1✔
808
        self._amount_inflight -= htlc_log.amount_msat
1✔
809
        self._nhtlcs_inflight -= 1
1✔
810
        if self._amount_inflight < 0 or self._nhtlcs_inflight < 0:
1✔
811
            raise Exception(f"amount_inflight={self._amount_inflight}, nhtlcs_inflight={self._nhtlcs_inflight}. both should be >= 0 !")
×
812
        return htlc_log
1✔
813

814
    def add_new_htlc(self, sent_htlc_info: SentHtlcInfo):
1✔
815
        self._nhtlcs_inflight += 1
1✔
816
        self._amount_inflight += sent_htlc_info.amount_receiver_msat
1✔
817
        if self._amount_inflight > self.amount_to_pay:  # safety belts
1✔
818
            raise Exception(f"amount_inflight={self._amount_inflight} > amount_to_pay={self.amount_to_pay}")
×
819
        shi = sent_htlc_info
1✔
820
        bkey = shi.payment_secret_bucket
1✔
821
        # if we sent MPP to a trampoline, add item to sent_buckets
822
        if self.uses_trampoline and shi.amount_msat != shi.bucket_msat:
1✔
823
            if bkey not in self._sent_buckets:
1✔
824
                self._sent_buckets[bkey] = (0, 0)
1✔
825
            amount_sent, amount_failed = self._sent_buckets[bkey]
1✔
826
            amount_sent += shi.amount_receiver_msat
1✔
827
            self._sent_buckets[bkey] = amount_sent, amount_failed
1✔
828

829
    def on_htlc_fail_get_fail_amt_to_propagate(self, sent_htlc_info: SentHtlcInfo) -> Optional[int]:
1✔
830
        shi = sent_htlc_info
1✔
831
        # check sent_buckets if we use trampoline
832
        bkey = shi.payment_secret_bucket
1✔
833
        if self.uses_trampoline and bkey in self._sent_buckets:
1✔
834
            amount_sent, amount_failed = self._sent_buckets[bkey]
1✔
835
            amount_failed += shi.amount_receiver_msat
1✔
836
            self._sent_buckets[bkey] = amount_sent, amount_failed
1✔
837
            if amount_sent != amount_failed:
1✔
838
                self.logger.info('bucket still active...')
1✔
839
                return None
1✔
840
            self.logger.info('bucket failed')
1✔
841
            return amount_sent
1✔
842
        # not using trampoline buckets
843
        return shi.amount_receiver_msat
1✔
844

845
    def get_outstanding_amount_to_send(self) -> int:
1✔
846
        return self.amount_to_pay - self._amount_inflight
1✔
847

848
    def can_be_deleted(self) -> bool:
1✔
849
        """Returns True iff finished sending htlcs AND all pending htlcs have resolved."""
850
        if self.is_active:
1✔
851
            return False
1✔
852
        # note: no one is consuming from sent_htlcs_q anymore
853
        nhtlcs_resolved = self.sent_htlcs_q.qsize()
1✔
854
        assert nhtlcs_resolved <= self._nhtlcs_inflight
1✔
855
        return nhtlcs_resolved == self._nhtlcs_inflight
1✔
856

857

858
class LNWallet(LNWorker):
1✔
859

860
    lnwatcher: Optional['LNWatcher']
1✔
861
    MPP_EXPIRY = 120
1✔
862
    TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 3  # seconds
1✔
863
    PAYMENT_TIMEOUT = 120
1✔
864
    MPP_SPLIT_PART_FRACTION = 0.2
1✔
865
    MPP_SPLIT_PART_MINAMT_MSAT = 5_000_000
1✔
866

867
    def __init__(self, wallet: 'Abstract_Wallet', xprv):
1✔
868
        self.wallet = wallet
1✔
869
        self.config = wallet.config
1✔
870
        self.db = wallet.db
1✔
871
        self.node_keypair = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.NODE_KEY)
1✔
872
        self.backup_key = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.BACKUP_CIPHER).privkey
1✔
873
        self.static_payment_key = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.PAYMENT_BASE)
1✔
874
        self.payment_secret_key = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.PAYMENT_SECRET_KEY).privkey
1✔
875
        self.funding_root_keypair = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.FUNDING_ROOT_KEY)
1✔
876
        Logger.__init__(self)
1✔
877
        features = LNWALLET_FEATURES
1✔
878
        if self.config.ENABLE_ANCHOR_CHANNELS:
1✔
879
            features |= LnFeatures.OPTION_ANCHORS_ZERO_FEE_HTLC_OPT
1✔
880
        if self.config.ACCEPT_ZEROCONF_CHANNELS:
1✔
881
            features |= LnFeatures.OPTION_ZEROCONF_OPT
×
882
        if self.config.EXPERIMENTAL_LN_FORWARD_PAYMENTS or self.config.EXPERIMENTAL_LN_FORWARD_TRAMPOLINE_PAYMENTS:
1✔
883
            features |= LnFeatures.OPTION_ONION_MESSAGE_OPT
×
884
        if self.config.EXPERIMENTAL_LN_FORWARD_PAYMENTS and self.config.LIGHTNING_USE_GOSSIP:
1✔
885
            features |= LnFeatures.GOSSIP_QUERIES_OPT  # signal we have gossip to fetch
×
886
        LNWorker.__init__(self, self.node_keypair, features, config=self.config)
1✔
887
        self.lnwatcher = LNWatcher(self)
1✔
888
        self.lnrater: LNRater = None
1✔
889
        # lightning_payments: RHASH -> amount_msat, direction, status, min_final_cltv_delta, expiry_delay, creation_ts
890
        self.payment_info = self.db.get_dict('lightning_payments')  # type: dict[str, Tuple[Optional[int], int, int, int, int, int]]
1✔
891
        self._preimages = self.db.get_dict('lightning_preimages')   # RHASH -> preimage
1✔
892
        self._bolt11_cache = {}
1✔
893
        # note: this sweep_address is only used as fallback; as it might result in address-reuse
894
        self.logs = defaultdict(list)  # type: Dict[str, List[HtlcLog]]  # key is RHASH  # (not persisted)
1✔
895
        # used in tests
896
        self.enable_htlc_settle = True
1✔
897
        self.enable_htlc_forwarding = True
1✔
898

899
        # note: accessing channels (besides simple lookup) needs self.lock!
900
        self._channels = {}  # type: Dict[bytes, Channel]
1✔
901
        channels = self.db.get_dict("channels")
1✔
902
        for channel_id, c in random_shuffled_copy(channels.items()):
1✔
903
            self._channels[bfh(channel_id)] = chan = Channel(c, lnworker=self)
1✔
904
            self.wallet.set_reserved_addresses_for_chan(chan, reserved=True)
1✔
905

906
        self._channel_backups = {}  # type: Dict[bytes, ChannelBackup]
1✔
907
        # order is important: imported should overwrite onchain
908
        for name in ["onchain_channel_backups", "imported_channel_backups"]:
1✔
909
            channel_backups = self.db.get_dict(name)
1✔
910
            for channel_id, storage in channel_backups.items():
1✔
911
                self._channel_backups[bfh(channel_id)] = cb = ChannelBackup(storage, lnworker=self)
×
912
                self.wallet.set_reserved_addresses_for_chan(cb, reserved=True)
×
913

914
        self._paysessions = dict()                      # type: Dict[bytes, PaySession]
1✔
915
        self.sent_htlcs_info = dict()                   # type: Dict[SentHtlcKey, SentHtlcInfo]
1✔
916
        self.received_mpp_htlcs = self.db.get_dict('received_mpp_htlcs')   # type: Dict[str, ReceivedMPPStatus]  # payment_key -> ReceivedMPPStatus
1✔
917

918
        # detect inflight payments
919
        self.inflight_payments = set()        # (not persisted) keys of invoices that are in PR_INFLIGHT state
1✔
920
        for payment_hash in self.get_payments(status='inflight').keys():
1✔
921
            self.set_invoice_status(payment_hash.hex(), PR_INFLIGHT)
×
922

923
        # payment forwarding
924
        self.active_forwardings = self.db.get_dict('active_forwardings')    # type: Dict[str, List[str]]        # Dict: payment_key -> list of htlc_keys
1✔
925
        self.forwarding_failures = self.db.get_dict('forwarding_failures')  # type: Dict[str, Tuple[str, str]]  # Dict: payment_key -> (error_bytes, error_message)
1✔
926
        self.downstream_to_upstream_htlc = {}                               # type: Dict[str, str]              # Dict: htlc_key -> htlc_key (not persisted)
1✔
927

928
        # k: payment_hashes of htlcs that we should not expire even if we don't know the preimage
929
        # v: If `None` the htlcs won't get expired and potentially get timed out in a force close.
930
        #    Note: it might not be safe to release the preimage shortly before expiry as this would allow the
931
        #          remote node to ignore our fulfill_htlc, wait until expiry and try to time out the htlc onchain
932
        #          in a fee race against us and then use our released preimage to fulfill upstream.
933
        # v: If `int`: Overwrites `MIN_FINAL_CLTV_DELTA_ACCEPTED` in htlc switch and allows to set custom
934
        #              expiration delta. The htlcs will get expired if their blocks left to expiry are
935
        #              below the specified expiration delta.
936
        # htlcs will get settled as soon as the preimage becomes available
937
        self.dont_expire_htlcs = self.db.get_dict('dont_expire_htlcs')      # type: Dict[str, Optional[int]]
1✔
938

939
        # k: payment_hash of payments for which we don't want to release the preimage, no matter
940
        # how close to expiry. Doesn't prevent htlcs from getting expired or failed if there is no
941
        # preimage available. Might be used in combination with dont_expire_htlcs.
942
        self.dont_settle_htlcs = self.db.get_dict('dont_settle_htlcs')  # type: Dict[str, None]
1✔
943

944
        # payment_hash -> callback:
945
        self.hold_invoice_callbacks = {}                # type: Dict[bytes, Callable[[bytes], Awaitable[None]]]
1✔
946
        self._payment_bundles_pkey_to_canon = {}       # type: Dict[bytes, bytes]            # TODO: persist
1✔
947
        self._payment_bundles_canon_to_pkeylist = {}   # type: Dict[bytes, Sequence[bytes]]  # TODO: persist
1✔
948

949
        self.nostr_keypair = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.NOSTR_KEY)
1✔
950
        self.swap_manager = SwapManager(wallet=self.wallet, lnworker=self)
1✔
951
        self.onion_message_manager = OnionMessageManager(self)
1✔
952
        self.subscribe_to_channels()
1✔
953

954
    def subscribe_to_channels(self):
1✔
955
        for chan in self.channels.values():
1✔
956
            self.lnwatcher.add_channel(chan)
1✔
957
        for cb in self.channel_backups.values():
1✔
958
            self.lnwatcher.add_channel(cb)
×
959

960
    def has_deterministic_node_id(self) -> bool:
1✔
961
        return bool(self.db.get('lightning_xprv'))
×
962

963
    def can_have_recoverable_channels(self) -> bool:
1✔
964
        return (self.has_deterministic_node_id()
×
965
                and not self.config.LIGHTNING_LISTEN)
966

967
    def has_recoverable_channels(self) -> bool:
1✔
968
        """Whether *future* channels opened by this wallet would be recoverable
969
        from seed (via putting OP_RETURN outputs into funding txs).
970
        """
971
        return (self.can_have_recoverable_channels()
×
972
                and self.config.LIGHTNING_USE_RECOVERABLE_CHANNELS)
973

974
    def has_anchor_channels(self) -> bool:
1✔
975
        """Returns True if any active channel is an anchor channel."""
976
        return any(chan.has_anchors() and not chan.is_closed()
1✔
977
                   for chan in self.channels.values())
978

979
    @property
1✔
980
    def channels(self) -> Mapping[bytes, Channel]:
1✔
981
        """Returns a read-only copy of channels."""
982
        with self.lock:
1✔
983
            return self._channels.copy()
1✔
984

985
    @property
1✔
986
    def channel_backups(self) -> Mapping[bytes, ChannelBackup]:
1✔
987
        """Returns a read-only copy of channels."""
988
        with self.lock:
1✔
989
            return self._channel_backups.copy()
1✔
990

991
    def get_channel_objects(self) -> Mapping[bytes, AbstractChannel]:
1✔
992
        r = self.channel_backups
×
993
        r.update(self.channels)
×
994
        return r
×
995

996
    def get_channel_by_id(self, channel_id: bytes) -> Optional[Channel]:
1✔
997
        return self._channels.get(channel_id, None)
1✔
998

999
    def diagnostic_name(self):
1✔
1000
        return self.wallet.diagnostic_name()
1✔
1001

1002
    @ignore_exceptions
1✔
1003
    @log_exceptions
1✔
1004
    async def sync_with_remote_watchtower(self):
1✔
1005
        self.watchtower_ctns = {}
1✔
1006
        while True:
1✔
1007
            # periodically poll if the user updated 'watchtower_url'
1008
            await asyncio.sleep(5)
1✔
1009
            watchtower_url = self.config.WATCHTOWER_CLIENT_URL
×
1010
            if not watchtower_url:
×
1011
                continue
×
1012
            parsed_url = urllib.parse.urlparse(watchtower_url)
×
1013
            if not (parsed_url.scheme == 'https' or is_private_netaddress(parsed_url.hostname)):
×
1014
                self.logger.warning(f"got watchtower URL for remote tower but we won't use it! "
×
1015
                                    f"can only use HTTPS (except if private IP): not using {watchtower_url!r}")
1016
                continue
×
1017
            # try to sync with the remote watchtower
1018
            try:
×
1019
                async with make_aiohttp_session(proxy=self.network.proxy) as session:
×
1020
                    watchtower = JsonRPCClient(session, watchtower_url)
×
1021
                    watchtower.add_method('get_ctn')
×
1022
                    watchtower.add_method('add_sweep_tx')
×
1023
                    for chan in self.channels.values():
×
1024
                        await self.sync_channel_with_watchtower(chan, watchtower)
×
1025
            except aiohttp.client_exceptions.ClientConnectorError:
×
1026
                self.logger.info(f'could not contact remote watchtower {watchtower_url}')
×
1027

1028
    def get_watchtower_ctn(self, channel_point):
1✔
1029
        return self.watchtower_ctns.get(channel_point)
×
1030

1031
    async def sync_channel_with_watchtower(self, chan: Channel, watchtower):
1✔
1032
        outpoint = chan.funding_outpoint.to_str()
×
1033
        addr = chan.get_funding_address()
×
1034
        current_ctn = chan.get_oldest_unrevoked_ctn(REMOTE)
×
1035
        watchtower_ctn = await watchtower.get_ctn(outpoint, addr)
×
1036
        for ctn in range(watchtower_ctn + 1, current_ctn):
×
1037
            sweeptxs = chan.create_sweeptxs_for_watchtower(ctn)
×
1038
            for tx in sweeptxs:
×
1039
                await watchtower.add_sweep_tx(outpoint, ctn, tx.inputs()[0].prevout.to_str(), tx.serialize())
×
1040
            self.watchtower_ctns[outpoint] = ctn
×
1041

1042
    def start_network(self, network: 'Network'):
1✔
1043
        super().start_network(network)
1✔
1044
        self.lnwatcher.start_network(network)
1✔
1045
        self.swap_manager.start_network(network)
1✔
1046
        self.lnrater = LNRater(self, network)
1✔
1047
        self.onion_message_manager.start_network(network=network)
1✔
1048

1049
        for coro in [
1✔
1050
                self.maybe_listen(),
1051
                self.lnwatcher.trigger_callbacks(),  # shortcut (don't block) if funding tx locked and verified
1052
                self.reestablish_peers_and_channels(),
1053
                self.sync_with_remote_watchtower(),
1054
        ]:
1055
            tg_coro = self.taskgroup.spawn(coro)
1✔
1056
            asyncio.run_coroutine_threadsafe(tg_coro, self.network.asyncio_loop)
1✔
1057

1058
    async def stop(self):
1✔
1059
        self.stopping_soon = True
1✔
1060
        if self.listen_server:  # stop accepting new peers
1✔
1061
            self.listen_server.close()
×
1062
        async with ignore_after(self.TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS):
1✔
1063
            await self.wait_for_received_pending_htlcs_to_get_removed()
1✔
1064
        await LNWorker.stop(self)
1✔
1065
        if self.lnwatcher:
1✔
1066
            self.lnwatcher.stop()
×
1067
            self.lnwatcher = None
×
1068
        if self.swap_manager and self.swap_manager.network:  # may not be present in tests
1✔
1069
            await self.swap_manager.stop()
×
1070
        if self.onion_message_manager:
1✔
1071
            await self.onion_message_manager.stop()
×
1072

1073
    async def wait_for_received_pending_htlcs_to_get_removed(self):
1✔
1074
        assert self.stopping_soon is True
1✔
1075
        # We try to fail pending MPP HTLCs, and wait a bit for them to get removed.
1076
        # Note: even without MPP, if we just failed/fulfilled an HTLC, it is good
1077
        #       to wait a bit for it to become irrevocably removed.
1078
        # Note: we don't wait for *all htlcs* to get removed, only for those
1079
        #       that we can already fail/fulfill. e.g. forwarded htlcs cannot be removed
1080
        async with OldTaskGroup() as group:
1✔
1081
            for peer in self.peers.values():
1✔
1082
                await group.spawn(peer.wait_one_htlc_switch_iteration())
1✔
1083
        while True:
1✔
1084
            if all(not peer.received_htlcs_pending_removal for peer in self.peers.values()):
1✔
1085
                break
1✔
1086
            async with OldTaskGroup(wait=any) as group:
1✔
1087
                for peer in self.peers.values():
1✔
1088
                    await group.spawn(peer.received_htlc_removed_event.wait())
1✔
1089

1090
    def peer_closed(self, peer):
1✔
1091
        for chan in self.channels_for_peer(peer.pubkey).values():
×
1092
            chan.peer_state = PeerState.DISCONNECTED
×
1093
            util.trigger_callback('channel', self.wallet, chan)
×
1094
        super().peer_closed(peer)
×
1095

1096
    def get_payments(self, *, status=None) -> Mapping[bytes, List[HTLCWithStatus]]:
1✔
1097
        out = defaultdict(list)
1✔
1098
        for chan in self.channels.values():
1✔
1099
            d = chan.get_payments(status=status)
1✔
1100
            for payment_hash, plist in d.items():
1✔
1101
                out[payment_hash] += plist
1✔
1102
        return out
1✔
1103

1104
    def get_payment_value(
1✔
1105
            self, info: Optional['PaymentInfo'],
1106
            plist: List[HTLCWithStatus]
1107
    ) -> Tuple[PaymentDirection, int, Optional[int], int]:
1108
        """ fee_msat is included in amount_msat"""
1109
        assert plist
×
1110
        amount_msat = sum(int(x.direction) * x.htlc.amount_msat for x in plist)
×
1111
        if all(x.direction == SENT for x in plist):
×
1112
            direction = PaymentDirection.SENT
×
1113
            fee_msat = (- info.amount_msat - amount_msat) if info else None
×
1114
        elif all(x.direction == RECEIVED for x in plist):
×
1115
            direction = PaymentDirection.RECEIVED
×
1116
            fee_msat = None
×
1117
        elif amount_msat < 0:
×
1118
            direction = PaymentDirection.SELF_PAYMENT
×
1119
            fee_msat = - amount_msat
×
1120
        else:
1121
            direction = PaymentDirection.FORWARDING
×
1122
            fee_msat = - amount_msat
×
1123
        timestamp = min([htlc_with_status.htlc.timestamp for htlc_with_status in plist])
×
1124
        return direction, amount_msat, fee_msat, timestamp
×
1125

1126
    def get_lightning_history(self) -> Dict[str, LightningHistoryItem]:
1✔
1127
        """
1128
        side effect: sets defaults labels
1129
        note that the result is not ordered
1130
        """
1131
        out = {}
×
1132
        for payment_hash, plist in self.get_payments(status='settled').items():
×
1133
            if len(plist) == 0:
×
1134
                continue
×
1135
            key = payment_hash.hex()
×
1136
            info = self.get_payment_info(payment_hash)
×
1137
            # note: just after successfully paying an invoice using MPP, amount and fee values might be shifted
1138
            #       temporarily: the amount only considers 'settled' htlcs (see plist above), but we might also
1139
            #       have some inflight htlcs still. Until all relevant htlcs settle, the amount will be lower than
1140
            #       expected and the fee higher (the inflight htlcs will be effectively counted as fees).
1141
            direction, amount_msat, fee_msat, timestamp = self.get_payment_value(info, plist)
×
1142
            label = self.wallet.get_label_for_rhash(key)
×
1143
            if not label and direction == PaymentDirection.FORWARDING:
×
1144
                label = _('Forwarding')
×
1145
            preimage = self.get_preimage(payment_hash).hex()
×
1146
            group_id = self.swap_manager.get_group_id_for_payment_hash(payment_hash)
×
1147
            item = LightningHistoryItem(
×
1148
                type='payment',
1149
                payment_hash=payment_hash.hex(),
1150
                preimage=preimage,
1151
                amount_msat=amount_msat,
1152
                fee_msat=fee_msat,
1153
                group_id=group_id,
1154
                timestamp=timestamp or 0,
1155
                label=label,
1156
                direction=direction,
1157
            )
1158
            out[payment_hash.hex()] = item
×
1159
        now = int(time.time())
×
1160
        for chan in itertools.chain(self.channels.values(), self.channel_backups.values()):  # type: AbstractChannel
×
1161
            item = chan.get_funding_height()
×
1162
            if item is None:
×
1163
                continue
×
1164
            funding_txid, funding_height, funding_timestamp = item
×
1165
            label = _('Open channel') + ' ' + chan.get_id_for_log()
×
1166
            self.wallet.set_default_label(funding_txid, label)
×
1167
            self.wallet.set_group_label(funding_txid, label)
×
1168
            item = LightningHistoryItem(
×
1169
                type='channel_opening',
1170
                label=label,
1171
                group_id=funding_txid,
1172
                timestamp=funding_timestamp or now,
1173
                amount_msat=chan.balance(LOCAL, ctn=0),
1174
                fee_msat=None,
1175
                payment_hash=None,
1176
                preimage=None,
1177
                direction=None,
1178
            )
1179
            out[funding_txid] = item
×
1180
            item = chan.get_closing_height()
×
1181
            if item is None:
×
1182
                continue
×
1183
            closing_txid, closing_height, closing_timestamp = item
×
1184
            label = _('Close channel') + ' ' + chan.get_id_for_log()
×
1185
            self.wallet.set_default_label(closing_txid, label)
×
1186
            self.wallet.set_group_label(closing_txid, label)
×
1187
            item = LightningHistoryItem(
×
1188
                type='channel_closing',
1189
                label=label,
1190
                group_id=closing_txid,
1191
                timestamp=closing_timestamp or now,
1192
                amount_msat=-chan.balance(LOCAL),
1193
                fee_msat=None,
1194
                payment_hash=None,
1195
                preimage=None,
1196
                direction=None,
1197
            )
1198
            out[closing_txid] = item
×
1199

1200
        # sanity check
1201
        balance_msat = sum([x.amount_msat for x in out.values()])
×
1202
        lb = sum(chan.balance(LOCAL) if not chan.is_closed_or_closing() else 0
×
1203
                 for chan in self.channels.values())
1204
        if balance_msat != lb:
×
1205
            # this typically happens when a channel is recently force closed
1206
            self.logger.info(f'get_lightning_history: balance mismatch {balance_msat - lb}')
×
1207
        return out
×
1208

1209
    def get_groups_for_onchain_history(self) -> Dict[str, str]:
1✔
1210
        """
1211
        returns dict: txid -> group_id
1212
        side effect: sets default labels
1213
        """
1214
        groups = {}
×
1215
        # add funding events
1216
        for chan in itertools.chain(self.channels.values(), self.channel_backups.values()):  # type: AbstractChannel
×
1217
            item = chan.get_funding_height()
×
1218
            if item is None:
×
1219
                continue
×
1220
            funding_txid, funding_height, funding_timestamp = item
×
1221
            groups[funding_txid] = funding_txid
×
1222
            item = chan.get_closing_height()
×
1223
            if item is None:
×
1224
                continue
×
1225
            closing_txid, closing_height, closing_timestamp = item
×
1226
            groups[closing_txid] = closing_txid
×
1227

1228
        d = self.swap_manager.get_groups_for_onchain_history()
×
1229
        for txid, v in d.items():
×
1230
            group_id = v['group_id']
×
1231
            label = v.get('label')
×
1232
            group_label = v.get('group_label') or label
×
1233
            groups[txid] = group_id
×
1234
            if label:
×
1235
                self.wallet.set_default_label(txid, label)
×
1236
            if group_label:
×
1237
                self.wallet.set_group_label(group_id, group_label)
×
1238

1239
        return groups
×
1240

1241
    def channel_peers(self) -> List[bytes]:
1✔
1242
        node_ids = [chan.node_id for chan in self.channels.values() if not chan.is_closed()]
×
1243
        return node_ids
×
1244

1245
    def channels_for_peer(self, node_id):
1✔
1246
        assert type(node_id) is bytes
1✔
1247
        return {chan_id: chan for (chan_id, chan) in self.channels.items()
1✔
1248
                if chan.node_id == node_id}
1249

1250
    def channel_state_changed(self, chan: Channel):
1✔
1251
        if type(chan) is Channel:
×
1252
            self.save_channel(chan)
×
1253
        self.clear_invoices_cache()
×
NEW
1254
        if chan._state == ChannelState.REDEEMED:
×
NEW
1255
            self.maybe_cleanup_mpp(chan)
×
UNCOV
1256
        util.trigger_callback('channel', self.wallet, chan)
×
1257

1258
    def save_channel(self, chan: Channel):
1✔
1259
        assert type(chan) is Channel
×
1260
        if chan.config[REMOTE].next_per_commitment_point == chan.config[REMOTE].current_per_commitment_point:
×
1261
            raise Exception("Tried to save channel with next_point == current_point, this should not happen")
×
1262
        self.wallet.save_db()
×
1263
        util.trigger_callback('channel', self.wallet, chan)
×
1264

1265
    def channel_by_txo(self, txo: str) -> Optional[AbstractChannel]:
1✔
1266
        for chan in self.channels.values():
×
1267
            if chan.funding_outpoint.to_str() == txo:
×
1268
                return chan
×
1269
        for chan in self.channel_backups.values():
×
1270
            if chan.funding_outpoint.to_str() == txo:
×
1271
                return chan
×
1272
        return None
×
1273

1274
    async def handle_onchain_state(self, chan: Channel):
1✔
1275
        if self.network is None:
×
1276
            # network not started yet
1277
            return
×
1278

1279
        if type(chan) is ChannelBackup:
×
1280
            util.trigger_callback('channel', self.wallet, chan)
×
1281
            return
×
1282

1283
        if (chan.get_state() in (ChannelState.OPEN, ChannelState.SHUTDOWN)
×
1284
                and chan.should_be_closed_due_to_expiring_htlcs(self.wallet.adb.get_local_height())):
1285
            self.logger.info(f"force-closing due to expiring htlcs")
×
1286
            await self.schedule_force_closing(chan.channel_id)
×
1287

1288
        elif chan.get_state() == ChannelState.FUNDED:
×
1289
            peer = self._peers.get(chan.node_id)
×
1290
            if peer and peer.is_initialized() and chan.peer_state == PeerState.GOOD:
×
1291
                peer.send_channel_ready(chan)
×
1292

1293
        elif chan.get_state() == ChannelState.OPEN:
×
1294
            peer = self._peers.get(chan.node_id)
×
1295
            if peer and peer.is_initialized() and chan.peer_state == PeerState.GOOD:
×
1296
                peer.maybe_update_fee(chan)
×
1297
                peer.maybe_send_announcement_signatures(chan)
×
1298

1299
        elif chan.get_state() == ChannelState.FORCE_CLOSING:
×
1300
            force_close_tx = chan.force_close_tx()
×
1301
            txid = force_close_tx.txid()
×
1302
            height = self.lnwatcher.adb.get_tx_height(txid).height()
×
1303
            if height == TX_HEIGHT_LOCAL:
×
1304
                self.logger.info('REBROADCASTING CLOSING TX')
×
1305
                await self.network.try_broadcasting(force_close_tx, 'force-close')
×
1306

1307
    def get_peer_by_static_jit_scid_alias(self, scid_alias: bytes) -> Optional[Peer]:
1✔
1308
        for nodeid, peer in self.peers.items():
×
1309
            if scid_alias == self._scid_alias_of_node(nodeid):
×
1310
                return peer
×
1311

1312
    def _scid_alias_of_node(self, nodeid: bytes) -> bytes:
1✔
1313
        # scid alias for just-in-time channels
1314
        return sha256(b'Electrum' + nodeid)[0:8]
×
1315

1316
    def get_static_jit_scid_alias(self) -> bytes:
1✔
1317
        return self._scid_alias_of_node(self.node_keypair.pubkey)
×
1318

1319
    @log_exceptions
1✔
1320
    async def open_channel_just_in_time(
1✔
1321
        self,
1322
        *,
1323
        next_peer: Peer,
1324
        next_amount_msat_htlc: int,
1325
        next_cltv_abs: int,
1326
        payment_hash: bytes,
1327
        next_onion: OnionPacket,
1328
    ) -> str:
1329
        # if an exception is raised during negotiation, we raise an OnionRoutingFailure.
1330
        # this will cancel the incoming HTLC
1331

1332
        # prevent settling the htlc until the channel opening was successful so we can fail it if needed
1333
        self.dont_settle_htlcs[payment_hash.hex()] = None
×
1334
        try:
×
1335
            funding_sat = 2 * (next_amount_msat_htlc // 1000) # try to fully spend htlcs
×
1336
            password = self.wallet.get_unlocked_password() if self.wallet.has_password() else None
×
1337
            channel_opening_fee = next_amount_msat_htlc // 100
×
1338
            if channel_opening_fee // 1000 < self.config.ZEROCONF_MIN_OPENING_FEE:
×
1339
                self.logger.info(f'rejecting JIT channel: payment too low')
×
1340
                raise OnionRoutingFailure(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'payment too low')
×
1341
            self.logger.info(f'channel opening fee (sats): {channel_opening_fee//1000}')
×
1342
            next_chan, funding_tx = await self.open_channel_with_peer(
×
1343
                next_peer, funding_sat,
1344
                push_sat=0,
1345
                zeroconf=True,
1346
                public=False,
1347
                opening_fee=channel_opening_fee,
1348
                password=password,
1349
            )
1350
            async def wait_for_channel():
×
1351
                while not next_chan.is_open():
×
1352
                    await asyncio.sleep(1)
×
1353
            await util.wait_for2(wait_for_channel(), LN_P2P_NETWORK_TIMEOUT)
×
1354
            next_chan.save_remote_scid_alias(self._scid_alias_of_node(next_peer.pubkey))
×
1355
            self.logger.info(f'JIT channel is open')
×
1356
            next_amount_msat_htlc -= channel_opening_fee
×
1357
            # fixme: some checks are missing
1358
            htlc = next_peer.send_htlc(
×
1359
                chan=next_chan,
1360
                payment_hash=payment_hash,
1361
                amount_msat=next_amount_msat_htlc,
1362
                cltv_abs=next_cltv_abs,
1363
                onion=next_onion)
1364
            async def wait_for_preimage():
×
1365
                while self.get_preimage(payment_hash) is None:
×
1366
                    await asyncio.sleep(1)
×
1367
            await util.wait_for2(wait_for_preimage(), LN_P2P_NETWORK_TIMEOUT)
×
1368

1369
            # We have been paid and can broadcast
1370
            # todo: if broadcasting raise an exception, we should try to rebroadcast
1371
            await self.network.broadcast_transaction(funding_tx)
×
1372
        except OnionRoutingFailure:
×
1373
            raise
×
1374
        except Exception:
×
1375
            raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_NODE_FAILURE, data=b'')
×
1376
        finally:
1377
            del self.dont_settle_htlcs[payment_hash.hex()]
×
1378

1379
        htlc_key = serialize_htlc_key(next_chan.get_scid_or_local_alias(), htlc.htlc_id)
×
1380
        return htlc_key
×
1381

1382
    @log_exceptions
1✔
1383
    async def open_channel_with_peer(
1✔
1384
            self, peer, funding_sat, *,
1385
            push_sat: int = 0,
1386
            public: bool = False,
1387
            zeroconf: bool = False,
1388
            opening_fee: int = None,
1389
            password=None):
1390
        if self.config.ENABLE_ANCHOR_CHANNELS:
×
1391
            self.wallet.unlock(password)
×
1392
        coins = self.wallet.get_spendable_coins(None)
×
1393
        node_id = peer.pubkey
×
1394
        fee_policy = FeePolicy(self.config.FEE_POLICY)
×
1395
        funding_tx = self.mktx_for_open_channel(
×
1396
            coins=coins,
1397
            funding_sat=funding_sat,
1398
            node_id=node_id,
1399
            fee_policy=fee_policy)
1400
        chan, funding_tx = await self._open_channel_coroutine(
×
1401
            peer=peer,
1402
            funding_tx=funding_tx,
1403
            funding_sat=funding_sat,
1404
            push_sat=push_sat,
1405
            public=public,
1406
            zeroconf=zeroconf,
1407
            opening_fee=opening_fee,
1408
            password=password)
1409
        return chan, funding_tx
×
1410

1411
    @log_exceptions
1✔
1412
    async def _open_channel_coroutine(
1✔
1413
            self, *,
1414
            peer: Peer,
1415
            funding_tx: PartialTransaction,
1416
            funding_sat: int,
1417
            push_sat: int,
1418
            public: bool,
1419
            zeroconf=False,
1420
            opening_fee=None,
1421
            password: Optional[str],
1422
    ) -> Tuple[Channel, PartialTransaction]:
1423

1424
        if funding_sat > self.config.LIGHTNING_MAX_FUNDING_SAT:
×
1425
            raise Exception(
×
1426
                _("Requested channel capacity is over maximum.")
1427
                + f"\n{funding_sat} sat > {self.config.LIGHTNING_MAX_FUNDING_SAT} sat"
1428
            )
1429
        coro = peer.channel_establishment_flow(
×
1430
            funding_tx=funding_tx,
1431
            funding_sat=funding_sat,
1432
            push_msat=push_sat * 1000,
1433
            public=public,
1434
            zeroconf=zeroconf,
1435
            opening_fee=opening_fee,
1436
            temp_channel_id=os.urandom(32))
1437
        chan, funding_tx = await util.wait_for2(coro, LN_P2P_NETWORK_TIMEOUT)
×
1438
        util.trigger_callback('channels_updated', self.wallet)
×
1439
        self.wallet.adb.add_transaction(funding_tx)  # save tx as local into the wallet
×
1440
        self.wallet.sign_transaction(funding_tx, password)
×
1441
        if funding_tx.is_complete() and not zeroconf:
×
1442
            await self.network.try_broadcasting(funding_tx, 'open_channel')
×
1443
        return chan, funding_tx
×
1444

1445
    def add_channel(self, chan: Channel):
1✔
1446
        with self.lock:
×
1447
            self._channels[chan.channel_id] = chan
×
1448
        self.lnwatcher.add_channel(chan)
×
1449

1450
    def add_new_channel(self, chan: Channel):
1✔
1451
        self.add_channel(chan)
×
1452
        channels_db = self.db.get_dict('channels')
×
1453
        channels_db[chan.channel_id.hex()] = chan.storage
×
1454
        self.wallet.set_reserved_addresses_for_chan(chan, reserved=True)
×
1455
        try:
×
1456
            self.save_channel(chan)
×
1457
        except Exception:
×
1458
            chan.set_state(ChannelState.REDEEMED)
×
1459
            self.remove_channel(chan.channel_id)
×
1460
            raise
×
1461

1462
    def cb_data(self, node_id: bytes) -> bytes:
1✔
1463
        return CB_MAGIC_BYTES + node_id[0:NODE_ID_PREFIX_LEN]
×
1464

1465
    def decrypt_cb_data(self, encrypted_data: bytes, funding_address: str) -> bytes:
1✔
1466
        funding_scripthash = bytes.fromhex(address_to_scripthash(funding_address))
×
1467
        nonce = funding_scripthash[0:12]
×
1468
        return chacha20_decrypt(key=self.backup_key, data=encrypted_data, nonce=nonce)
×
1469

1470
    def encrypt_cb_data(self, data: bytes, funding_address: str) -> bytes:
1✔
1471
        funding_scripthash = bytes.fromhex(address_to_scripthash(funding_address))
×
1472
        nonce = funding_scripthash[0:12]
×
1473
        # note: we are only using chacha20 instead of chacha20+poly1305 to save onchain space
1474
        #       (not have the 16 byte MAC). Otherwise, the latter would be preferable.
1475
        return chacha20_encrypt(key=self.backup_key, data=data, nonce=nonce)
×
1476

1477
    def mktx_for_open_channel(
1✔
1478
            self, *,
1479
            coins: Sequence[PartialTxInput],
1480
            funding_sat: int,
1481
            node_id: bytes,
1482
            fee_policy: FeePolicy,
1483
    ) -> PartialTransaction:
1484
        from .wallet import get_locktime_for_new_transaction
×
1485

1486
        outputs = [PartialTxOutput.from_address_and_value(DummyAddress.CHANNEL, funding_sat)]
×
1487
        if self.has_recoverable_channels():
×
1488
            dummy_scriptpubkey = make_op_return(self.cb_data(node_id))
×
1489
            outputs.append(PartialTxOutput(scriptpubkey=dummy_scriptpubkey, value=0))
×
1490
        tx = self.wallet.make_unsigned_transaction(
×
1491
            coins=coins,
1492
            outputs=outputs,
1493
            fee_policy=fee_policy,
1494
            # we do not know yet if peer accepts anchors, just assume they do
1495
            is_anchor_channel_opening=self.config.ENABLE_ANCHOR_CHANNELS,
1496
        )
1497
        tx.set_rbf(False)
×
1498
        # rm randomness from locktime, as we use the locktime as entropy for deriving the funding_privkey
1499
        # (and it would be confusing to get a collision as a consequence of the randomness)
1500
        tx.locktime = get_locktime_for_new_transaction(self.network, include_random_component=False)
×
1501
        return tx
×
1502

1503
    def suggest_funding_amount(self, amount_to_pay: int, coins: Sequence[PartialTxInput]) -> Tuple[int, int] | None:
1✔
1504
        """ whether we can pay amount_sat after opening a new channel"""
1505
        num_sats_can_send = int(self.num_sats_can_send())
×
1506
        lightning_needed = amount_to_pay - num_sats_can_send
×
1507
        assert lightning_needed > 0
×
1508
        min_funding_sat = lightning_needed + (lightning_needed // 20) + 1000  # safety margin
×
1509
        min_funding_sat = max(min_funding_sat, MIN_FUNDING_SAT)  # at least MIN_FUNDING_SAT
×
1510
        if min_funding_sat > self.config.LIGHTNING_MAX_FUNDING_SAT:
×
1511
            return
×
1512
        fee_policy = FeePolicy(f'feerate:{FEERATE_FALLBACK_STATIC_FEE}')
×
1513
        try:
×
1514
            self.mktx_for_open_channel(
×
1515
                coins=coins, funding_sat=min_funding_sat, node_id=bytes(32), fee_policy=fee_policy)
1516
            funding_sat = min_funding_sat
×
1517
        except NotEnoughFunds:
×
1518
            return
×
1519
        # if available, suggest twice that amount:
1520
        if 2 * min_funding_sat <= self.config.LIGHTNING_MAX_FUNDING_SAT:
×
1521
            try:
×
1522
                self.mktx_for_open_channel(
×
1523
                    coins=coins, funding_sat=2*min_funding_sat, node_id=bytes(32), fee_policy=fee_policy)
1524
                funding_sat = 2 * min_funding_sat
×
1525
            except NotEnoughFunds:
×
1526
                pass
×
1527
        return funding_sat, min_funding_sat
×
1528

1529
    def open_channel(
1✔
1530
            self, *,
1531
            connect_str: str,
1532
            funding_tx: PartialTransaction,
1533
            funding_sat: int,
1534
            push_amt_sat: int,
1535
            public: bool = False,
1536
            password: str = None,
1537
    ) -> Tuple[Channel, PartialTransaction]:
1538

1539
        fut = asyncio.run_coroutine_threadsafe(self.add_peer(connect_str), self.network.asyncio_loop)
×
1540
        try:
×
1541
            peer = fut.result()
×
1542
        except concurrent.futures.TimeoutError:
×
1543
            raise Exception(_("add peer timed out"))
×
1544
        coro = self._open_channel_coroutine(
×
1545
            peer=peer,
1546
            funding_tx=funding_tx,
1547
            funding_sat=funding_sat,
1548
            push_sat=push_amt_sat,
1549
            public=public,
1550
            password=password)
1551
        fut = asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop)
×
1552
        try:
×
1553
            chan, funding_tx = fut.result()
×
1554
        except concurrent.futures.TimeoutError:
×
1555
            raise Exception(_("open_channel timed out"))
×
1556
        return chan, funding_tx
×
1557

1558
    def get_channel_by_short_id(self, short_channel_id: bytes) -> Optional[Channel]:
1✔
1559
        # First check against *real* SCIDs.
1560
        # This e.g. protects against maliciously chosen SCID aliases, and accidental collisions.
1561
        for chan in self.channels.values():
×
1562
            if chan.short_channel_id == short_channel_id:
×
1563
                return chan
×
1564
        # Now we also consider aliases.
1565
        # TODO we should split this as this search currently ignores the "direction"
1566
        #      of the aliases. We should only look at either the remote OR the local alias,
1567
        #      depending on context.
1568
        for chan in self.channels.values():
×
1569
            if chan.get_remote_scid_alias() == short_channel_id:
×
1570
                return chan
×
1571
            if chan.get_local_scid_alias() == short_channel_id:
×
1572
                return chan
×
1573
        return None
×
1574

1575
    def can_pay_invoice(self, invoice: Invoice) -> bool:
1✔
1576
        assert invoice.is_lightning()
×
1577
        return (invoice.get_amount_sat() or 0) <= self.num_sats_can_send()
×
1578

1579
    @log_exceptions
1✔
1580
    async def pay_invoice(
1✔
1581
            self, invoice: Invoice, *,
1582
            amount_msat: int = None,
1583
            attempts: int = None,  # used only in unit tests
1584
            full_path: LNPaymentPath = None,
1585
            channels: Optional[Sequence[Channel]] = None,
1586
            budget: Optional[PaymentFeeBudget] = None,
1587
    ) -> Tuple[bool, List[HtlcLog]]:
1588
        bolt11 = invoice.lightning_invoice
1✔
1589
        lnaddr = self._check_bolt11_invoice(bolt11, amount_msat=amount_msat)
1✔
1590
        min_final_cltv_delta = lnaddr.get_min_final_cltv_delta()
1✔
1591
        payment_hash = lnaddr.paymenthash
1✔
1592
        key = payment_hash.hex()
1✔
1593
        payment_secret = lnaddr.payment_secret
1✔
1594
        invoice_pubkey = lnaddr.pubkey.serialize()
1✔
1595
        invoice_features = lnaddr.get_features()
1✔
1596
        r_tags = lnaddr.get_routing_info('r')
1✔
1597
        amount_to_pay = lnaddr.get_amount_msat()
1✔
1598
        status = self.get_payment_status(payment_hash)
1✔
1599
        if status == PR_PAID:
1✔
1600
            raise PaymentFailure(_("This invoice has been paid already"))
×
1601
        if status == PR_INFLIGHT:
1✔
1602
            raise PaymentFailure(_("A payment was already initiated for this invoice"))
×
1603
        if payment_hash in self.get_payments(status='inflight'):
1✔
1604
            raise PaymentFailure(_("A previous attempt to pay this invoice did not clear"))
×
1605
        info = PaymentInfo(
1✔
1606
            payment_hash=payment_hash,
1607
            amount_msat=amount_to_pay,
1608
            direction=SENT,
1609
            status=PR_UNPAID,
1610
            min_final_cltv_delta=min_final_cltv_delta,
1611
            expiry_delay=LN_EXPIRY_NEVER,
1612
        )
1613
        self.save_payment_info(info)
1✔
1614
        self.wallet.set_label(key, lnaddr.get_description())
1✔
1615
        self.set_invoice_status(key, PR_INFLIGHT)
1✔
1616
        if budget is None:
1✔
1617
            budget = PaymentFeeBudget.from_invoice_amount(invoice_amount_msat=amount_to_pay, config=self.config)
1✔
1618
        if attempts is None and self.uses_trampoline():
1✔
1619
            # we don't expect lots of failed htlcs with trampoline, so we can fail sooner
1620
            attempts = 30
1✔
1621
        success = False
1✔
1622
        try:
1✔
1623
            await self.pay_to_node(
1✔
1624
                node_pubkey=invoice_pubkey,
1625
                payment_hash=payment_hash,
1626
                payment_secret=payment_secret,
1627
                amount_to_pay=amount_to_pay,
1628
                min_final_cltv_delta=min_final_cltv_delta,
1629
                r_tags=r_tags,
1630
                invoice_features=invoice_features,
1631
                attempts=attempts,
1632
                full_path=full_path,
1633
                channels=channels,
1634
                budget=budget,
1635
            )
1636
            success = True
1✔
1637
        except PaymentFailure as e:
1✔
1638
            self.logger.info(f'payment failure: {e!r}')
1✔
1639
            reason = str(e)
1✔
1640
        except ChannelDBNotLoaded as e:
1✔
1641
            self.logger.info(f'payment failure: {e!r}')
×
1642
            reason = str(e)
×
1643
        finally:
1644
            self.logger.info(f"pay_invoice ending session for RHASH={payment_hash.hex()}. {success=}")
1✔
1645
        if success:
1✔
1646
            self.set_invoice_status(key, PR_PAID)
1✔
1647
            util.trigger_callback('payment_succeeded', self.wallet, key)
1✔
1648
        else:
1649
            self.set_invoice_status(key, PR_UNPAID)
1✔
1650
            util.trigger_callback('payment_failed', self.wallet, key, reason)
1✔
1651
        log = self.logs[key]
1✔
1652
        return success, log
1✔
1653

1654
    async def pay_to_node(
1✔
1655
            self, *,
1656
            node_pubkey: bytes,
1657
            payment_hash: bytes,
1658
            payment_secret: bytes,
1659
            amount_to_pay: int,  # in msat
1660
            min_final_cltv_delta: int,
1661
            r_tags,
1662
            invoice_features: int,
1663
            attempts: int = None,
1664
            full_path: LNPaymentPath = None,
1665
            fwd_trampoline_onion: OnionPacket = None,
1666
            budget: PaymentFeeBudget,
1667
            channels: Optional[Sequence[Channel]] = None,
1668
            fw_payment_key: str = None,  # for forwarding
1669
    ) -> None:
1670
        """
1671
        Can raise PaymentFailure, ChannelDBNotLoaded,
1672
        or OnionRoutingFailure (if forwarding trampoline).
1673
        """
1674

1675
        assert budget
1✔
1676
        assert budget.fee_msat >= 0, budget
1✔
1677
        assert budget.cltv >= 0, budget
1✔
1678

1679
        payment_key = payment_hash + payment_secret
1✔
1680
        assert payment_key not in self._paysessions
1✔
1681
        self._paysessions[payment_key] = paysession = PaySession(
1✔
1682
            payment_hash=payment_hash,
1683
            payment_secret=payment_secret,
1684
            initial_trampoline_fee_level=self.config.INITIAL_TRAMPOLINE_FEE_LEVEL,
1685
            invoice_features=invoice_features,
1686
            r_tags=r_tags,
1687
            min_final_cltv_delta=min_final_cltv_delta,
1688
            amount_to_pay=amount_to_pay,
1689
            invoice_pubkey=node_pubkey,
1690
            uses_trampoline=self.uses_trampoline(),
1691
            # the config option to use two trampoline hops for legacy payments has been removed as
1692
            # the trampoline onion is too small (400 bytes) to accommodate two trampoline hops and
1693
            # routing hints, making the functionality unusable for payments that require routing hints.
1694
            # TODO: if you read this, the year is 2027 and there is no use for the second trampoline
1695
            # hop code anymore remove the code completely.
1696
            use_two_trampolines=False,
1697
        )
1698
        self.logs[payment_hash.hex()] = log = []  # TODO incl payment_secret in key (re trampoline forwarding)
1✔
1699

1700
        paysession.logger.info(
1✔
1701
            f"pay_to_node starting session for RHASH={payment_hash.hex()}. "
1702
            f"using_trampoline={self.uses_trampoline()}. "
1703
            f"invoice_features={paysession.invoice_features.get_names()}. "
1704
            f"{amount_to_pay=} msat. {budget=}")
1705
        if not self.uses_trampoline():
1✔
1706
            self.logger.info(
1✔
1707
                f"gossip_db status. sync progress: {self.network.lngossip.get_sync_progress_estimate()}. "
1708
                f"num_nodes={self.channel_db.num_nodes}, "
1709
                f"num_channels={self.channel_db.num_channels}, "
1710
                f"num_policies={self.channel_db.num_policies}.")
1711

1712
        # when encountering trampoline forwarding difficulties in the legacy case, we
1713
        # sometimes need to fall back to a single trampoline forwarder, at the expense
1714
        # of privacy
1715
        try:
1✔
1716
            while True:
1✔
1717
                if (amount_to_send := paysession.get_outstanding_amount_to_send()) > 0:
1✔
1718
                    # 1. create a set of routes for remaining amount.
1719
                    # note: path-finding runs in a separate thread so that we don't block the asyncio loop
1720
                    # graph updates might occur during the computation
1721
                    remaining_fee_budget_msat = (budget.fee_msat * amount_to_send) // amount_to_pay
1✔
1722
                    routes = self.create_routes_for_payment(
1✔
1723
                        paysession=paysession,
1724
                        amount_msat=amount_to_send,
1725
                        full_path=full_path,
1726
                        fwd_trampoline_onion=fwd_trampoline_onion,
1727
                        channels=channels,
1728
                        budget=budget._replace(fee_msat=remaining_fee_budget_msat),
1729
                    )
1730
                    # 2. send htlcs
1731
                    async for sent_htlc_info, cltv_delta, trampoline_onion in routes:
1✔
1732
                        await self.pay_to_route(
1✔
1733
                            paysession=paysession,
1734
                            sent_htlc_info=sent_htlc_info,
1735
                            min_final_cltv_delta=cltv_delta,
1736
                            trampoline_onion=trampoline_onion,
1737
                            fw_payment_key=fw_payment_key,
1738
                        )
1739
                    # invoice_status is triggered in self.set_invoice_status when it actually changes.
1740
                    # It is also triggered here to update progress for a lightning payment in the GUI
1741
                    # (e.g. attempt counter)
1742
                    util.trigger_callback('invoice_status', self.wallet, payment_hash.hex(), PR_INFLIGHT)
1✔
1743
                # 3. await a queue, collect resolved htlcs
1744
                htlc_log = await paysession.wait_for_one_htlc_to_resolve()
1✔
1745
                while True:
1✔
1746
                    log.append(htlc_log)
1✔
1747
                    await self._process_htlc_log(
1✔
1748
                        paysession=paysession, htlc_log=htlc_log, is_forwarding_trampoline=bool(fwd_trampoline_onion))
1749
                    if paysession.number_htlcs_inflight < 1:
1✔
1750
                        break
1✔
1751
                    # wait a bit, more failures might come
1752
                    try:
1✔
1753
                        htlc_log = await util.wait_for2(
1✔
1754
                            paysession.wait_for_one_htlc_to_resolve(),
1755
                            timeout=paysession.TIMEOUT_WAIT_FOR_NEXT_RESOLVED_HTLC)
1756
                    except asyncio.TimeoutError:
1✔
1757
                        break
1✔
1758

1759
                # max attempts or timeout
1760
                if (attempts is not None and len(log) >= attempts) or (attempts is None and time.time() - paysession.start_time > self.PAYMENT_TIMEOUT):
1✔
1761
                    raise PaymentFailure('Giving up after %d attempts'%len(log))
1✔
1762
        except PaymentSuccess:
1✔
1763
            pass
1✔
1764
        finally:
1765
            paysession.is_active = False
1✔
1766
            if paysession.can_be_deleted():
1✔
1767
                self._paysessions.pop(payment_key)
1✔
1768
            paysession.logger.info(f"pay_to_node ending session for RHASH={payment_hash.hex()}")
1✔
1769

1770
    async def _process_htlc_log(
1✔
1771
        self,
1772
        *,
1773
        paysession: PaySession,
1774
        htlc_log: HtlcLog,
1775
        is_forwarding_trampoline: bool,
1776
    ) -> None:
1777
        """Handle a single just-resolved HTLC, as part of a payment-session.
1778

1779
        Can raise PaymentFailure, PaymentSuccess,
1780
        or OnionRoutingFailure (if forwarding trampoline).
1781
        """
1782
        if htlc_log.success:
1✔
1783
            if self.network.path_finder:
1✔
1784
                # TODO: report every route to liquidity hints for mpp
1785
                # in the case of success, we report channels of the
1786
                # route as being able to send the same amount in the future,
1787
                # as we assume to not know the capacity
1788
                self.network.path_finder.update_liquidity_hints(htlc_log.route, htlc_log.amount_msat)
1✔
1789
                # remove inflight htlcs from liquidity hints
1790
                self.network.path_finder.update_inflight_htlcs(htlc_log.route, add_htlcs=False)
1✔
1791
            raise PaymentSuccess()
1✔
1792
        # htlc failed
1793
        # if we get a tmp channel failure, it might work to split the amount and try more routes
1794
        # if we get a channel update, we might retry the same route and amount
1795
        route = htlc_log.route
1✔
1796
        sender_idx = htlc_log.sender_idx
1✔
1797
        failure_msg = htlc_log.failure_msg
1✔
1798
        if sender_idx is None:
1✔
UNCOV
1799
            raise PaymentFailure(failure_msg.code_name())
×
1800
        erring_node_id = route[sender_idx].node_id
1✔
1801
        code, data = failure_msg.code, failure_msg.data
1✔
1802
        self.logger.info(f"UPDATE_FAIL_HTLC. code={repr(code)}. "
1✔
1803
                         f"decoded_data={failure_msg.decode_data()}. data={data.hex()!r}")
1804
        self.logger.info(f"error reported by {erring_node_id.hex()}")
1✔
1805
        if code == OnionFailureCode.MPP_TIMEOUT:
1✔
1806
            raise PaymentFailure(failure_msg.code_name())
1✔
1807
        # errors returned by the next trampoline.
1808
        if is_forwarding_trampoline and code in [
1✔
1809
                OnionFailureCode.TRAMPOLINE_FEE_INSUFFICIENT,
1810
                OnionFailureCode.TRAMPOLINE_EXPIRY_TOO_SOON]:
1811
            raise failure_msg
×
1812
        # trampoline
1813
        if self.uses_trampoline():
1✔
1814
            paysession.handle_failed_trampoline_htlc(
1✔
1815
                htlc_log=htlc_log, failure_msg=failure_msg)
1816
        else:
1817
            self.handle_error_code_from_failed_htlc(
1✔
1818
                route=route, sender_idx=sender_idx, failure_msg=failure_msg, amount=htlc_log.amount_msat)
1819

1820
    async def pay_to_route(
1✔
1821
            self, *,
1822
            paysession: PaySession,
1823
            sent_htlc_info: SentHtlcInfo,
1824
            min_final_cltv_delta: int,
1825
            trampoline_onion: Optional[OnionPacket] = None,
1826
            fw_payment_key: str = None,
1827
    ) -> None:
1828
        """Sends a single HTLC."""
1829
        shi = sent_htlc_info
1✔
1830
        del sent_htlc_info  # just renamed
1✔
1831
        short_channel_id = shi.route[0].short_channel_id
1✔
1832
        chan = self.get_channel_by_short_id(short_channel_id)
1✔
1833
        assert chan, ShortChannelID(short_channel_id)
1✔
1834
        peer = self._peers.get(shi.route[0].node_id)
1✔
1835
        if not peer:
1✔
1836
            raise PaymentFailure('Dropped peer')
×
1837
        await peer.initialized
1✔
1838
        htlc = peer.pay(
1✔
1839
            route=shi.route,
1840
            chan=chan,
1841
            amount_msat=shi.amount_msat,
1842
            total_msat=shi.bucket_msat,
1843
            payment_hash=paysession.payment_hash,
1844
            min_final_cltv_delta=min_final_cltv_delta,
1845
            payment_secret=shi.payment_secret_bucket,
1846
            trampoline_onion=trampoline_onion)
1847

1848
        key = (paysession.payment_hash, short_channel_id, htlc.htlc_id)
1✔
1849
        self.sent_htlcs_info[key] = shi
1✔
1850
        paysession.add_new_htlc(shi)
1✔
1851
        if fw_payment_key:
1✔
1852
            htlc_key = serialize_htlc_key(short_channel_id, htlc.htlc_id)
1✔
1853
            self.logger.info(f'adding active forwarding {fw_payment_key}')
1✔
1854
            self.active_forwardings[fw_payment_key].append(htlc_key)
1✔
1855
        if self.network.path_finder:
1✔
1856
            # add inflight htlcs to liquidity hints
1857
            self.network.path_finder.update_inflight_htlcs(shi.route, add_htlcs=True)
1✔
1858
        util.trigger_callback('htlc_added', chan, htlc, SENT)
1✔
1859

1860
    def handle_error_code_from_failed_htlc(
1✔
1861
            self,
1862
            *,
1863
            route: LNPaymentRoute,
1864
            sender_idx: int,
1865
            failure_msg: OnionRoutingFailure,
1866
            amount: int) -> None:
1867

1868
        assert self.channel_db  # cannot be in trampoline mode
1✔
1869
        assert self.network.path_finder
1✔
1870

1871
        # remove inflight htlcs from liquidity hints
1872
        self.network.path_finder.update_inflight_htlcs(route, add_htlcs=False)
1✔
1873

1874
        code, data = failure_msg.code, failure_msg.data
1✔
1875
        # TODO can we use lnmsg.OnionWireSerializer here?
1876
        # TODO update onion_wire.csv
1877
        # handle some specific error codes
1878
        failure_codes = {
1✔
1879
            OnionFailureCode.TEMPORARY_CHANNEL_FAILURE: 0,
1880
            OnionFailureCode.AMOUNT_BELOW_MINIMUM: 8,
1881
            OnionFailureCode.FEE_INSUFFICIENT: 8,
1882
            OnionFailureCode.INCORRECT_CLTV_EXPIRY: 4,
1883
            OnionFailureCode.EXPIRY_TOO_SOON: 0,
1884
            OnionFailureCode.CHANNEL_DISABLED: 2,
1885
        }
1886
        try:
1✔
1887
            failing_channel = route[sender_idx + 1].short_channel_id
1✔
1888
        except IndexError:
1✔
1889
            raise PaymentFailure(f'payment destination reported error: {failure_msg.code_name()}') from None
1✔
1890

1891
        # TODO: handle unknown next peer?
1892
        # handle failure codes that include a channel update
1893
        if code in failure_codes:
1✔
1894
            offset = failure_codes[code]
1✔
1895
            channel_update_len = int.from_bytes(data[offset:offset+2], byteorder="big")
1✔
1896
            channel_update_as_received = data[offset+2: offset+2+channel_update_len]
1✔
1897
            payload = self._decode_channel_update_msg(channel_update_as_received)
1✔
1898
            if payload is None:
1✔
1899
                self.logger.info(f'could not decode channel_update for failed htlc: '
×
1900
                                 f'{channel_update_as_received.hex()}')
1901
                blacklist = True
×
1902
            elif payload.get('short_channel_id') != failing_channel:
1✔
1903
                self.logger.info(f'short_channel_id in channel_update does not match our route')
×
1904
                blacklist = True
×
1905
            else:
1906
                # apply the channel update or get blacklisted
1907
                blacklist, update = self._handle_chanupd_from_failed_htlc(
1✔
1908
                    payload, route=route, sender_idx=sender_idx, failure_msg=failure_msg)
1909
                # we interpret a temporary channel failure as a liquidity issue
1910
                # in the channel and update our liquidity hints accordingly
1911
                if code == OnionFailureCode.TEMPORARY_CHANNEL_FAILURE:
1✔
1912
                    self.network.path_finder.update_liquidity_hints(
1✔
1913
                        route,
1914
                        amount,
1915
                        failing_channel=ShortChannelID(failing_channel))
1916
                # if we can't decide on some action, we are stuck
1917
                if not (blacklist or update):
1✔
1918
                    raise PaymentFailure(failure_msg.code_name())
×
1919
        # for errors that do not include a channel update
1920
        else:
1921
            blacklist = True
1✔
1922
        if blacklist:
1✔
1923
            self.network.path_finder.add_edge_to_blacklist(short_channel_id=failing_channel)
1✔
1924

1925
    def _handle_chanupd_from_failed_htlc(
1✔
1926
        self, payload, *,
1927
        route: LNPaymentRoute,
1928
        sender_idx: int,
1929
        failure_msg: OnionRoutingFailure,
1930
    ) -> Tuple[bool, bool]:
1931
        blacklist = False
1✔
1932
        update = False
1✔
1933
        try:
1✔
1934
            r = self.channel_db.add_channel_update(payload, verify=True)
1✔
1935
        except InvalidGossipMsg:
×
1936
            return True, False  # blacklist
×
1937
        short_channel_id = ShortChannelID(payload['short_channel_id'])
1✔
1938
        if r == UpdateStatus.GOOD:
1✔
1939
            self.logger.info(f"applied channel update to {short_channel_id}")
×
1940
            # TODO: add test for this
1941
            # FIXME: this does not work for our own unannounced channels.
1942
            for chan in self.channels.values():
×
1943
                if chan.short_channel_id == short_channel_id:
×
1944
                    chan.set_remote_update(payload)
×
1945
            update = True
×
1946
        elif r == UpdateStatus.ORPHANED:
1✔
1947
            # maybe it is a private channel (and data in invoice was outdated)
1948
            self.logger.info(f"Could not find {short_channel_id}. maybe update is for private channel?")
1✔
1949
            start_node_id = route[sender_idx].node_id
1✔
1950
            cache_ttl = None
1✔
1951
            if failure_msg.code == OnionFailureCode.CHANNEL_DISABLED:
1✔
1952
                # eclair sends CHANNEL_DISABLED if its peer is offline. E.g. we might be trying to pay
1953
                # a mobile phone with the app closed. So we cache this with a short TTL.
1954
                cache_ttl = self.channel_db.PRIVATE_CHAN_UPD_CACHE_TTL_SHORT
×
1955
            update = self.channel_db.add_channel_update_for_private_channel(payload, start_node_id, cache_ttl=cache_ttl)
1✔
1956
            blacklist = not update
1✔
1957
        elif r == UpdateStatus.EXPIRED:
×
1958
            blacklist = True
×
1959
        elif r == UpdateStatus.DEPRECATED:
×
1960
            self.logger.info(f'channel update is not more recent.')
×
1961
            blacklist = True
×
1962
        elif r == UpdateStatus.UNCHANGED:
×
1963
            blacklist = True
×
1964
        return blacklist, update
1✔
1965

1966
    @classmethod
1✔
1967
    def _decode_channel_update_msg(cls, chan_upd_msg: bytes) -> Optional[Dict[str, Any]]:
1✔
1968
        channel_update_as_received = chan_upd_msg
1✔
1969
        channel_update_typed = (258).to_bytes(length=2, byteorder="big") + channel_update_as_received
1✔
1970
        # note: some nodes put channel updates in error msgs with the leading msg_type already there.
1971
        #       we try decoding both ways here.
1972
        try:
1✔
1973
            message_type, payload = decode_msg(channel_update_typed)
1✔
1974
            if payload['chain_hash'] != constants.net.rev_genesis_bytes(): raise Exception()
1✔
1975
            payload['raw'] = channel_update_typed
1✔
1976
            return payload
1✔
1977
        except Exception:  # FIXME: too broad
1✔
1978
            try:
1✔
1979
                message_type, payload = decode_msg(channel_update_as_received)
1✔
1980
                if payload['chain_hash'] != constants.net.rev_genesis_bytes(): raise Exception()
1✔
1981
                payload['raw'] = channel_update_as_received
1✔
1982
                return payload
1✔
1983
            except Exception:
1✔
1984
                return None
1✔
1985

1986
    def _check_bolt11_invoice(self, bolt11_invoice: str, *, amount_msat: int = None) -> LnAddr:
1✔
1987
        """Parses and validates a bolt11 invoice str into a LnAddr.
1988
        Includes pre-payment checks external to the parser.
1989
        """
1990
        addr = lndecode(bolt11_invoice)
1✔
1991
        if addr.is_expired():
1✔
1992
            raise InvoiceError(_("This invoice has expired"))
×
1993
        # check amount
1994
        if amount_msat:  # replace amt in invoice. main usecase is paying zero amt invoices
1✔
1995
            existing_amt_msat = addr.get_amount_msat()
×
1996
            if existing_amt_msat and amount_msat < existing_amt_msat:
×
1997
                raise Exception("cannot pay lower amt than what is originally in LN invoice")
×
1998
            addr.amount = Decimal(amount_msat) / COIN / 1000
×
1999
        if addr.amount is None:
1✔
2000
            raise InvoiceError(_("Missing amount"))
×
2001
        # check cltv
2002
        if addr.get_min_final_cltv_delta() > NBLOCK_CLTV_DELTA_TOO_FAR_INTO_FUTURE:
1✔
2003
            raise InvoiceError("{}\n{}".format(
1✔
2004
                _("Invoice wants us to risk locking funds for unreasonably long."),
2005
                f"min_final_cltv_delta: {addr.get_min_final_cltv_delta()}"))
2006
        # check features
2007
        addr.validate_and_compare_features(self.features)
1✔
2008
        return addr
1✔
2009

2010
    def is_trampoline_peer(self, node_id: bytes) -> bool:
1✔
2011
        # until trampoline is advertised in lnfeatures, check against hardcoded list
2012
        if is_hardcoded_trampoline(node_id):
1✔
2013
            return True
1✔
2014
        peer = self._peers.get(node_id)
×
2015
        if not peer:
×
2016
            return False
×
2017
        return (peer.their_features.supports(LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ECLAIR)
×
2018
                or peer.their_features.supports(LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ELECTRUM))
2019

2020
    def suggest_peer(self) -> Optional[bytes]:
1✔
2021
        if not self.uses_trampoline():
×
2022
            return self.lnrater.suggest_peer()
×
2023
        else:
2024
            return random.choice(list(hardcoded_trampoline_nodes().values())).pubkey
×
2025

2026
    def suggest_payment_splits(
1✔
2027
        self,
2028
        *,
2029
        amount_msat: int,
2030
        final_total_msat: int,
2031
        my_active_channels: Sequence[Channel],
2032
        invoice_features: LnFeatures,
2033
        r_tags: Sequence[Sequence[Sequence[Any]]],
2034
        receiver_pubkey: bytes,
2035
    ) -> List['SplitConfigRating']:
2036
        channels_with_funds = {
1✔
2037
            (chan.channel_id, chan.node_id): ( int(chan.available_to_spend(HTLCOwner.LOCAL)), chan.htlc_slots_left(HTLCOwner.LOCAL))
2038
            for chan in my_active_channels
2039
        }
2040
        # if we have a direct channel it's preferable to send a single part directly through this
2041
        # channel, so this bool will disable excluding single part payments
2042
        have_direct_channel = any(chan.node_id == receiver_pubkey for chan in my_active_channels)
1✔
2043
        self.logger.info(f"channels_with_funds: {channels_with_funds}, {have_direct_channel=}")
1✔
2044
        exclude_single_part_payments = False
1✔
2045
        if self.uses_trampoline():
1✔
2046
            # in the case of a legacy payment, we don't allow splitting via different
2047
            # trampoline nodes, because of https://github.com/ACINQ/eclair/issues/2127
2048
            is_legacy, _ = is_legacy_relay(invoice_features, r_tags)
1✔
2049
            exclude_multinode_payments = is_legacy
1✔
2050
            # we don't split within a channel when sending to a trampoline node,
2051
            # the trampoline node will split for us
2052
            exclude_single_channel_splits = not self.config.TEST_FORCE_MPP
1✔
2053
        else:
2054
            exclude_multinode_payments = False
1✔
2055
            exclude_single_channel_splits = False
1✔
2056
            if invoice_features.supports(LnFeatures.BASIC_MPP_OPT) and not self.config.TEST_FORCE_DISABLE_MPP:
1✔
2057
                # if amt is still large compared to total_msat, split it:
2058
                if (amount_msat / final_total_msat > self.MPP_SPLIT_PART_FRACTION
1✔
2059
                        and amount_msat > self.MPP_SPLIT_PART_MINAMT_MSAT
2060
                        and not have_direct_channel):
2061
                    exclude_single_part_payments = True
×
2062

2063
        split_configurations = suggest_splits(
1✔
2064
            amount_msat,
2065
            channels_with_funds,
2066
            exclude_single_part_payments=exclude_single_part_payments,
2067
            exclude_multinode_payments=exclude_multinode_payments,
2068
            exclude_single_channel_splits=exclude_single_channel_splits
2069
        )
2070

2071
        self.logger.info(f'suggest_split {amount_msat} returned {len(split_configurations)} configurations')
1✔
2072
        return split_configurations
1✔
2073

2074
    async def create_routes_for_payment(
1✔
2075
            self, *,
2076
            paysession: PaySession,
2077
            amount_msat: int,        # part of payment amount we want routes for now
2078
            fwd_trampoline_onion: OnionPacket = None,
2079
            full_path: LNPaymentPath = None,
2080
            channels: Optional[Sequence[Channel]] = None,
2081
            budget: PaymentFeeBudget,
2082
    ) -> AsyncGenerator[Tuple[SentHtlcInfo, int, Optional[OnionPacket]], None]:
2083

2084
        """Creates multiple routes for splitting a payment over the available
2085
        private channels.
2086

2087
        We first try to conduct the payment over a single channel. If that fails
2088
        and mpp is supported by the receiver, we will split the payment."""
2089
        trampoline_features = LnFeatures.VAR_ONION_OPT
1✔
2090
        local_height = self.wallet.adb.get_local_height()
1✔
2091
        fee_related_error = None  # type: Optional[FeeBudgetExceeded]
1✔
2092
        if channels:
1✔
2093
            my_active_channels = channels
×
2094
        else:
2095
            my_active_channels = [
1✔
2096
                chan for chan in self.channels.values() if
2097
                chan.is_active() and not chan.is_frozen_for_sending()]
2098
        # try random order
2099
        random.shuffle(my_active_channels)
1✔
2100
        split_configurations = self.suggest_payment_splits(
1✔
2101
            amount_msat=amount_msat,
2102
            final_total_msat=paysession.amount_to_pay,
2103
            my_active_channels=my_active_channels,
2104
            invoice_features=paysession.invoice_features,
2105
            r_tags=paysession.r_tags,
2106
            receiver_pubkey=paysession.invoice_pubkey,
2107
        )
2108
        for sc in split_configurations:
1✔
2109
            is_multichan_mpp = len(sc.config.items()) > 1
1✔
2110
            is_mpp = sc.config.number_parts() > 1
1✔
2111
            if is_mpp and not paysession.invoice_features.supports(LnFeatures.BASIC_MPP_OPT):
1✔
2112
                continue
1✔
2113
            if not is_mpp and self.config.TEST_FORCE_MPP:
1✔
2114
                continue
1✔
2115
            if is_mpp and self.config.TEST_FORCE_DISABLE_MPP:
1✔
2116
                continue
×
2117
            self.logger.info(f"trying split configuration: {sc.config.values()} rating: {sc.rating}")
1✔
2118
            routes = []
1✔
2119
            try:
1✔
2120
                if self.uses_trampoline():
1✔
2121
                    per_trampoline_channel_amounts = defaultdict(list)
1✔
2122
                    # categorize by trampoline nodes for trampoline mpp construction
2123
                    for (chan_id, _), part_amounts_msat in sc.config.items():
1✔
2124
                        chan = self.channels[chan_id]
1✔
2125
                        for part_amount_msat in part_amounts_msat:
1✔
2126
                            per_trampoline_channel_amounts[chan.node_id].append((chan_id, part_amount_msat))
1✔
2127
                    # for each trampoline forwarder, construct mpp trampoline
2128
                    for trampoline_node_id, trampoline_parts in per_trampoline_channel_amounts.items():
1✔
2129
                        per_trampoline_amount = sum([x[1] for x in trampoline_parts])
1✔
2130
                        trampoline_route, trampoline_onion, per_trampoline_amount_with_fees, per_trampoline_cltv_delta = create_trampoline_route_and_onion(
1✔
2131
                            amount_msat=per_trampoline_amount,
2132
                            total_msat=paysession.amount_to_pay,
2133
                            min_final_cltv_delta=paysession.min_final_cltv_delta,
2134
                            my_pubkey=self.node_keypair.pubkey,
2135
                            invoice_pubkey=paysession.invoice_pubkey,
2136
                            invoice_features=paysession.invoice_features,
2137
                            node_id=trampoline_node_id,
2138
                            r_tags=paysession.r_tags,
2139
                            payment_hash=paysession.payment_hash,
2140
                            payment_secret=paysession.payment_secret,
2141
                            local_height=local_height,
2142
                            trampoline_fee_level=paysession.trampoline_fee_level,
2143
                            use_two_trampolines=paysession.use_two_trampolines,
2144
                            failed_routes=paysession.failed_trampoline_routes,
2145
                            budget=budget._replace(fee_msat=budget.fee_msat // len(per_trampoline_channel_amounts)),
2146
                        )
2147
                        # node_features is only used to determine is_tlv
2148
                        per_trampoline_secret = os.urandom(32)
1✔
2149
                        per_trampoline_fees = per_trampoline_amount_with_fees - per_trampoline_amount
1✔
2150
                        self.logger.info(f'created route with trampoline fee level={paysession.trampoline_fee_level}')
1✔
2151
                        self.logger.info(f'trampoline hops: {[hop.end_node.hex() for hop in trampoline_route]}')
1✔
2152
                        self.logger.info(f'per trampoline fees: {per_trampoline_fees}')
1✔
2153
                        for chan_id, part_amount_msat in trampoline_parts:
1✔
2154
                            chan = self.channels[chan_id]
1✔
2155
                            margin = chan.available_to_spend(LOCAL) - part_amount_msat
1✔
2156
                            delta_fee = min(per_trampoline_fees, margin)
1✔
2157
                            # TODO: distribute trampoline fee over several channels?
2158
                            part_amount_msat_with_fees = part_amount_msat + delta_fee
1✔
2159
                            per_trampoline_fees -= delta_fee
1✔
2160
                            route = [
1✔
2161
                                RouteEdge(
2162
                                    start_node=self.node_keypair.pubkey,
2163
                                    end_node=trampoline_node_id,
2164
                                    short_channel_id=chan.short_channel_id,
2165
                                    fee_base_msat=0,
2166
                                    fee_proportional_millionths=0,
2167
                                    cltv_delta=0,
2168
                                    node_features=trampoline_features)
2169
                            ]
2170
                            self.logger.info(f'adding route {part_amount_msat} {delta_fee} {margin}')
1✔
2171
                            shi = SentHtlcInfo(
1✔
2172
                                route=route,
2173
                                payment_secret_orig=paysession.payment_secret,
2174
                                payment_secret_bucket=per_trampoline_secret,
2175
                                amount_msat=part_amount_msat_with_fees,
2176
                                bucket_msat=per_trampoline_amount_with_fees,
2177
                                amount_receiver_msat=part_amount_msat,
2178
                                trampoline_fee_level=paysession.trampoline_fee_level,
2179
                                trampoline_route=trampoline_route,
2180
                            )
2181
                            routes.append((shi, per_trampoline_cltv_delta, trampoline_onion))
1✔
2182
                        if per_trampoline_fees != 0:
1✔
2183
                            e = 'not enough margin to pay trampoline fee'
×
2184
                            self.logger.info(e)
×
2185
                            raise FeeBudgetExceeded(e)
×
2186
                else:
2187
                    # We atomically loop through a split configuration. If there was
2188
                    # a failure to find a path for a single part, we try the next configuration
2189
                    for (chan_id, _), part_amounts_msat in sc.config.items():
1✔
2190
                        for part_amount_msat in part_amounts_msat:
1✔
2191
                            channel = self.channels[chan_id]
1✔
2192
                            route = await run_in_thread(
1✔
2193
                                partial(
2194
                                    self.create_route_for_single_htlc,
2195
                                    amount_msat=part_amount_msat,
2196
                                    invoice_pubkey=paysession.invoice_pubkey,
2197
                                    min_final_cltv_delta=paysession.min_final_cltv_delta,
2198
                                    r_tags=paysession.r_tags,
2199
                                    invoice_features=paysession.invoice_features,
2200
                                    my_sending_channels=[channel] if is_multichan_mpp else my_active_channels,
2201
                                    full_path=full_path,
2202
                                    budget=budget._replace(fee_msat=budget.fee_msat // sc.config.number_parts()),
2203
                                )
2204
                            )
2205
                            shi = SentHtlcInfo(
1✔
2206
                                route=route,
2207
                                payment_secret_orig=paysession.payment_secret,
2208
                                payment_secret_bucket=paysession.payment_secret,
2209
                                amount_msat=part_amount_msat,
2210
                                bucket_msat=paysession.amount_to_pay,
2211
                                amount_receiver_msat=part_amount_msat,
2212
                                trampoline_fee_level=None,
2213
                                trampoline_route=None,
2214
                            )
2215
                            routes.append((shi, paysession.min_final_cltv_delta, fwd_trampoline_onion))
1✔
2216
            except NoPathFound:
1✔
2217
                continue
1✔
2218
            except FeeBudgetExceeded as e:
1✔
2219
                fee_related_error = e
×
2220
                continue
×
2221
            for route in routes:
1✔
2222
                yield route
1✔
2223
            return
1✔
2224
        if fee_related_error is not None:
1✔
2225
            raise fee_related_error
×
2226
        raise NoPathFound()
1✔
2227

2228
    @profiler
1✔
2229
    def create_route_for_single_htlc(
1✔
2230
            self, *,
2231
            amount_msat: int,  # that final receiver gets
2232
            invoice_pubkey: bytes,
2233
            min_final_cltv_delta: int,
2234
            r_tags,
2235
            invoice_features: int,
2236
            my_sending_channels: List[Channel],
2237
            full_path: Optional[LNPaymentPath],
2238
            budget: PaymentFeeBudget,
2239
    ) -> LNPaymentRoute:
2240

2241
        my_sending_aliases = set(chan.get_local_scid_alias() for chan in my_sending_channels)
1✔
2242
        my_sending_channels = {chan.short_channel_id: chan for chan in my_sending_channels
1✔
2243
            if chan.short_channel_id is not None}
2244
        # Collect all private edges from route hints.
2245
        # Note: if some route hints are multiple edges long, and these paths cross each other,
2246
        #       we allow our path finding to cross the paths; i.e. the route hints are not isolated.
2247
        private_route_edges = {}  # type: Dict[ShortChannelID, RouteEdge]
1✔
2248
        for private_path in r_tags:
1✔
2249
            # we need to shift the node pubkey by one towards the destination:
2250
            private_path_nodes = [edge[0] for edge in private_path][1:] + [invoice_pubkey]
1✔
2251
            private_path_rest = [edge[1:] for edge in private_path]
1✔
2252
            start_node = private_path[0][0]
1✔
2253
            # remove aliases from direct routes
2254
            if len(private_path) == 1 and private_path[0][1] in my_sending_aliases:
1✔
2255
                self.logger.info(f'create_route: skipping alias {ShortChannelID(private_path[0][1])}')
×
2256
                continue
×
2257
            for end_node, edge_rest in zip(private_path_nodes, private_path_rest):
1✔
2258
                short_channel_id, fee_base_msat, fee_proportional_millionths, cltv_delta = edge_rest
1✔
2259
                short_channel_id = ShortChannelID(short_channel_id)
1✔
2260
                if (our_chan := self.get_channel_by_short_id(short_channel_id)) is not None:
1✔
2261
                    # check if the channel is one of our channels and frozen for sending
2262
                    if our_chan.is_frozen_for_sending():
1✔
2263
                        continue
×
2264
                # if we have a routing policy for this edge in the db, that takes precedence,
2265
                # as it is likely from a previous failure
2266
                channel_policy = self.channel_db.get_policy_for_node(
1✔
2267
                    short_channel_id=short_channel_id,
2268
                    node_id=start_node,
2269
                    my_channels=my_sending_channels)
2270
                if channel_policy:
1✔
2271
                    fee_base_msat = channel_policy.fee_base_msat
1✔
2272
                    fee_proportional_millionths = channel_policy.fee_proportional_millionths
1✔
2273
                    cltv_delta = channel_policy.cltv_delta
1✔
2274
                node_info = self.channel_db.get_node_info_for_node_id(node_id=end_node)
1✔
2275
                route_edge = RouteEdge(
1✔
2276
                        start_node=start_node,
2277
                        end_node=end_node,
2278
                        short_channel_id=short_channel_id,
2279
                        fee_base_msat=fee_base_msat,
2280
                        fee_proportional_millionths=fee_proportional_millionths,
2281
                        cltv_delta=cltv_delta,
2282
                        node_features=node_info.features if node_info else 0)
2283
                private_route_edges[route_edge.short_channel_id] = route_edge
1✔
2284
                start_node = end_node
1✔
2285
        # now find a route, end to end: between us and the recipient
2286
        try:
1✔
2287
            route = self.network.path_finder.find_route(
1✔
2288
                nodeA=self.node_keypair.pubkey,
2289
                nodeB=invoice_pubkey,
2290
                invoice_amount_msat=amount_msat,
2291
                path=full_path,
2292
                my_sending_channels=my_sending_channels,
2293
                private_route_edges=private_route_edges)
2294
        except NoChannelPolicy as e:
1✔
2295
            raise NoPathFound() from e
×
2296
        if not route:
1✔
2297
            raise NoPathFound()
1✔
2298
        if not is_route_within_budget(
1✔
2299
            route, budget=budget, amount_msat_for_dest=amount_msat, cltv_delta_for_dest=min_final_cltv_delta,
2300
        ):
2301
            self.logger.info(f"rejecting route (exceeds budget): {route=}. {budget=}")
×
2302
            raise FeeBudgetExceeded()
×
2303
        assert len(route) > 0
1✔
2304
        if route[-1].end_node != invoice_pubkey:
1✔
2305
            raise LNPathInconsistent("last node_id != invoice pubkey")
1✔
2306
        # add features from invoice
2307
        route[-1].node_features |= invoice_features
1✔
2308
        return route
1✔
2309

2310
    def clear_invoices_cache(self):
1✔
2311
        self._bolt11_cache.clear()
×
2312

2313
    def get_bolt11_invoice(
1✔
2314
            self, *,
2315
            payment_info: PaymentInfo,
2316
            message: str,
2317
            fallback_address: Optional[str],
2318
            channels: Optional[Sequence[Channel]] = None,
2319
    ) -> Tuple[LnAddr, str]:
2320
        amount_msat = payment_info.amount_msat
1✔
2321
        pair = self._bolt11_cache.get(payment_info.payment_hash)
1✔
2322
        if pair:
1✔
2323
            lnaddr, invoice = pair
×
2324
            assert lnaddr.get_amount_msat() == amount_msat
×
2325
            return pair
×
2326

2327
        assert amount_msat is None or amount_msat > 0
1✔
2328
        timestamp = int(time.time())
1✔
2329
        needs_jit: bool = self.receive_requires_jit_channel(amount_msat)
1✔
2330
        routing_hints = self.calc_routing_hints_for_invoice(amount_msat, channels=channels, needs_jit=needs_jit)
1✔
2331
        self.logger.info(f"creating bolt11 invoice with routing_hints: {routing_hints}, jit: {needs_jit}, sat: {(amount_msat or 0) // 1000}")
1✔
2332
        invoice_features = self.features.for_invoice()
1✔
2333
        if not self.uses_trampoline():
1✔
2334
            invoice_features &= ~ LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ELECTRUM
×
2335
        if needs_jit:
1✔
2336
            # jit only works with single htlcs, mpp will cause LSP to open channels for each htlc
2337
            invoice_features &= ~ LnFeatures.BASIC_MPP_OPT & ~ LnFeatures.BASIC_MPP_REQ
×
2338
        payment_secret = self.get_payment_secret(payment_info.payment_hash)
1✔
2339
        amount_btc = amount_msat/Decimal(COIN*1000) if amount_msat else None
1✔
2340
        min_final_cltv_delta = payment_info.min_final_cltv_delta + MIN_FINAL_CLTV_DELTA_BUFFER_INVOICE
1✔
2341
        lnaddr = LnAddr(
1✔
2342
            paymenthash=payment_info.payment_hash,
2343
            amount=amount_btc,
2344
            tags=[
2345
                ('d', message),
2346
                ('c', min_final_cltv_delta),
2347
                ('x', payment_info.expiry_delay),
2348
                ('9', invoice_features),
2349
                ('f', fallback_address),
2350
            ] + routing_hints,
2351
            date=timestamp,
2352
            payment_secret=payment_secret)
2353
        invoice = lnencode(lnaddr, self.node_keypair.privkey)
1✔
2354
        pair = lnaddr, invoice
1✔
2355
        self._bolt11_cache[payment_info.payment_hash] = pair
1✔
2356
        return pair
1✔
2357

2358
    def get_payment_secret(self, payment_hash):
1✔
2359
        return sha256(sha256(self.payment_secret_key) + payment_hash)
1✔
2360

2361
    def _get_payment_key(self, payment_hash: bytes) -> bytes:
1✔
2362
        """Return payment bucket key.
2363
        We bucket htlcs based on payment_hash+payment_secret. payment_secret is included
2364
        as it changes over a trampoline path (in the outer onion), and these paths can overlap.
2365
        """
2366
        payment_secret = self.get_payment_secret(payment_hash)
1✔
2367
        return payment_hash + payment_secret
1✔
2368

2369
    def create_payment_info(
1✔
2370
        self, *,
2371
        amount_msat: Optional[int],
2372
        min_final_cltv_delta: Optional[int] = None,
2373
        exp_delay: int = LN_EXPIRY_NEVER,
2374
        write_to_disk=True
2375
    ) -> bytes:
2376
        payment_preimage = os.urandom(32)
1✔
2377
        payment_hash = sha256(payment_preimage)
1✔
2378
        min_final_cltv_delta = min_final_cltv_delta or MIN_FINAL_CLTV_DELTA_ACCEPTED
1✔
2379
        info = PaymentInfo(
1✔
2380
            payment_hash=payment_hash,
2381
            amount_msat=amount_msat,
2382
            direction=RECEIVED,
2383
            status=PR_UNPAID,
2384
            min_final_cltv_delta=min_final_cltv_delta,
2385
            expiry_delay=exp_delay
2386
        )
2387
        self.save_preimage(payment_hash, payment_preimage, write_to_disk=False)
1✔
2388
        self.save_payment_info(info, write_to_disk=False)
1✔
2389
        if write_to_disk:
1✔
2390
            self.wallet.save_db()
×
2391
        return payment_hash
1✔
2392

2393
    def bundle_payments(self, hash_list: Sequence[bytes]) -> None:
1✔
2394
        """Bundle together a list of payment_hashes, for atomicity, so that either
2395
        - all gets fulfilled, or
2396
        - none of them gets fulfilled.
2397
        (we are the recipient of this payment)
2398
        """
2399
        payment_keys = [self._get_payment_key(x) for x in hash_list]
1✔
2400
        with self.lock:
1✔
2401
            # We maintain two maps.
2402
            #   map1: payment_key -> bundle_key=canon_pkey (canonically smallest among pkeys)
2403
            #   map2: bundle_key -> list of pkeys in bundle
2404
            # assumption: bundles are immutable, so no adding extra pkeys after-the-fact
2405
            canon_pkey = min(payment_keys)
1✔
2406
            for pkey in payment_keys:
1✔
2407
                assert pkey not in self._payment_bundles_pkey_to_canon
1✔
2408
            for pkey in payment_keys:
1✔
2409
                self._payment_bundles_pkey_to_canon[pkey] = canon_pkey
1✔
2410
            self._payment_bundles_canon_to_pkeylist[canon_pkey] = tuple(payment_keys)
1✔
2411

2412
    def get_payment_bundle(self, payment_key: Union[bytes, str]) -> Sequence[bytes]:
1✔
2413
        with self.lock:
1✔
2414
            if isinstance(payment_key, str):
1✔
2415
                try:
1✔
2416
                    payment_key = bytes.fromhex(payment_key)
1✔
NEW
2417
                except ValueError:
×
2418
                    # might be a forwarding payment_key which is not hex and will never have a bundle
NEW
2419
                    return []
×
2420
            canon_pkey =  self._payment_bundles_pkey_to_canon.get(payment_key)
1✔
2421
            if canon_pkey is None:
1✔
2422
                return []
1✔
2423
            return self._payment_bundles_canon_to_pkeylist[canon_pkey]
1✔
2424

2425
    def is_payment_bundle_complete(self, any_payment_key: str) -> bool:
1✔
2426
        """
2427
        complete means a htlc set is available for each payment key of the payment bundle and
2428
        all htlc sets have a resolution >= COMPLETE (we got the whole payment bundle amount)
2429
        """
2430
        # get all payment keys covered by this bundle
2431
        bundle_payment_keys = self.get_payment_bundle(any_payment_key)
1✔
2432
        if not bundle_payment_keys:  # there is no payment bundle
1✔
2433
            return True
1✔
2434
        for payment_key in bundle_payment_keys:
1✔
2435
            mpp_set = self.received_mpp_htlcs.get(payment_key.hex())
1✔
2436
            if mpp_set is None:
1✔
2437
                # payment bundle is missing htlc set for payment request
2438
                # it might have already been failed and deleted
2439
                return False
1✔
2440
            elif mpp_set.resolution not in (RecvMPPResolution.COMPLETE, RecvMPPResolution.SETTLING):
1✔
2441
                return False
1✔
2442
        return True
1✔
2443

2444
    def delete_payment_bundle(
1✔
2445
        self, *,
2446
        payment_hash: Optional[bytes] = None,
2447
        payment_key: Optional[bytes] = None,
2448
    ) -> None:
2449
        assert (payment_hash is not None) ^ (payment_key is not None), \
1✔
2450
                    "must provide exactly one of (payment_hash, payment_key)"
2451
        if not payment_key:
1✔
NEW
2452
            payment_key = self._get_payment_key(payment_hash)
×
2453
        with self.lock:
1✔
2454
            canon_pkey = self._payment_bundles_pkey_to_canon.get(payment_key)
1✔
2455
            if canon_pkey is None:  # is it ok for bundle to be missing??
1✔
2456
                return
1✔
2457
            pkey_list = self._payment_bundles_canon_to_pkeylist[canon_pkey]
1✔
2458
            for pkey in pkey_list:
1✔
2459
                del self._payment_bundles_pkey_to_canon[pkey]
1✔
2460
            del self._payment_bundles_canon_to_pkeylist[canon_pkey]
1✔
2461

2462
    def save_preimage(self, payment_hash: bytes, preimage: bytes, *, write_to_disk: bool = True):
1✔
2463
        if sha256(preimage) != payment_hash:
1✔
2464
            raise Exception("tried to save incorrect preimage for payment_hash")
×
2465
        if self._preimages.get(payment_hash.hex()) is not None:
1✔
2466
            return  # we already have this preimage
1✔
2467
        self.logger.debug(f"saving preimage for {payment_hash.hex()}")
1✔
2468
        self._preimages[payment_hash.hex()] = preimage.hex()
1✔
2469
        if write_to_disk:
1✔
2470
            self.wallet.save_db()
1✔
2471

2472
    def get_preimage(self, payment_hash: bytes) -> Optional[bytes]:
1✔
2473
        assert isinstance(payment_hash, bytes), f"expected bytes, but got {type(payment_hash)}"
1✔
2474
        preimage_hex = self._preimages.get(payment_hash.hex())
1✔
2475
        if preimage_hex is None:
1✔
2476
            return None
1✔
2477
        preimage_bytes = bytes.fromhex(preimage_hex)
1✔
2478
        if sha256(preimage_bytes) != payment_hash:
1✔
2479
            raise Exception("found incorrect preimage for payment_hash")
×
2480
        return preimage_bytes
1✔
2481

2482
    def get_preimage_hex(self, payment_hash: str) -> Optional[str]:
1✔
2483
        preimage_bytes = self.get_preimage(bytes.fromhex(payment_hash)) or b""
1✔
2484
        return preimage_bytes.hex() or None
1✔
2485

2486
    def get_payment_info(self, payment_hash: bytes) -> Optional[PaymentInfo]:
1✔
2487
        """returns None if payment_hash is a payment we are forwarding"""
2488
        key = payment_hash.hex()
1✔
2489
        with self.lock:
1✔
2490
            if key in self.payment_info:
1✔
2491
                stored_tuple = self.payment_info[key]
1✔
2492
                amount_msat, direction, status, min_final_cltv_delta, expiry_delay, creation_ts = stored_tuple
1✔
2493
                return PaymentInfo(
1✔
2494
                    payment_hash=payment_hash,
2495
                    amount_msat=amount_msat,
2496
                    direction=direction,
2497
                    status=status,
2498
                    min_final_cltv_delta=min_final_cltv_delta,
2499
                    expiry_delay=expiry_delay,
2500
                    creation_ts=creation_ts,
2501
                )
2502
            return None
1✔
2503

2504
    def add_payment_info_for_hold_invoice(
1✔
2505
        self,
2506
        payment_hash: bytes, *,
2507
        lightning_amount_sat: Optional[int],
2508
        min_final_cltv_delta: int,
2509
        exp_delay: int,
2510
    ):
2511
        amount = lightning_amount_sat * 1000 if lightning_amount_sat else None
1✔
2512
        info = PaymentInfo(
1✔
2513
            payment_hash=payment_hash,
2514
            amount_msat=amount,
2515
            direction=RECEIVED,
2516
            status=PR_UNPAID,
2517
            min_final_cltv_delta=min_final_cltv_delta,
2518
            expiry_delay=exp_delay,
2519
        )
2520
        self.save_payment_info(info, write_to_disk=False)
1✔
2521

2522
    def register_hold_invoice(self, payment_hash: bytes, cb: Callable[[bytes], Awaitable[None]]):
1✔
2523
        assert self.get_preimage(payment_hash) is None, "hold invoice cb won't get called if preimage is already set"
1✔
2524
        self.hold_invoice_callbacks[payment_hash] = cb
1✔
2525

2526
    def unregister_hold_invoice(self, payment_hash: bytes):
1✔
2527
        self.hold_invoice_callbacks.pop(payment_hash, None)
1✔
2528
        payment_key = self._get_payment_key(payment_hash).hex()
1✔
2529
        if payment_key in self.received_mpp_htlcs:
1✔
2530
            if self.get_preimage(payment_hash) is None:
1✔
2531
                # the pending mpp set can be failed as we don't have the preimage to settle it
2532
                self.set_mpp_resolution(payment_key, RecvMPPResolution.FAILED)
1✔
2533

2534
    def save_payment_info(self, info: PaymentInfo, *, write_to_disk: bool = True) -> None:
1✔
2535
        assert info.status in SAVED_PR_STATUS
1✔
2536
        with self.lock:
1✔
2537
            if old_info := self.get_payment_info(payment_hash=info.payment_hash):
1✔
2538
                if info == old_info:
1✔
2539
                    return  # already saved
1✔
2540
                if info.direction == SENT:
1✔
2541
                    # allow saving of newer PaymentInfo if it is a sending attempt
2542
                    old_info = dataclasses.replace(old_info, creation_ts=info.creation_ts)
1✔
2543
                if info != dataclasses.replace(old_info, status=info.status):
1✔
2544
                    # differs more than in status. let's fail
2545
                    raise Exception(f"payment_hash already in use: {info=} != {old_info=}")
×
2546
            key = info.payment_hash.hex()
1✔
2547
            self.payment_info[key] = dataclasses.astuple(info)[1:]  # drop the payment hash at index 0
1✔
2548
        if write_to_disk:
1✔
2549
            self.wallet.save_db()
1✔
2550

2551
    def update_or_create_mpp_with_received_htlc(
1✔
2552
        self,
2553
        *,
2554
        payment_key: str,
2555
        scid: ShortChannelID,
2556
        htlc: UpdateAddHtlc,
2557
        unprocessed_onion_packet: str,
2558
    ):
2559
        # Payment key creation:
2560
        #   * for regular forwarded htlcs -> "scid.hex() + ':%d' % htlc_id" [htlc key]
2561
        #   * for trampoline forwarding -> "payment hash + payment secret from outer onion"
2562
        #   * for final non-trampoline htlcs (we are receiver) -> "payment hash + payment secret from onion"
2563
        #   * for final trampoline htlcs (we are receiver) -> 2. step grouping:
2564
        #           1. grouping of htlcs by "payments hash + outer onion payment secret", a 'multi-trampoline mpp part'.
2565
        #           2. once the set of step 1. is COMPLETE (amount_fwd outer onion >= total_amt outer onion)
2566
        #              the htlcs get moved to the parent mpp set (created once first part is complete) grouped by:
2567
        #              "payment_hash + inner onion payment secret (the one in the invoice)"
2568
        #              After moving the htlcs the first set gets deleted.
2569
        #
2570
        # Add the validated htlc to the htlc set associated with the payment key.
2571
        # If no set exists, a new set in WAITING state is created.
2572
        mpp_status = self.received_mpp_htlcs.get(payment_key)
1✔
2573
        if mpp_status is None:
1✔
2574
            self.logger.debug(f"creating new mpp set for {payment_key=}")
1✔
2575
            mpp_status = ReceivedMPPStatus(
1✔
2576
                resolution=RecvMPPResolution.WAITING,
2577
                htlcs=set(),
2578
            )
2579

2580
        if mpp_status.resolution > RecvMPPResolution.WAITING:
1✔
2581
            # we are getting a htlc for a set that is not in WAITING state, it cannot be safely added
NEW
2582
            self.logger.info(f"htlc set cannot accept htlc, failing htlc: {scid=} {htlc.htlc_id=}")
×
NEW
2583
            if mpp_status == RecvMPPResolution.EXPIRED:
×
NEW
2584
                raise OnionRoutingFailure(code=OnionFailureCode.MPP_TIMEOUT, data=b'')
×
NEW
2585
            raise OnionRoutingFailure(
×
2586
                code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS,
2587
                data=htlc.amount_msat.to_bytes(8, byteorder="big"),
2588
            )
2589

2590
        new_htlc = ReceivedMPPHtlc(
1✔
2591
            scid=scid,
2592
            htlc=htlc,
2593
            unprocessed_onion=unprocessed_onion_packet,
2594
        )
2595
        assert new_htlc not in mpp_status.htlcs, "each htlc should make it here only once?"
1✔
2596
        assert isinstance(unprocessed_onion_packet, str)
1✔
2597
        mpp_status.htlcs.add(new_htlc)  # side-effecting htlc_set
1✔
2598
        self.received_mpp_htlcs[payment_key] = mpp_status
1✔
2599

2600
    def set_mpp_resolution(self, payment_key: str, new_resolution: RecvMPPResolution) -> ReceivedMPPStatus:
1✔
2601
        mpp_status = self.received_mpp_htlcs[payment_key]
1✔
2602
        if mpp_status.resolution == new_resolution:
1✔
2603
            return mpp_status
1✔
2604
        if not (mpp_status.resolution, new_resolution) in lnutil.allowed_mpp_set_transitions:
1✔
NEW
2605
            raise ValueError(f'forbidden mpp set transition: {mpp_status.resolution} -> {new_resolution}')
×
2606
        self.logger.info(f'set_mpp_resolution {new_resolution.name} {len(mpp_status.htlcs)=}: {payment_key=}')
1✔
2607
        self.received_mpp_htlcs[payment_key] = mpp_status._replace(resolution=new_resolution)
1✔
2608
        self.wallet.save_db()
1✔
2609
        return self.received_mpp_htlcs[payment_key]
1✔
2610

2611
    def set_htlc_set_error(
1✔
2612
        self,
2613
        payment_key: str,
2614
        error: Optional[Union[bytes, OnionFailureCode, OnionRoutingFailure]],
2615
    ) -> Optional[Tuple[Optional[bytes], Optional[OnionFailureCode], Optional[bytes]]]:
2616
        """handles different types of errors and sets the htlc set to failed, then returns a more
2617
        structured tuple of error types which can then be used to fail the htlc set"""
2618
        if error is None:
1✔
2619
            return None
1✔
2620

2621
        htlc_set = self.received_mpp_htlcs[payment_key]
1✔
2622
        assert htlc_set.resolution != RecvMPPResolution.SETTLING
1✔
2623
        raw_error, error_code, error_data = None, None, None
1✔
2624
        if isinstance(error, bytes):
1✔
2625
            raw_error = error
1✔
2626
        elif isinstance(error, OnionFailureCode):
1✔
2627
            error_code = error
1✔
2628
        elif isinstance(error, OnionRoutingFailure):
1✔
2629
            error_code, error_data = OnionFailureCode(error.code), error.data
1✔
2630
        else:
NEW
2631
            raise ValueError(f"invalid error type: {repr(error)}")
×
2632

2633
        if error_code == OnionFailureCode.MPP_TIMEOUT:
1✔
2634
            self.set_mpp_resolution(payment_key=payment_key, new_resolution=RecvMPPResolution.EXPIRED)
1✔
2635
        else:
2636
            self.set_mpp_resolution(payment_key=payment_key, new_resolution=RecvMPPResolution.FAILED)
1✔
2637

2638
        return raw_error, error_code, error_data
1✔
2639

2640
    def get_mpp_resolution(self, payment_hash: bytes) -> Optional[RecvMPPResolution]:
1✔
2641
        payment_key = self._get_payment_key(payment_hash)
1✔
2642
        status = self.received_mpp_htlcs.get(payment_key.hex())
1✔
2643
        return status.resolution if status else None
1✔
2644

2645
    def is_complete_mpp(self, payment_hash: bytes) -> bool:
1✔
2646
        resolution = self.get_mpp_resolution(payment_hash)
1✔
2647
        if resolution is not None:
1✔
2648
            return resolution in (RecvMPPResolution.COMPLETE, RecvMPPResolution.SETTLING)
1✔
2649
        return False
1✔
2650

2651
    def get_payment_mpp_amount_msat(self, payment_hash: bytes) -> Optional[int]:
1✔
2652
        """Returns the received mpp amount for given payment hash."""
2653
        payment_key = self._get_payment_key(payment_hash)
1✔
2654
        total_msat = self.get_mpp_amounts(payment_key)
1✔
2655
        if not total_msat:
1✔
2656
            return None
1✔
2657
        return total_msat
1✔
2658

2659
    def get_mpp_amounts(self, payment_key: bytes) -> Optional[int]:
1✔
2660
        """Returns total received amount or None."""
2661
        mpp_status = self.received_mpp_htlcs.get(payment_key.hex())
1✔
2662
        if not mpp_status:
1✔
2663
            return None
1✔
2664
        total = sum([mpp_htlc.htlc.amount_msat for mpp_htlc in mpp_status.htlcs])
1✔
2665
        return total
1✔
2666

2667
    def maybe_cleanup_mpp(
1✔
2668
            self,
2669
            chan: Channel,
2670
    ) -> None:
2671
        """
2672
        Remove all remaining mpp htlcs of the given channel after closing.
2673
        Usually they get removed in htlc_switch after all htlcs of the set are resolved,
2674
        however if there is a force close with pending htlcs they need to be removed after the channel
2675
        is closed.
2676
        """
2677
        # only cleanup when channel is REDEEMED as mpp set is still required for lnsweep
NEW
2678
        assert chan._state == ChannelState.REDEEMED
×
UNCOV
2679
        for payment_key_hex, mpp_status in list(self.received_mpp_htlcs.items()):
×
NEW
2680
            htlcs_to_remove = [htlc for htlc in mpp_status.htlcs if htlc.scid == chan.short_channel_id]
×
NEW
2681
            for stale_mpp_htlc in htlcs_to_remove:
×
NEW
2682
                assert mpp_status.resolution != RecvMPPResolution.WAITING
×
NEW
2683
                self.logger.info(f'maybe_cleanup_mpp: removing htlc of MPP {payment_key_hex}')
×
NEW
2684
                mpp_status.htlcs.remove(stale_mpp_htlc)  # side-effecting htlc_set
×
NEW
2685
            if len(mpp_status.htlcs) == 0:
×
UNCOV
2686
                self.logger.info(f'maybe_cleanup_mpp: removing mpp {payment_key_hex}')
×
NEW
2687
                del self.received_mpp_htlcs[payment_key_hex]
×
UNCOV
2688
                self.maybe_cleanup_forwarding(payment_key_hex)
×
2689

2690
    def maybe_cleanup_forwarding(self, payment_key_hex: str) -> None:
1✔
2691
        self.active_forwardings.pop(payment_key_hex, None)
1✔
2692
        self.forwarding_failures.pop(payment_key_hex, None)
1✔
2693

2694
    def get_payment_status(self, payment_hash: bytes) -> int:
1✔
2695
        info = self.get_payment_info(payment_hash)
1✔
2696
        return info.status if info else PR_UNPAID
1✔
2697

2698
    def get_invoice_status(self, invoice: BaseInvoice) -> int:
1✔
2699
        invoice_id = invoice.rhash
1✔
2700
        status = self.get_payment_status(bfh(invoice_id))
1✔
2701
        if status == PR_UNPAID and invoice_id in self.inflight_payments:
1✔
2702
            return PR_INFLIGHT
×
2703
        # status may be PR_FAILED
2704
        if status == PR_UNPAID and invoice_id in self.logs:
1✔
2705
            status = PR_FAILED
×
2706
        return status
1✔
2707

2708
    def set_invoice_status(self, key: str, status: int) -> None:
1✔
2709
        if status == PR_INFLIGHT:
1✔
2710
            self.inflight_payments.add(key)
1✔
2711
        elif key in self.inflight_payments:
1✔
2712
            self.inflight_payments.remove(key)
1✔
2713
        if status in SAVED_PR_STATUS:
1✔
2714
            self.set_payment_status(bfh(key), status)
1✔
2715
        util.trigger_callback('invoice_status', self.wallet, key, status)
1✔
2716
        self.logger.info(f"set_invoice_status {key}: {status}")
1✔
2717
        # liquidity changed
2718
        self.clear_invoices_cache()
1✔
2719

2720
    def set_request_status(self, payment_hash: bytes, status: int) -> None:
1✔
2721
        if self.get_payment_status(payment_hash) == status:
1✔
2722
            return
1✔
2723
        self.set_payment_status(payment_hash, status)
1✔
2724
        request_id = payment_hash.hex()
1✔
2725
        req = self.wallet.get_request(request_id)
1✔
2726
        if req is None:
1✔
2727
            return
1✔
2728
        util.trigger_callback('request_status', self.wallet, request_id, status)
1✔
2729

2730
    def set_payment_status(self, payment_hash: bytes, status: int) -> None:
1✔
2731
        info = self.get_payment_info(payment_hash)
1✔
2732
        if info is None:
1✔
2733
            # if we are forwarding
2734
            return
1✔
2735
        info = dataclasses.replace(info, status=status)
1✔
2736
        self.save_payment_info(info)
1✔
2737

2738
    def is_forwarded_htlc(self, htlc_key) -> Optional[str]:
1✔
2739
        """Returns whether this was a forwarded HTLC."""
2740
        for payment_key, htlcs in self.active_forwardings.items():
1✔
2741
            if htlc_key in htlcs:
1✔
2742
                return payment_key
1✔
2743
        return None
1✔
2744

2745
    def notify_upstream_peer(self, htlc_key: str) -> None:
1✔
2746
        """Called when an HTLC we offered on chan gets irrevocably fulfilled or failed.
2747
        If we find this was a forwarded HTLC, the upstream peer is notified.
2748
        """
2749
        upstream_key = self.downstream_to_upstream_htlc.pop(htlc_key, None)
1✔
2750
        if not upstream_key:
1✔
UNCOV
2751
            return
×
2752
        upstream_chan_scid, _ = deserialize_htlc_key(upstream_key)
1✔
2753
        upstream_chan = self.get_channel_by_short_id(upstream_chan_scid)
1✔
2754
        upstream_peer = self.peers.get(upstream_chan.node_id) if upstream_chan else None
1✔
2755
        if upstream_peer:
1✔
2756
            upstream_peer.downstream_htlc_resolved_event.set()
1✔
2757
            upstream_peer.downstream_htlc_resolved_event.clear()
1✔
2758

2759
    def htlc_fulfilled(self, chan: Channel, payment_hash: bytes, htlc_id: int):
1✔
2760

2761
        util.trigger_callback('htlc_fulfilled', payment_hash, chan, htlc_id)
1✔
2762
        htlc_key = serialize_htlc_key(chan.get_scid_or_local_alias(), htlc_id)
1✔
2763
        fw_key = self.is_forwarded_htlc(htlc_key)
1✔
2764
        if fw_key:
1✔
UNCOV
2765
            fw_htlcs = self.active_forwardings[fw_key]
×
UNCOV
2766
            fw_htlcs.remove(htlc_key)
×
2767

2768
        shi = self.sent_htlcs_info.get((payment_hash, chan.short_channel_id, htlc_id))
1✔
2769
        if shi and htlc_id in chan.onion_keys:
1✔
2770
            chan.pop_onion_key(htlc_id)
1✔
2771
            payment_key = payment_hash + shi.payment_secret_orig
1✔
2772
            paysession = self._paysessions[payment_key]
1✔
2773
            q = paysession.sent_htlcs_q
1✔
2774
            htlc_log = HtlcLog(
1✔
2775
                success=True,
2776
                route=shi.route,
2777
                amount_msat=shi.amount_receiver_msat,
2778
                trampoline_fee_level=shi.trampoline_fee_level)
2779
            q.put_nowait(htlc_log)
1✔
2780
            if paysession.can_be_deleted():
1✔
2781
                self._paysessions.pop(payment_key)
1✔
2782
                paysession_active = False
1✔
2783
            else:
2784
                paysession_active = True
1✔
2785
        else:
2786
            if fw_key:
1✔
UNCOV
2787
                paysession_active = False
×
2788
            else:
2789
                key = payment_hash.hex()
1✔
2790
                self.set_invoice_status(key, PR_PAID)
1✔
2791
                util.trigger_callback('payment_succeeded', self.wallet, key)
1✔
2792

2793
        if fw_key:
1✔
UNCOV
2794
            fw_htlcs = self.active_forwardings[fw_key]
×
UNCOV
2795
            if len(fw_htlcs) == 0 and not paysession_active:
×
UNCOV
2796
                self.notify_upstream_peer(htlc_key)
×
2797

2798
    def htlc_failed(
1✔
2799
            self,
2800
            chan: Channel,
2801
            payment_hash: bytes,
2802
            htlc_id: int,
2803
            error_bytes: Optional[bytes],
2804
            failure_message: Optional['OnionRoutingFailure']):
2805
        # note: this may be called several times for the same htlc
2806

2807
        util.trigger_callback('htlc_failed', payment_hash, chan, htlc_id)
1✔
2808
        htlc_key = serialize_htlc_key(chan.get_scid_or_local_alias(), htlc_id)
1✔
2809
        fw_key = self.is_forwarded_htlc(htlc_key)
1✔
2810
        if fw_key:
1✔
2811
            fw_htlcs = self.active_forwardings[fw_key]
1✔
2812
            fw_htlcs.remove(htlc_key)
1✔
2813

2814
        shi = self.sent_htlcs_info.get((payment_hash, chan.short_channel_id, htlc_id))
1✔
2815
        if shi and htlc_id in chan.onion_keys:
1✔
2816
            onion_key = chan.pop_onion_key(htlc_id)
1✔
2817
            payment_okey = payment_hash + shi.payment_secret_orig
1✔
2818
            paysession = self._paysessions[payment_okey]
1✔
2819
            q = paysession.sent_htlcs_q
1✔
2820
            # detect if it is part of a bucket
2821
            # if yes, wait until the bucket completely failed
2822
            route = shi.route
1✔
2823
            if error_bytes:
1✔
2824
                # TODO "decode_onion_error" might raise, catch and maybe blacklist/penalise someone?
2825
                try:
1✔
2826
                    failure_message, sender_idx = decode_onion_error(
1✔
2827
                        error_bytes,
2828
                        [x.node_id for x in route],
2829
                        onion_key)
UNCOV
2830
                except Exception as e:
×
UNCOV
2831
                    sender_idx = None
×
UNCOV
2832
                    failure_message = OnionRoutingFailure(OnionFailureCode.INVALID_ONION_PAYLOAD, str(e).encode())
×
2833
            else:
2834
                # probably got "update_fail_malformed_htlc". well... who to penalise now?
2835
                assert failure_message is not None
×
2836
                sender_idx = None
×
2837
            self.logger.info(f"htlc_failed {failure_message}")
1✔
2838
            amount_receiver_msat = paysession.on_htlc_fail_get_fail_amt_to_propagate(shi)
1✔
2839
            if amount_receiver_msat is None:
1✔
2840
                return
1✔
2841
            if shi.trampoline_route:
1✔
2842
                route = shi.trampoline_route
1✔
2843
            htlc_log = HtlcLog(
1✔
2844
                success=False,
2845
                route=route,
2846
                amount_msat=amount_receiver_msat,
2847
                error_bytes=error_bytes,
2848
                failure_msg=failure_message,
2849
                sender_idx=sender_idx,
2850
                trampoline_fee_level=shi.trampoline_fee_level)
2851
            q.put_nowait(htlc_log)
1✔
2852
            if paysession.can_be_deleted():
1✔
UNCOV
2853
                self._paysessions.pop(payment_okey)
×
UNCOV
2854
                paysession_active = False
×
2855
            else:
2856
                paysession_active = True
1✔
2857
        else:
2858
            if fw_key:
1✔
2859
                paysession_active = False
1✔
2860
            else:
2861
                self.logger.info(f"received unknown htlc_failed, probably from previous session (phash={payment_hash.hex()})")
1✔
2862
                key = payment_hash.hex()
1✔
2863
                invoice = self.wallet.get_invoice(key)
1✔
2864
                if invoice and self.get_invoice_status(invoice) != PR_UNPAID:
1✔
2865
                    self.set_invoice_status(key, PR_UNPAID)
×
2866
                    util.trigger_callback('payment_failed', self.wallet, key, '')
×
2867

2868
        if fw_key:
1✔
2869
            fw_htlcs = self.active_forwardings[fw_key]
1✔
2870
            can_forward_failure = (len(fw_htlcs) == 0) and not paysession_active
1✔
2871
            if can_forward_failure:
1✔
2872
                self.logger.info(f'htlc_failed: save_forwarding_failure (phash={payment_hash.hex()})')
1✔
2873
                self.save_forwarding_failure(fw_key, error_bytes=error_bytes, failure_message=failure_message)
1✔
2874
                self.notify_upstream_peer(htlc_key)
1✔
2875
            else:
2876
                self.logger.info(f'htlc_failed: waiting for other htlcs to fail (phash={payment_hash.hex()})')
1✔
2877

2878
    def calc_routing_hints_for_invoice(self, amount_msat: Optional[int], channels=None, needs_jit=False):
1✔
2879
        """calculate routing hints (BOLT-11 'r' field)"""
2880
        routing_hints = []
1✔
2881
        if needs_jit:
1✔
2882
            node_id, rest = extract_nodeid(self.config.ZEROCONF_TRUSTED_NODE)
×
2883
            alias_or_scid = self.get_static_jit_scid_alias()
×
2884
            routing_hints.append(('r', [(node_id, alias_or_scid, 0, 0, 144)]))
×
2885
            # no need for more because we cannot receive enough through the others and mpp is disabled for jit
2886
            channels = []
×
2887
        else:
2888
            if channels is None:
1✔
2889
                channels = list(self.get_channels_for_receiving(amount_msat=amount_msat, include_disconnected=True))
1✔
2890
                random.shuffle(channels)  # let's not leak channel order
1✔
2891
            scid_to_my_channels = {
1✔
2892
                chan.short_channel_id: chan for chan in channels
2893
                if chan.short_channel_id is not None
2894
            }
2895
        for chan in channels:
1✔
2896
            alias_or_scid = chan.get_remote_scid_alias() or chan.short_channel_id
1✔
2897
            assert isinstance(alias_or_scid, bytes), alias_or_scid
1✔
2898
            channel_info = get_mychannel_info(chan.short_channel_id, scid_to_my_channels)
1✔
2899
            # note: as a fallback, if we don't have a channel update for the
2900
            # incoming direction of our private channel, we fill the invoice with garbage.
2901
            # the sender should still be able to pay us, but will incur an extra round trip
2902
            # (they will get the channel update from the onion error)
2903
            # at least, that's the theory. https://github.com/lightningnetwork/lnd/issues/2066
2904
            fee_base_msat = fee_proportional_millionths = 0
1✔
2905
            cltv_delta = 1  # lnd won't even try with zero
1✔
2906
            missing_info = True
1✔
2907
            if channel_info:
1✔
2908
                policy = get_mychannel_policy(channel_info.short_channel_id, chan.node_id, scid_to_my_channels)
1✔
2909
                if policy:
1✔
2910
                    fee_base_msat = policy.fee_base_msat
1✔
2911
                    fee_proportional_millionths = policy.fee_proportional_millionths
1✔
2912
                    cltv_delta = policy.cltv_delta
1✔
2913
                    missing_info = False
1✔
2914
            if missing_info:
1✔
2915
                self.logger.info(
×
2916
                    f"Warning. Missing channel update for our channel {chan.short_channel_id}; "
2917
                    f"filling invoice with incorrect data.")
2918
            routing_hints.append(('r', [(
1✔
2919
                chan.node_id,
2920
                alias_or_scid,
2921
                fee_base_msat,
2922
                fee_proportional_millionths,
2923
                cltv_delta)]))
2924
        return routing_hints
1✔
2925

2926
    def delete_payment_info(self, payment_hash_hex: str):
1✔
2927
        # This method is called when an invoice or request is deleted by the user.
2928
        # The GUI only lets the user delete invoices or requests that have not been paid.
2929
        # Once an invoice/request has been paid, it is part of the history,
2930
        # and get_lightning_history assumes that payment_info is there.
2931
        assert self.get_payment_status(bytes.fromhex(payment_hash_hex)) != PR_PAID
1✔
2932
        with self.lock:
1✔
2933
            self.payment_info.pop(payment_hash_hex, None)
1✔
2934

2935
    def get_balance(self, *, frozen=False) -> Decimal:
1✔
2936
        with self.lock:
×
2937
            return Decimal(sum(
×
2938
                chan.balance(LOCAL) if not chan.is_closed() and (chan.is_frozen_for_sending() if frozen else True) else 0
2939
                for chan in self.channels.values())) / 1000
2940

2941
    def get_channels_for_sending(self):
1✔
2942
        for c in self.channels.values():
×
2943
            if c.is_active() and not c.is_frozen_for_sending():
×
2944
                if self.channel_db or self.is_trampoline_peer(c.node_id):
×
2945
                    yield c
×
2946

2947
    def fee_estimate(self, amount_sat):
1✔
2948
        # Here we have to guess a fee, because some callers (submarine swaps)
2949
        # use this method to initiate a payment, which would otherwise fail.
2950
        fee_base_msat = 5000               # FIXME ehh.. there ought to be a better way...
×
2951
        fee_proportional_millionths = 500  # FIXME
×
2952
        # inverse of fee_for_edge_msat
2953
        amount_msat = amount_sat * 1000
×
2954
        amount_minus_fees = (amount_msat - fee_base_msat) * 1_000_000 // ( 1_000_000 + fee_proportional_millionths)
×
2955
        return Decimal(amount_msat - amount_minus_fees) / 1000
×
2956

2957
    def num_sats_can_send(self, deltas=None) -> Decimal:
1✔
2958
        """
2959
        without trampoline, sum of all channel capacity
2960
        with trampoline, MPP must use a single trampoline
2961
        """
2962
        if deltas is None:
×
2963
            deltas = {}
×
2964

2965
        def send_capacity(chan):
×
2966
            if chan in deltas:
×
2967
                delta_msat = deltas[chan] * 1000
×
2968
                if delta_msat > chan.available_to_spend(REMOTE):
×
2969
                    delta_msat = 0
×
2970
            else:
2971
                delta_msat = 0
×
2972
            return chan.available_to_spend(LOCAL) + delta_msat
×
2973
        can_send_dict = defaultdict(int)
×
2974
        with self.lock:
×
2975
            for c in self.get_channels_for_sending():
×
2976
                if not self.uses_trampoline():
×
2977
                    can_send_dict[0] += send_capacity(c)
×
2978
                else:
2979
                    can_send_dict[c.node_id] += send_capacity(c)
×
2980
        can_send = max(can_send_dict.values()) if can_send_dict else 0
×
2981
        can_send_sat = Decimal(can_send)/1000
×
2982
        can_send_sat -= self.fee_estimate(can_send_sat)
×
2983
        return max(can_send_sat, 0)
×
2984

2985
    def get_channels_for_receiving(
1✔
2986
        self, *, amount_msat: Optional[int] = None, include_disconnected: bool = False,
2987
    ) -> Sequence[Channel]:
2988
        if not amount_msat:  # assume we want to recv a large amt, e.g. finding max.
1✔
2989
            amount_msat = float('inf')
×
2990
        with self.lock:
1✔
2991
            channels = list(self.channels.values())
1✔
2992
            channels = [chan for chan in channels
1✔
2993
                        if chan.is_open() and not chan.is_frozen_for_receiving()]
2994

2995
            if not include_disconnected:
1✔
2996
                channels = [chan for chan in channels if chan.is_active()]
×
2997

2998
            # Filter out nodes that have low receive capacity compared to invoice amt.
2999
            # Even with MPP, below a certain threshold, including these channels probably
3000
            # hurts more than help, as they lead to many failed attempts for the sender.
3001
            channels = sorted(channels, key=lambda chan: -chan.available_to_spend(REMOTE))
1✔
3002
            selected_channels = []
1✔
3003
            running_sum = 0
1✔
3004
            cutoff_factor = 0.2  # heuristic
1✔
3005
            for chan in channels:
1✔
3006
                recv_capacity = chan.available_to_spend(REMOTE)
1✔
3007
                chan_can_handle_payment_as_single_part = recv_capacity >= amount_msat
1✔
3008
                chan_small_compared_to_running_sum = recv_capacity < cutoff_factor * running_sum
1✔
3009
                if not chan_can_handle_payment_as_single_part and chan_small_compared_to_running_sum:
1✔
3010
                    break
1✔
3011
                running_sum += recv_capacity
1✔
3012
                selected_channels.append(chan)
1✔
3013
            channels = selected_channels
1✔
3014
            del selected_channels
1✔
3015
            # cap max channels to include to keep QR code reasonably scannable
3016
            channels = channels[:10]
1✔
3017
            return channels
1✔
3018

3019
    def num_sats_can_receive(self, deltas=None) -> Decimal:
1✔
3020
        """
3021
        We no longer assume the sender to send MPP on different channels,
3022
        because channel liquidities are hard to guess
3023
        """
3024
        if deltas is None:
×
3025
            deltas = {}
×
3026

3027
        def recv_capacity(chan):
×
3028
            if chan in deltas:
×
3029
                delta_msat = deltas[chan] * 1000
×
3030
                if delta_msat > chan.available_to_spend(LOCAL):
×
3031
                    delta_msat = 0
×
3032
            else:
3033
                delta_msat = 0
×
3034
            return chan.available_to_spend(REMOTE) + delta_msat
×
3035
        with self.lock:
×
3036
            recv_channels = self.get_channels_for_receiving()
×
3037
            recv_chan_msats = [recv_capacity(chan) for chan in recv_channels]
×
3038
        if not recv_chan_msats:
×
3039
            return Decimal(0)
×
3040
        can_receive_msat = max(recv_chan_msats)
×
3041
        return Decimal(can_receive_msat) / 1000
×
3042

3043
    def receive_requires_jit_channel(self, amount_msat: Optional[int]) -> bool:
1✔
3044
        """Returns true if we cannot receive the amount and have set up a trusted LSP node.
3045
        Cannot work reliably with 0 amount invoices as we don't know if we are able to receive it.
3046
        """
3047
        # zeroconf provider is configured and connected
3048
        if (self.can_get_zeroconf_channel()
1✔
3049
                # we cannot receive the amount specified
3050
                and ((amount_msat and self.num_sats_can_receive() < (amount_msat // 1000))
3051
                        # or we cannot receive anything, and it's a 0 amount invoice
3052
                        or (not amount_msat and self.num_sats_can_receive() < 1))):
3053
            return True
×
3054
        return False
1✔
3055

3056
    def can_get_zeroconf_channel(self) -> bool:
1✔
3057
        if not self.config.ACCEPT_ZEROCONF_CHANNELS and self.config.ZEROCONF_TRUSTED_NODE:
1✔
3058
            # check if zeroconf is accepted and client has trusted zeroconf node configured
3059
            return False
×
3060
        try:
1✔
3061
            node_id = extract_nodeid(self.wallet.config.ZEROCONF_TRUSTED_NODE)[0]
1✔
3062
        except ConnStringFormatError:
1✔
3063
            # invalid connection string
3064
            return False
1✔
3065
        # only return True if we are connected to the zeroconf provider
3066
        return node_id in self.peers
×
3067

3068
    def _suggest_channels_for_rebalance(self, direction, amount_sat) -> Sequence[Tuple[Channel, int]]:
1✔
3069
        """
3070
        Suggest a channel and amount to send/receive with that channel, so that we will be able to receive/send amount_sat
3071
        This is used when suggesting a swap or rebalance in order to receive a payment
3072
        """
3073
        with self.lock:
×
3074
            func = self.num_sats_can_send if direction == SENT else self.num_sats_can_receive
×
3075
            suggestions = []
×
3076
            channels = self.get_channels_for_sending() if direction == SENT else self.get_channels_for_receiving()
×
3077
            for chan in channels:
×
3078
                available_sat = chan.available_to_spend(LOCAL if direction == SENT else REMOTE) // 1000
×
3079
                delta = amount_sat - available_sat
×
3080
                delta += self.fee_estimate(amount_sat)
×
3081
                # add safety margin
3082
                delta += delta // 100 + 1
×
3083
                if func(deltas={chan:delta}) >= amount_sat:
×
3084
                    suggestions.append((chan, int(delta)))
×
3085
                elif direction == RECEIVED and func(deltas={chan:2*delta}) >= amount_sat:
×
3086
                    # MPP heuristics has a 0.5 slope
3087
                    suggestions.append((chan, int(2*delta)))
×
3088
        if not suggestions:
×
3089
            raise NotEnoughFunds
×
3090
        return suggestions
×
3091

3092
    def _suggest_rebalance(self, direction, amount_sat):
1✔
3093
        """
3094
        Suggest a rebalance in order to be able to send or receive amount_sat.
3095
        Returns (from_channel, to_channel, amount to shuffle)
3096
        """
3097
        try:
×
3098
            suggestions = self._suggest_channels_for_rebalance(direction, amount_sat)
×
3099
        except NotEnoughFunds:
×
3100
            return False
×
3101
        for chan2, delta in suggestions:
×
3102
            # margin for fee caused by rebalancing
3103
            delta += self.fee_estimate(amount_sat)
×
3104
            # find other channel or trampoline that can send delta
3105
            for chan1 in self.channels.values():
×
3106
                if chan1.is_frozen_for_sending() or not chan1.is_active():
×
3107
                    continue
×
3108
                if chan1 == chan2:
×
3109
                    continue
×
3110
                if self.uses_trampoline() and chan1.node_id == chan2.node_id:
×
3111
                    continue
×
3112
                if direction == SENT:
×
3113
                    if chan1.can_pay(delta*1000):
×
3114
                        return chan1, chan2, delta
×
3115
                else:
3116
                    if chan1.can_receive(delta*1000):
×
3117
                        return chan2, chan1, delta
×
3118
            else:
3119
                continue
×
3120
        else:
3121
            return False
×
3122

3123
    def num_sats_can_rebalance(self, chan1, chan2):
1✔
3124
        # TODO: we should be able to spend 'max', with variable fee
3125
        n1 = chan1.available_to_spend(LOCAL)
×
3126
        n1 -= self.fee_estimate(n1)
×
3127
        n2 = chan2.available_to_spend(REMOTE)
×
3128
        amount_sat = min(n1, n2) // 1000
×
3129
        return amount_sat
×
3130

3131
    def suggest_rebalance_to_send(self, amount_sat):
1✔
3132
        return self._suggest_rebalance(SENT, amount_sat)
×
3133

3134
    def suggest_rebalance_to_receive(self, amount_sat):
1✔
3135
        return self._suggest_rebalance(RECEIVED, amount_sat)
×
3136

3137
    def suggest_swap_to_send(self, amount_sat, coins):
1✔
3138
        # fixme: if swap_amount_sat is lower than the minimum swap amount, we need to propose a higher value
3139
        assert amount_sat > self.num_sats_can_send()
×
3140
        try:
×
3141
            suggestions = self._suggest_channels_for_rebalance(SENT, amount_sat)
×
3142
        except NotEnoughFunds:
×
3143
            return None
×
3144
        for chan, swap_recv_amount in suggestions:
×
3145
            # check that we can send onchain
3146
            swap_server_mining_fee = 10000 # guessing, because we have not called get_pairs yet
×
3147
            swap_funding_sat = swap_recv_amount + swap_server_mining_fee
×
3148
            swap_output = PartialTxOutput.from_address_and_value(DummyAddress.SWAP, int(swap_funding_sat))
×
3149
            try:
×
3150
                # check if we have enough onchain funds
3151
                self.wallet.make_unsigned_transaction(
×
3152
                    coins=coins,
3153
                    outputs=[swap_output],
3154
                    fee_policy=FeePolicy(self.config.FEE_POLICY_SWAPS),
3155
                )
3156
            except NotEnoughFunds:
×
3157
                continue
×
3158
            return chan, swap_recv_amount
×
3159
        return None
×
3160

3161
    def suggest_swap_to_receive(self, amount_sat: int):
1✔
3162
        assert amount_sat > self.num_sats_can_receive(), f"{amount_sat=} | {self.num_sats_can_receive()=}"
×
3163
        try:
×
3164
            suggestions = self._suggest_channels_for_rebalance(RECEIVED, amount_sat)
×
3165
        except NotEnoughFunds:
×
3166
            return
×
3167
        for chan, swap_recv_amount in suggestions:
×
3168
            return chan, swap_recv_amount
×
3169

3170
    async def rebalance_channels(self, chan1: Channel, chan2: Channel, *, amount_msat: int):
1✔
3171
        if chan1 == chan2:
×
3172
            raise Exception('Rebalance requires two different channels')
×
3173
        if self.uses_trampoline() and chan1.node_id == chan2.node_id:
×
3174
            raise Exception('Rebalance requires channels from different trampolines')
×
3175
        payment_hash = self.create_payment_info(
×
3176
            amount_msat=amount_msat,
3177
            exp_delay=3600,
3178
        )
3179
        info = self.get_payment_info(payment_hash)
×
3180
        lnaddr, invoice = self.get_bolt11_invoice(
×
3181
            payment_info=info,
3182
            message='rebalance',
3183
            fallback_address=None,
3184
            channels=[chan2],
3185
        )
3186
        invoice_obj = Invoice.from_bech32(invoice)
×
3187
        return await self.pay_invoice(invoice_obj, channels=[chan1])
×
3188

3189
    def can_receive_invoice(self, invoice: BaseInvoice) -> bool:
1✔
3190
        assert invoice.is_lightning()
×
3191
        return (invoice.get_amount_sat() or 0) <= self.num_sats_can_receive()
×
3192

3193
    async def close_channel(self, chan_id):
1✔
3194
        chan = self._channels[chan_id]
×
3195
        peer = self._peers[chan.node_id]
×
3196
        return await peer.close_channel(chan_id)
×
3197

3198
    def _force_close_channel(self, chan_id: bytes) -> Transaction:
1✔
3199
        chan = self._channels[chan_id]
1✔
3200
        tx = chan.force_close_tx()
1✔
3201
        # We set the channel state to make sure we won't sign new commitment txs.
3202
        # We expect the caller to try to broadcast this tx, after which it is
3203
        # not safe to keep using the channel even if the broadcast errors (server could be lying).
3204
        # Until the tx is seen in the mempool, there will be automatic rebroadcasts.
3205
        chan.set_state(ChannelState.FORCE_CLOSING)
1✔
3206
        # Add local tx to wallet to also allow manual rebroadcasts.
3207
        try:
1✔
3208
            self.wallet.adb.add_transaction(tx)
1✔
3209
        except UnrelatedTransactionException:
×
3210
            pass  # this can happen if (~all the balance goes to REMOTE)
×
3211
        return tx
1✔
3212

3213
    async def force_close_channel(self, chan_id: bytes) -> str:
1✔
3214
        """Force-close the channel. Network-related exceptions are propagated to the caller.
3215
        (automatic rebroadcasts will be scheduled)
3216
        """
3217
        # note: as we are async, it can take a few event loop iterations between the caller
3218
        #       "calling us" and us getting to run, and we only set the channel state now:
3219
        tx = self._force_close_channel(chan_id)
1✔
3220
        await self.network.broadcast_transaction(tx)
1✔
3221
        return tx.txid()
1✔
3222

3223
    def schedule_force_closing(self, chan_id: bytes) -> 'asyncio.Task[bool]':
1✔
3224
        """Schedules a task to force-close the channel and returns it.
3225
        Network-related exceptions are suppressed.
3226
        (automatic rebroadcasts will be scheduled)
3227
        Note: this method is intentionally not async so that callers have a guarantee
3228
              that the channel state is set immediately.
3229
        """
3230
        tx = self._force_close_channel(chan_id)
1✔
3231
        return asyncio.create_task(self.network.try_broadcasting(tx, 'force-close'))
1✔
3232

3233
    def remove_channel(self, chan_id):
1✔
3234
        chan = self.channels[chan_id]
×
3235
        assert chan.can_be_deleted()
×
3236
        with self.lock:
×
3237
            self._channels.pop(chan_id)
×
3238
            self.db.get('channels').pop(chan_id.hex())
×
3239
        self.wallet.set_reserved_addresses_for_chan(chan, reserved=False)
×
3240

3241
        util.trigger_callback('channels_updated', self.wallet)
×
3242
        util.trigger_callback('wallet_updated', self.wallet)
×
3243

3244
    @ignore_exceptions
1✔
3245
    @log_exceptions
1✔
3246
    async def reestablish_peer_for_given_channel(self, chan: Channel) -> None:
1✔
3247
        now = time.time()
×
3248
        peer_addresses = []
×
3249
        if self.uses_trampoline():
×
3250
            addr = trampolines_by_id().get(chan.node_id)
×
3251
            if addr:
×
3252
                peer_addresses.append(addr)
×
3253
        else:
3254
            # will try last good address first, from gossip
3255
            last_good_addr = self.channel_db.get_last_good_address(chan.node_id)
×
3256
            if last_good_addr:
×
3257
                peer_addresses.append(last_good_addr)
×
3258
            # will try addresses for node_id from gossip
3259
            addrs_from_gossip = self.channel_db.get_node_addresses(chan.node_id) or []
×
3260
            for host, port, ts in addrs_from_gossip:
×
3261
                peer_addresses.append(LNPeerAddr(host, port, chan.node_id))
×
3262
        # will try addresses stored in channel storage
3263
        peer_addresses += list(chan.get_peer_addresses())
×
3264
        # Done gathering addresses.
3265
        # Now select first one that has not failed recently.
3266
        for peer in peer_addresses:
×
3267
            if self._can_retry_addr(peer, urgent=True, now=now):
×
3268
                await self._add_peer(peer.host, peer.port, peer.pubkey)
×
3269
                return
×
3270

3271
    async def reestablish_peers_and_channels(self):
1✔
3272
        while True:
1✔
3273
            await asyncio.sleep(1)
1✔
3274
            if self.stopping_soon:
×
3275
                return
×
3276
            if self.config.ZEROCONF_TRUSTED_NODE:
×
3277
                peer = LNPeerAddr.from_str(self.config.ZEROCONF_TRUSTED_NODE)
×
3278
                if self._can_retry_addr(peer, urgent=True):
×
3279
                    await self._add_peer(peer.host, peer.port, peer.pubkey)
×
3280
            for chan in self.channels.values():
×
3281
                # reestablish
3282
                # note: we delegate filtering out uninteresting chans to this:
3283
                if not chan.should_try_to_reestablish_peer():
×
3284
                    continue
×
3285
                peer = self._peers.get(chan.node_id, None)
×
3286
                if peer:
×
3287
                    await peer.taskgroup.spawn(peer.reestablish_channel(chan))
×
3288
                else:
3289
                    await self.taskgroup.spawn(self.reestablish_peer_for_given_channel(chan))
×
3290

3291
    def current_target_feerate_per_kw(self, *, has_anchors: bool) -> Optional[int]:
1✔
3292
        target: int = FEE_LN_MINIMUM_ETA_TARGET if has_anchors else FEE_LN_ETA_TARGET
1✔
3293
        feerate_per_kvbyte = self.network.fee_estimates.eta_target_to_fee(target)
1✔
3294
        if feerate_per_kvbyte is None:
1✔
3295
            return None
×
3296
        if has_anchors:
1✔
3297
            # set a floor of 5 sat/vb to have some safety margin in case the mempool
3298
            # grows quickly
3299
            feerate_per_kvbyte = max(feerate_per_kvbyte, 5000)
×
3300
        return max(FEERATE_PER_KW_MIN_RELAY_LIGHTNING, feerate_per_kvbyte // 4)
1✔
3301

3302
    def current_low_feerate_per_kw_srk_channel(self) -> Optional[int]:
1✔
3303
        """Gets low feerate for static remote key channels."""
3304
        if constants.net is constants.BitcoinRegtest:
1✔
3305
            feerate_per_kvbyte = 0
×
3306
        else:
3307
            feerate_per_kvbyte = self.network.fee_estimates.eta_target_to_fee(FEE_LN_LOW_ETA_TARGET)
1✔
3308
            if feerate_per_kvbyte is None:
1✔
3309
                return None
×
3310
        low_feerate_per_kw = max(FEERATE_PER_KW_MIN_RELAY_LIGHTNING, feerate_per_kvbyte // 4)
1✔
3311
        # make sure this is never higher than the target feerate:
3312
        current_target_feerate = self.current_target_feerate_per_kw(has_anchors=False)
1✔
3313
        if not current_target_feerate:
1✔
3314
            return None
×
3315
        low_feerate_per_kw = min(low_feerate_per_kw, current_target_feerate)
1✔
3316
        return low_feerate_per_kw
1✔
3317

3318
    def create_channel_backup(self, channel_id: bytes):
1✔
3319
        chan = self._channels[channel_id]
×
3320
        # do not backup old-style channels
3321
        assert chan.is_static_remotekey_enabled()
×
3322
        peer_addresses = list(chan.get_peer_addresses())
×
3323
        peer_addr = peer_addresses[0]
×
3324
        return ImportedChannelBackupStorage(
×
3325
            node_id=chan.node_id,
3326
            privkey=self.node_keypair.privkey,
3327
            funding_txid=chan.funding_outpoint.txid,
3328
            funding_index=chan.funding_outpoint.output_index,
3329
            funding_address=chan.get_funding_address(),
3330
            host=peer_addr.host,
3331
            port=peer_addr.port,
3332
            is_initiator=chan.constraints.is_initiator,
3333
            channel_seed=chan.config[LOCAL].channel_seed,
3334
            local_delay=chan.config[LOCAL].to_self_delay,
3335
            remote_delay=chan.config[REMOTE].to_self_delay,
3336
            remote_revocation_pubkey=chan.config[REMOTE].revocation_basepoint.pubkey,
3337
            remote_payment_pubkey=chan.config[REMOTE].payment_basepoint.pubkey,
3338
            local_payment_pubkey=chan.config[LOCAL].payment_basepoint.pubkey,
3339
            multisig_funding_privkey=chan.config[LOCAL].multisig_key.privkey,
3340
        )
3341

3342
    def export_channel_backup(self, channel_id):
1✔
3343
        xpub = self.wallet.get_fingerprint()
×
3344
        backup_bytes = self.create_channel_backup(channel_id).to_bytes()
×
3345
        assert backup_bytes == ImportedChannelBackupStorage.from_bytes(backup_bytes).to_bytes(), "roundtrip failed"
×
3346
        encrypted = pw_encode_with_version_and_mac(backup_bytes, xpub)
×
3347
        assert backup_bytes == pw_decode_with_version_and_mac(encrypted, xpub), "encrypt failed"
×
3348
        return 'channel_backup:' + encrypted
×
3349

3350
    async def request_force_close(self, channel_id: bytes, *, connect_str=None) -> None:
1✔
3351
        if channel_id in self.channels:
×
3352
            chan = self.channels[channel_id]
×
3353
            peer = self._peers.get(chan.node_id)
×
3354
            chan.should_request_force_close = True
×
3355
            if peer:
×
3356
                peer.close_and_cleanup()  # to force a reconnect
×
3357
        elif connect_str:
×
3358
            peer = await self.add_peer(connect_str)
×
3359
            await peer.request_force_close(channel_id)
×
3360
        elif channel_id in self.channel_backups:
×
3361
            await self._request_force_close_from_backup(channel_id)
×
3362
        else:
3363
            raise Exception(f'Unknown channel {channel_id.hex()}')
×
3364

3365
    def import_channel_backup(self, data):
1✔
3366
        xpub = self.wallet.get_fingerprint()
×
3367
        cb_storage = ImportedChannelBackupStorage.from_encrypted_str(data, password=xpub)
×
3368
        channel_id = cb_storage.channel_id()
×
3369
        if channel_id.hex() in self.db.get_dict("channels"):
×
3370
            raise Exception('Channel already in wallet')
×
3371
        self.logger.info(f'importing channel backup: {channel_id.hex()}')
×
3372
        d = self.db.get_dict("imported_channel_backups")
×
3373
        d[channel_id.hex()] = cb_storage
×
3374
        with self.lock:
×
3375
            cb = ChannelBackup(cb_storage, lnworker=self)
×
3376
            self._channel_backups[channel_id] = cb
×
3377
        self.wallet.set_reserved_addresses_for_chan(cb, reserved=True)
×
3378
        self.wallet.save_db()
×
3379
        util.trigger_callback('channels_updated', self.wallet)
×
3380
        self.lnwatcher.add_channel(cb)
×
3381

3382
    def has_conflicting_backup_with(self, remote_node_id: bytes):
1✔
3383
        """ Returns whether we have an active channel with this node on another device, using same local node id. """
3384
        channel_backup_peers = [
×
3385
            cb.node_id for cb in self.channel_backups.values()
3386
            if (not cb.is_closed() and cb.get_local_pubkey() == self.node_keypair.pubkey)]
3387
        return any(remote_node_id.startswith(cb_peer_nodeid) for cb_peer_nodeid in channel_backup_peers)
×
3388

3389
    def remove_channel_backup(self, channel_id):
1✔
3390
        chan = self.channel_backups[channel_id]
×
3391
        assert chan.can_be_deleted()
×
3392
        found = False
×
3393
        onchain_backups = self.db.get_dict("onchain_channel_backups")
×
3394
        imported_backups = self.db.get_dict("imported_channel_backups")
×
3395
        if channel_id.hex() in onchain_backups:
×
3396
            onchain_backups.pop(channel_id.hex())
×
3397
            found = True
×
3398
        if channel_id.hex() in imported_backups:
×
3399
            imported_backups.pop(channel_id.hex())
×
3400
            found = True
×
3401
        if not found:
×
3402
            raise Exception('Channel not found')
×
3403
        with self.lock:
×
3404
            self._channel_backups.pop(channel_id)
×
3405
        self.wallet.set_reserved_addresses_for_chan(chan, reserved=False)
×
3406
        self.wallet.save_db()
×
3407
        util.trigger_callback('channels_updated', self.wallet)
×
3408

3409
    @log_exceptions
1✔
3410
    async def _request_force_close_from_backup(self, channel_id: bytes):
1✔
3411
        cb = self.channel_backups.get(channel_id)
×
3412
        if not cb:
×
3413
            raise Exception(f'channel backup not found {self.channel_backups}')
×
3414
        cb = cb.cb # storage
×
3415
        self.logger.info(f'requesting channel force close: {channel_id.hex()}')
×
3416
        if isinstance(cb, ImportedChannelBackupStorage):
×
3417
            node_id = cb.node_id
×
3418
            privkey = cb.privkey
×
3419
            addresses = [(cb.host, cb.port, 0)]
×
3420
        else:
3421
            assert isinstance(cb, OnchainChannelBackupStorage)
×
3422
            privkey = self.node_keypair.privkey
×
3423
            for pubkey, peer_addr in trampolines_by_id().items():
×
3424
                if pubkey.startswith(cb.node_id_prefix):
×
3425
                    node_id = pubkey
×
3426
                    addresses = [(peer_addr.host, peer_addr.port, 0)]
×
3427
                    break
×
3428
            else:
3429
                # we will try with gossip (see below)
3430
                addresses = []
×
3431

3432
        async def _request_fclose(addresses):
×
3433
            for host, port, timestamp in addresses:
×
3434
                peer_addr = LNPeerAddr(host, port, node_id)
×
3435
                transport = LNTransport(privkey, peer_addr, e_proxy=ESocksProxy.from_network_settings(self.network))
×
3436
                peer = Peer(self, node_id, transport, is_channel_backup=True)
×
3437
                try:
×
3438
                    async with OldTaskGroup(wait=any) as group:
×
3439
                        await group.spawn(peer._message_loop())
×
3440
                        await group.spawn(peer.request_force_close(channel_id))
×
3441
                    return True
×
3442
                except Exception as e:
×
3443
                    self.logger.info(f'failed to connect {host} {e}')
×
3444
                    continue
×
3445
            else:
3446
                return False
×
3447
        # try first without gossip db
3448
        success = await _request_fclose(addresses)
×
3449
        if success:
×
3450
            return
×
3451
        # try with gossip db
3452
        if self.uses_trampoline():
×
3453
            raise Exception(_('Please enable gossip'))
×
3454
        node_id = self.network.channel_db.get_node_by_prefix(cb.node_id_prefix)
×
3455
        addresses_from_gossip = self.network.channel_db.get_node_addresses(node_id)
×
3456
        if not addresses_from_gossip:
×
3457
            raise Exception('Peer not found in gossip database')
×
3458
        success = await _request_fclose(addresses_from_gossip)
×
3459
        if not success:
×
3460
            raise Exception('failed to connect')
×
3461

3462
    def maybe_add_backup_from_tx(self, tx):
1✔
3463
        funding_address = None
1✔
3464
        node_id_prefix = None
1✔
3465
        for i, o in enumerate(tx.outputs()):
1✔
3466
            script_type = get_script_type_from_output_script(o.scriptpubkey)
1✔
3467
            if script_type == 'p2wsh':
1✔
3468
                funding_index = i
×
3469
                funding_address = o.address
×
3470
                for o2 in tx.outputs():
×
3471
                    if o2.scriptpubkey.startswith(bytes([opcodes.OP_RETURN])):
×
3472
                        encrypted_data = o2.scriptpubkey[2:]
×
3473
                        data = self.decrypt_cb_data(encrypted_data, funding_address)
×
3474
                        if data.startswith(CB_MAGIC_BYTES):
×
3475
                            node_id_prefix = data[len(CB_MAGIC_BYTES):]
×
3476
        if node_id_prefix is None:
1✔
3477
            return
1✔
3478
        funding_txid = tx.txid()
×
3479
        cb_storage = OnchainChannelBackupStorage(
×
3480
            node_id_prefix=node_id_prefix,
3481
            funding_txid=funding_txid,
3482
            funding_index=funding_index,
3483
            funding_address=funding_address,
3484
            is_initiator=True)
3485
        channel_id = cb_storage.channel_id().hex()
×
3486
        if channel_id in self.db.get_dict("channels"):
×
3487
            return
×
3488
        self.logger.info(f"adding backup from tx")
×
3489
        d = self.db.get_dict("onchain_channel_backups")
×
3490
        d[channel_id] = cb_storage
×
3491
        cb = ChannelBackup(cb_storage, lnworker=self)
×
3492
        self.wallet.set_reserved_addresses_for_chan(cb, reserved=True)
×
3493
        self.wallet.save_db()
×
3494
        with self.lock:
×
3495
            self._channel_backups[bfh(channel_id)] = cb
×
3496
        util.trigger_callback('channels_updated', self.wallet)
×
3497
        self.lnwatcher.add_channel(cb)
×
3498

3499
    async def maybe_forward_htlc_set(
1✔
3500
        self,
3501
        payment_key: str, *,
3502
        processed_htlc_set: dict[ReceivedMPPHtlc, Tuple[ProcessedOnionPacket, Optional[ProcessedOnionPacket]]],
3503
    ) -> None:
3504
        assert self.enable_htlc_forwarding
1✔
3505
        assert payment_key not in self.active_forwardings, "cannot forward set twice"
1✔
3506
        self.active_forwardings[payment_key] = []
1✔
3507
        self.logger.debug(f"adding active_forwarding: {payment_key=}")
1✔
3508

3509
        any_mpp_htlc, (any_outer_onion, any_trampoline_onion) = next(iter(processed_htlc_set.items()))
1✔
3510
        try:
1✔
3511
            if any_trampoline_onion is None:
1✔
3512
                assert not any_outer_onion.are_we_final
1✔
3513
                assert len(processed_htlc_set) == 1, processed_htlc_set
1✔
3514
                forward_htlc = any_mpp_htlc.htlc
1✔
3515
                incoming_chan = self.get_channel_by_short_id(any_mpp_htlc.scid)
1✔
3516
                next_htlc = await self._maybe_forward_htlc(
1✔
3517
                    incoming_chan=incoming_chan,
3518
                    htlc=forward_htlc,
3519
                    processed_onion=any_outer_onion,
3520
                )
3521
                htlc_key = serialize_htlc_key(incoming_chan.get_scid_or_local_alias(), forward_htlc.htlc_id)
1✔
3522
                self.active_forwardings[payment_key].append(next_htlc)
1✔
3523
                self.downstream_to_upstream_htlc[next_htlc] = htlc_key
1✔
3524
            else:
3525
                assert not any_trampoline_onion.are_we_final and any_outer_onion.are_we_final
1✔
3526
                # trampoline forwarding
3527
                min_inc_cltv_abs = min(
1✔
3528
                    mpp_htlc.htlc.cltv_abs
3529
                    for mpp_htlc in processed_htlc_set.keys())  # take "min" to assume worst-case
3530
                await self._maybe_forward_trampoline(
1✔
3531
                    payment_hash=any_mpp_htlc.htlc.payment_hash,
3532
                    closest_inc_cltv_abs=min_inc_cltv_abs,
3533
                    total_msat=any_outer_onion.total_msat,
3534
                    any_trampoline_onion=any_trampoline_onion,
3535
                    fw_payment_key=payment_key,
3536
                )
3537
        except OnionRoutingFailure as e:
1✔
3538
            self.logger.debug(f"forwarding failed: {e=}")
1✔
3539
            if len(self.active_forwardings[payment_key]) == 0:
1✔
3540
                self.save_forwarding_failure(payment_key, failure_message=e)
1✔
3541
        # TODO what about other errors?
3542
        #      Could we "catch-all Exception" and fail back the htlcs with e.g. TEMPORARY_NODE_FAILURE?
3543
        #        - we don't want to fail the inc-HTLC for a syntax error that happens in the callback
3544
        #      If we don't call save_forwarding_failure(), the inc-HTLC gets stuck until expiry
3545
        #      and then the inc-channel will get force-closed.
3546
        #      => forwarding_callback() could have an API with two exceptions types:
3547
        #        - type1, such as OnionRoutingFailure, that signals we need to fail back the inc-HTLC
3548
        #        - type2, such as NoPathFound, that signals we want to retry forwarding
3549

3550
    async def _maybe_forward_htlc(
1✔
3551
            self, *,
3552
            incoming_chan: Channel,
3553
            htlc: UpdateAddHtlc,
3554
            processed_onion: ProcessedOnionPacket,
3555
    ) -> str:
3556

3557
        # Forward HTLC
3558
        # FIXME: there are critical safety checks MISSING here
3559
        #        - for example; atm we forward first and then persist "forwarding_info",
3560
        #          so if we segfault in-between and restart, we might forward an HTLC twice...
3561
        #          (same for trampoline forwarding)
3562
        #        - we could check for the exposure to dust HTLCs, see:
3563
        #          https://github.com/ACINQ/eclair/pull/1985
3564

3565
        def log_fail_reason(reason: str):
1✔
3566
            self.logger.debug(
1✔
3567
                f"_maybe_forward_htlc. will FAIL HTLC: inc_chan={incoming_chan.get_id_for_log()}. "
3568
                f"{reason}. inc_htlc={str(htlc)}. onion_payload={processed_onion.hop_data.payload}")
3569

3570
        forwarding_enabled = self.network.config.EXPERIMENTAL_LN_FORWARD_PAYMENTS
1✔
3571
        if not forwarding_enabled:
1✔
3572
            log_fail_reason("forwarding is disabled")
×
3573
            raise OnionRoutingFailure(code=OnionFailureCode.PERMANENT_CHANNEL_FAILURE, data=b'')
×
3574
        chain = self.network.blockchain()
1✔
3575
        if chain.is_tip_stale():
1✔
3576
            raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_NODE_FAILURE, data=b'')
×
3577
        if (next_chan_scid := processed_onion.next_chan_scid) is None:
1✔
3578
            raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_PAYLOAD, data=b'\x00\x00\x00')
×
3579
        if (next_amount_msat_htlc := processed_onion.amt_to_forward) is None:
1✔
3580
            raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_PAYLOAD, data=b'\x00\x00\x00')
×
3581
        if (next_cltv_abs := processed_onion.outgoing_cltv_value) is None:
1✔
3582
            raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_PAYLOAD, data=b'\x00\x00\x00')
×
3583

3584
        next_chan = self.get_channel_by_short_id(next_chan_scid)
1✔
3585

3586
        if self.features.supports(LnFeatures.OPTION_ZEROCONF_OPT):
1✔
3587
            next_peer = self.get_peer_by_static_jit_scid_alias(next_chan_scid)
×
3588
        else:
3589
            next_peer = None
1✔
3590

3591
        if not next_chan and next_peer and next_peer.accepts_zeroconf():
1✔
3592
            # check if an already existing channel can be used.
3593
            # todo: split the payment
3594
            for next_chan in next_peer.channels.values():
×
3595
                if next_chan.can_pay(next_amount_msat_htlc):
×
3596
                    break
×
3597
            else:
3598
                return await self.open_channel_just_in_time(
×
3599
                    next_peer=next_peer,
3600
                    next_amount_msat_htlc=next_amount_msat_htlc,
3601
                    next_cltv_abs=next_cltv_abs,
3602
                    payment_hash=htlc.payment_hash,
3603
                    next_onion=processed_onion.next_packet)
3604

3605
        local_height = chain.height()
1✔
3606
        if next_chan is None:
1✔
3607
            log_fail_reason(f"cannot find next_chan {next_chan_scid}")
×
3608
            raise OnionRoutingFailure(code=OnionFailureCode.UNKNOWN_NEXT_PEER, data=b'')
×
3609
        outgoing_chan_upd = next_chan.get_outgoing_gossip_channel_update(scid=next_chan_scid)[2:]
1✔
3610
        outgoing_chan_upd_len = len(outgoing_chan_upd).to_bytes(2, byteorder="big")
1✔
3611
        outgoing_chan_upd_message = outgoing_chan_upd_len + outgoing_chan_upd
1✔
3612
        if not next_chan.can_send_update_add_htlc():
1✔
3613
            log_fail_reason(
×
3614
                f"next_chan {next_chan.get_id_for_log()} cannot send ctx updates. "
3615
                f"chan state {next_chan.get_state()!r}, peer state: {next_chan.peer_state!r}")
3616
            raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, data=outgoing_chan_upd_message)
×
3617
        if not next_chan.can_pay(next_amount_msat_htlc):
1✔
3618
            log_fail_reason(f"transient error (likely due to insufficient funds): not next_chan.can_pay(amt)")
1✔
3619
            raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, data=outgoing_chan_upd_message)
1✔
3620
        if htlc.cltv_abs - next_cltv_abs < next_chan.forwarding_cltv_delta:
1✔
3621
            log_fail_reason(
×
3622
                f"INCORRECT_CLTV_EXPIRY. "
3623
                f"{htlc.cltv_abs=} - {next_cltv_abs=} < {next_chan.forwarding_cltv_delta=}")
3624
            data = htlc.cltv_abs.to_bytes(4, byteorder="big") + outgoing_chan_upd_message
×
3625
            raise OnionRoutingFailure(code=OnionFailureCode.INCORRECT_CLTV_EXPIRY, data=data)
×
3626
        if htlc.cltv_abs - lnutil.MIN_FINAL_CLTV_DELTA_ACCEPTED <= local_height \
1✔
3627
                or next_cltv_abs <= local_height:
3628
            raise OnionRoutingFailure(code=OnionFailureCode.EXPIRY_TOO_SOON, data=outgoing_chan_upd_message)
×
3629
        if max(htlc.cltv_abs, next_cltv_abs) > local_height + lnutil.NBLOCK_CLTV_DELTA_TOO_FAR_INTO_FUTURE:
1✔
3630
            raise OnionRoutingFailure(code=OnionFailureCode.EXPIRY_TOO_FAR, data=b'')
×
3631
        forwarding_fees = fee_for_edge_msat(
1✔
3632
            forwarded_amount_msat=next_amount_msat_htlc,
3633
            fee_base_msat=next_chan.forwarding_fee_base_msat,
3634
            fee_proportional_millionths=next_chan.forwarding_fee_proportional_millionths)
3635
        if htlc.amount_msat - next_amount_msat_htlc < forwarding_fees:
1✔
3636
            data = next_amount_msat_htlc.to_bytes(8, byteorder="big") + outgoing_chan_upd_message
×
3637
            raise OnionRoutingFailure(code=OnionFailureCode.FEE_INSUFFICIENT, data=data)
×
3638
        if self._maybe_refuse_to_forward_htlc_that_corresponds_to_payreq_we_created(htlc.payment_hash):
1✔
3639
            log_fail_reason(f"RHASH corresponds to payreq we created")
1✔
3640
            raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_NODE_FAILURE, data=b'')
1✔
3641
        self.logger.info(
1✔
3642
            f"maybe_forward_htlc. will forward HTLC: inc_chan={incoming_chan.short_channel_id}. inc_htlc={str(htlc)}. "
3643
            f"next_chan={next_chan.get_id_for_log()}.")
3644

3645
        next_peer = self.peers.get(next_chan.node_id)
1✔
3646
        if next_peer is None:
1✔
3647
            log_fail_reason(f"next_peer offline ({next_chan.node_id.hex()})")
×
3648
            raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, data=outgoing_chan_upd_message)
×
3649
        try:
1✔
3650
            next_htlc = next_peer.send_htlc(
1✔
3651
                chan=next_chan,
3652
                payment_hash=htlc.payment_hash,
3653
                amount_msat=next_amount_msat_htlc,
3654
                cltv_abs=next_cltv_abs,
3655
                onion=processed_onion.next_packet,
3656
            )
3657
        except BaseException as e:
×
3658
            log_fail_reason(f"error sending message to next_peer={next_chan.node_id.hex()}")
×
3659
            raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, data=outgoing_chan_upd_message)
×
3660

3661
        htlc_key = serialize_htlc_key(next_chan.get_scid_or_local_alias(), next_htlc.htlc_id)
1✔
3662
        return htlc_key
1✔
3663

3664
    async def _maybe_forward_trampoline(
1✔
3665
            self, *,
3666
            payment_hash: bytes,
3667
            closest_inc_cltv_abs: int,
3668
            total_msat: int,  # total_msat of the outer onion
3669
            any_trampoline_onion: ProcessedOnionPacket,  # any trampoline onion of the incoming htlc set, they should be similar
3670
            fw_payment_key: str,
3671
    ) -> None:
3672

3673
        forwarding_enabled = self.network.config.EXPERIMENTAL_LN_FORWARD_PAYMENTS
1✔
3674
        forwarding_trampoline_enabled = self.network.config.EXPERIMENTAL_LN_FORWARD_TRAMPOLINE_PAYMENTS
1✔
3675
        if not (forwarding_enabled and forwarding_trampoline_enabled):
1✔
3676
            self.logger.info(f"trampoline forwarding is disabled. failing htlc.")
×
3677
            raise OnionRoutingFailure(code=OnionFailureCode.PERMANENT_CHANNEL_FAILURE, data=b'')
×
3678
        payload = any_trampoline_onion.hop_data.payload
1✔
3679
        payment_data = payload.get('payment_data')
1✔
3680
        try:
1✔
3681
            payment_secret = payment_data['payment_secret'] if payment_data else os.urandom(32)
1✔
3682
            outgoing_node_id = payload["outgoing_node_id"]["outgoing_node_id"]
1✔
3683
            amt_to_forward = payload["amt_to_forward"]["amt_to_forward"]
1✔
3684
            out_cltv_abs = payload["outgoing_cltv_value"]["outgoing_cltv_value"]
1✔
3685
            if "invoice_features" in payload:
1✔
3686
                self.logger.info('forward_trampoline: legacy')
1✔
3687
                next_trampoline_onion = None
1✔
3688
                invoice_features = payload["invoice_features"]["invoice_features"]
1✔
3689
                invoice_routing_info = payload["invoice_routing_info"]["invoice_routing_info"]
1✔
3690
                r_tags = decode_routing_info(invoice_routing_info)
1✔
3691
                self.logger.info(f'r_tags {r_tags}')
1✔
3692
                # TODO legacy mpp payment, use total_msat from trampoline onion
3693
            else:
3694
                self.logger.info('forward_trampoline: end-to-end')
1✔
3695
                invoice_features = LnFeatures.BASIC_MPP_OPT
1✔
3696
                next_trampoline_onion = any_trampoline_onion.next_packet
1✔
3697
                r_tags = []
1✔
3698
        except Exception as e:
×
3699
            self.logger.exception('')
×
3700
            raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_PAYLOAD, data=b'\x00\x00\x00')
×
3701

3702
        if self._maybe_refuse_to_forward_htlc_that_corresponds_to_payreq_we_created(payment_hash):
1✔
3703
            self.logger.debug(
1✔
3704
                f"maybe_forward_trampoline. will FAIL HTLC(s). "
3705
                f"RHASH corresponds to payreq we created. {payment_hash.hex()=}")
3706
            raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_NODE_FAILURE, data=b'')
1✔
3707

3708
        # these are the fee/cltv paid by the sender
3709
        # pay_to_node will raise if they are not sufficient
3710
        budget = PaymentFeeBudget(
1✔
3711
            fee_msat=total_msat - amt_to_forward,
3712
            cltv=closest_inc_cltv_abs - out_cltv_abs,
3713
        )
3714
        self.logger.info(f'trampoline forwarding. budget={budget}')
1✔
3715
        self.logger.info(f'trampoline forwarding. {closest_inc_cltv_abs=}, {out_cltv_abs=}')
1✔
3716
        # To convert abs vs rel cltvs, we need to guess blockheight used by original sender as "current blockheight".
3717
        # Blocks might have been mined since.
3718
        # - if we skew towards the past, we decrease our own cltv_budget accordingly (which is ok)
3719
        # - if we skew towards the future, we decrease the cltv_budget for the subsequent nodes in the path,
3720
        #   which can result in them failing the payment.
3721
        # So we skew towards the past and guess that there has been 1 new block mined since the payment began:
3722
        local_height_of_onion_creator = self.network.get_local_height() - 1
1✔
3723
        cltv_budget_for_rest_of_route = out_cltv_abs - local_height_of_onion_creator
1✔
3724

3725
        if budget.fee_msat < 1000:
1✔
3726
            raise OnionRoutingFailure(code=OnionFailureCode.TRAMPOLINE_FEE_INSUFFICIENT, data=b'')
1✔
3727
        if budget.cltv < 576:
1✔
3728
            raise OnionRoutingFailure(code=OnionFailureCode.TRAMPOLINE_EXPIRY_TOO_SOON, data=b'')
×
3729

3730
        # do we have a connection to the node?
3731
        next_peer = self.peers.get(outgoing_node_id)
1✔
3732
        if next_peer and next_peer.accepts_zeroconf():
1✔
3733
            self.logger.info(f'JIT: found next_peer')
×
3734
            for next_chan in next_peer.channels.values():
×
3735
                if next_chan.can_pay(amt_to_forward):
×
3736
                    # todo: detect if we can do mpp
3737
                    self.logger.info(f'jit: next_chan can pay')
×
3738
                    break
×
3739
            else:
3740
                scid_alias = self._scid_alias_of_node(next_peer.pubkey)
×
3741
                route = [RouteEdge(
×
3742
                    start_node=next_peer.pubkey,
3743
                    end_node=outgoing_node_id,
3744
                    short_channel_id=scid_alias,
3745
                    fee_base_msat=0,
3746
                    fee_proportional_millionths=0,
3747
                    cltv_delta=144,
3748
                    node_features=0
3749
                )]
3750
                next_onion, amount_msat, cltv_abs, session_key = self.create_onion_for_route(
×
3751
                    route=route,
3752
                    amount_msat=amt_to_forward,
3753
                    total_msat=amt_to_forward,
3754
                    payment_hash=payment_hash,
3755
                    min_final_cltv_delta=cltv_budget_for_rest_of_route,
3756
                    payment_secret=payment_secret,
3757
                    trampoline_onion=next_trampoline_onion,
3758
                )
3759
                await self.open_channel_just_in_time(
×
3760
                    next_peer=next_peer,
3761
                    next_amount_msat_htlc=amt_to_forward,
3762
                    next_cltv_abs=cltv_abs,
3763
                    payment_hash=payment_hash,
3764
                    next_onion=next_onion)
3765
                return
×
3766

3767
        try:
1✔
3768
            await self.pay_to_node(
1✔
3769
                node_pubkey=outgoing_node_id,
3770
                payment_hash=payment_hash,
3771
                payment_secret=payment_secret,
3772
                amount_to_pay=amt_to_forward,
3773
                min_final_cltv_delta=cltv_budget_for_rest_of_route,
3774
                r_tags=r_tags,
3775
                invoice_features=invoice_features,
3776
                fwd_trampoline_onion=next_trampoline_onion,
3777
                budget=budget,
3778
                attempts=100,
3779
                fw_payment_key=fw_payment_key,
3780
            )
3781
        except OnionRoutingFailure as e:
1✔
3782
            raise
×
3783
        except FeeBudgetExceeded:
1✔
3784
            raise OnionRoutingFailure(code=OnionFailureCode.TRAMPOLINE_FEE_INSUFFICIENT, data=b'')
×
3785
        except PaymentFailure as e:
1✔
3786
            self.logger.debug(
1✔
3787
                f"maybe_forward_trampoline. PaymentFailure for {payment_hash.hex()=}, {payment_secret.hex()=}: {e!r}")
3788
            raise OnionRoutingFailure(code=OnionFailureCode.UNKNOWN_NEXT_PEER, data=b'')
1✔
3789

3790
    def _maybe_refuse_to_forward_htlc_that_corresponds_to_payreq_we_created(self, payment_hash: bytes) -> bool:
1✔
3791
        """Returns True if the HTLC should be failed.
3792
        We must not forward HTLCs with a matching payment_hash to a payment request we created.
3793
        Example attack:
3794
        - Bob creates payment request with HASH1, for 1 BTC; and gives the payreq to Alice
3795
        - Alice sends htlc A->B->C, for 1 sat, with HASH1
3796
        - Bob must not release the preimage of HASH1
3797
        """
3798
        payment_info = self.get_payment_info(payment_hash)
1✔
3799
        is_our_payreq = payment_info and payment_info.direction == RECEIVED
1✔
3800
        # note: If we don't have the preimage for a payment request, then it must be a hold invoice.
3801
        #       Hold invoices are created by other parties (e.g. a counterparty initiating a submarine swap),
3802
        #       and it is the other party choosing the payment_hash. If we failed HTLCs with payment_hashes colliding
3803
        #       with hold invoices, then a party that can make us save a hold invoice for an arbitrary hash could
3804
        #       also make us fail arbitrary HTLCs.
3805
        return bool(is_our_payreq and self.get_preimage(payment_hash))
1✔
3806

3807
    def create_onion_for_route(
1✔
3808
        self, *,
3809
        route: 'LNPaymentRoute',
3810
        amount_msat: int,
3811
        total_msat: int,
3812
        payment_hash: bytes,
3813
        min_final_cltv_delta: int,
3814
        payment_secret: bytes,
3815
        trampoline_onion: Optional[OnionPacket] = None,
3816
    ):
3817
        # add features learned during "init" for direct neighbour:
3818
        route[0].node_features |= self.features
1✔
3819
        local_height = self.network.get_local_height()
1✔
3820
        final_cltv_abs = local_height + min_final_cltv_delta
1✔
3821
        hops_data, amount_msat, cltv_abs = calc_hops_data_for_payment(
1✔
3822
            route,
3823
            amount_msat,
3824
            final_cltv_abs=final_cltv_abs,
3825
            total_msat=total_msat,
3826
            payment_secret=payment_secret)
3827
        self.logger.info(f"pay len(route)={len(route)}. for payment_hash={payment_hash.hex()}")
1✔
3828
        for i in range(len(route)):
1✔
3829
            self.logger.info(f"  {i}: edge={route[i].short_channel_id} hop_data={hops_data[i]!r}")
1✔
3830
        assert final_cltv_abs <= cltv_abs, (final_cltv_abs, cltv_abs)
1✔
3831
        session_key = os.urandom(32) # session_key
1✔
3832
        # if we are forwarding a trampoline payment, add trampoline onion
3833
        if trampoline_onion:
1✔
3834
            self.logger.info(f'adding trampoline onion to final payload')
1✔
3835
            trampoline_payload = dict(hops_data[-1].payload)
1✔
3836
            trampoline_payload["trampoline_onion_packet"] = {
1✔
3837
                "version": trampoline_onion.version,
3838
                "public_key": trampoline_onion.public_key,
3839
                "hops_data": trampoline_onion.hops_data,
3840
                "hmac": trampoline_onion.hmac
3841
            }
3842
            hops_data[-1] = dataclasses.replace(hops_data[-1], payload=trampoline_payload)
1✔
3843
            if t_hops_data := trampoline_onion._debug_hops_data:  # None if trampoline-forwarding
1✔
3844
                t_route = trampoline_onion._debug_route
1✔
3845
                assert t_route is not None
1✔
3846
                self.logger.info(f"lnpeer.pay len(t_route)={len(t_route)}")
1✔
3847
                for i in range(len(t_route)):
1✔
3848
                    self.logger.info(f"  {i}: t_node={t_route[i].end_node.hex()} hop_data={t_hops_data[i]!r}")
1✔
3849
        # create onion packet
3850
        payment_path_pubkeys = [x.node_id for x in route]
1✔
3851
        onion = new_onion_packet(payment_path_pubkeys, session_key, hops_data, associated_data=payment_hash) # must use another sessionkey
1✔
3852
        self.logger.info(f"starting payment. len(route)={len(hops_data)}.")
1✔
3853
        # create htlc
3854
        if cltv_abs > local_height + lnutil.NBLOCK_CLTV_DELTA_TOO_FAR_INTO_FUTURE:
1✔
3855
            raise PaymentFailure(f"htlc expiry too far into future. (in {cltv_abs-local_height} blocks)")
×
3856
        return onion, amount_msat, cltv_abs, session_key
1✔
3857

3858
    def save_forwarding_failure(
1✔
3859
            self,
3860
            payment_key: str,
3861
            *,
3862
            error_bytes: Optional[bytes] = None,
3863
            failure_message: Optional['OnionRoutingFailure'] = None
3864
    ) -> None:
3865
        error_hex = error_bytes.hex() if error_bytes else None
1✔
3866
        failure_hex = failure_message.to_bytes().hex() if failure_message else None
1✔
3867
        self.forwarding_failures[payment_key] = (error_hex, failure_hex)
1✔
3868

3869
    def get_forwarding_failure(self, payment_key: str) -> Tuple[Optional[bytes], Optional['OnionRoutingFailure']]:
1✔
3870
        error_hex, failure_hex = self.forwarding_failures.get(payment_key, (None, None))
1✔
3871
        error_bytes = bytes.fromhex(error_hex) if error_hex else None
1✔
3872
        failure_message = OnionRoutingFailure.from_bytes(bytes.fromhex(failure_hex)) if failure_hex else None
1✔
3873
        return error_bytes, failure_message
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc