• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

spesmilo / electrum / 5558179037184000

18 Feb 2025 02:59PM UTC coverage: 60.523% (+0.02%) from 60.502%
5558179037184000

push

CirrusCI

ecdsa
simplify history-related commands:
 - reduce number of methods
 - use nametuples instead of dicts
 - only two types: OnchainHistoryItem and LightningHistoryItem
 - channel open/closes are groups
 - move capital gains into separate RPC

34 of 102 new or added lines in 5 files covered. (33.33%)

4 existing lines in 2 files now uncovered.

20266 of 33485 relevant lines covered (60.52%)

3.02 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

51.61
/electrum/lnworker.py
1
# Copyright (C) 2018 The Electrum developers
2
# Distributed under the MIT software license, see the accompanying
3
# file LICENCE or http://www.opensource.org/licenses/mit-license.php
4

5
import asyncio
5✔
6
import os
5✔
7
from decimal import Decimal
5✔
8
import random
5✔
9
import time
5✔
10
import operator
5✔
11
import enum
5✔
12
from enum import IntEnum, Enum
5✔
13
from typing import (Optional, Sequence, Tuple, List, Set, Dict, TYPE_CHECKING,
5✔
14
                    NamedTuple, Union, Mapping, Any, Iterable, AsyncGenerator, DefaultDict, Callable, Awaitable)
15
import threading
5✔
16
import socket
5✔
17
import json
5✔
18
from datetime import datetime, timezone
5✔
19
from functools import partial, cached_property
5✔
20
from collections import defaultdict
5✔
21
import concurrent
5✔
22
from concurrent import futures
5✔
23
import urllib.parse
5✔
24
import itertools
5✔
25

26
import aiohttp
5✔
27
import dns.resolver
5✔
28
import dns.exception
5✔
29
from aiorpcx import run_in_thread, NetAddress, ignore_after
5✔
30
from electrum_ecc import ecdsa_der_sig_from_ecdsa_sig64
5✔
31

32
from . import constants, util
5✔
33
from . import keystore
5✔
34
from .util import profiler, chunks, OldTaskGroup, ESocksProxy
5✔
35
from .invoices import Invoice, PR_UNPAID, PR_EXPIRED, PR_PAID, PR_INFLIGHT, PR_FAILED, PR_ROUTING, LN_EXPIRY_NEVER
5✔
36
from .invoices import BaseInvoice
5✔
37
from .util import NetworkRetryManager, JsonRPCClient, NotEnoughFunds
5✔
38
from .util import EventListener, event_listener
5✔
39
from .keystore import BIP32_KeyStore
5✔
40
from .bitcoin import COIN
5✔
41
from .bitcoin import opcodes, make_op_return, address_to_scripthash
5✔
42
from .transaction import Transaction
5✔
43
from .transaction import get_script_type_from_output_script
5✔
44
from .crypto import sha256
5✔
45
from .bip32 import BIP32Node
5✔
46
from .util import bfh, InvoiceError, resolve_dns_srv, is_ip_address, log_exceptions
5✔
47
from .crypto import chacha20_encrypt, chacha20_decrypt
5✔
48
from .util import ignore_exceptions, make_aiohttp_session
5✔
49
from .util import timestamp_to_datetime, random_shuffled_copy
5✔
50
from .util import MyEncoder, is_private_netaddress, UnrelatedTransactionException
5✔
51
from .util import LightningHistoryItem
5✔
52
from .logging import Logger
5✔
53
from .lntransport import LNTransport, LNResponderTransport, LNTransportBase, LNPeerAddr, split_host_port, extract_nodeid, ConnStringFormatError
5✔
54
from .lnpeer import Peer, LN_P2P_NETWORK_TIMEOUT
5✔
55
from .lnaddr import lnencode, LnAddr, lndecode
5✔
56
from .lnchannel import Channel, AbstractChannel
5✔
57
from .lnchannel import ChannelState, PeerState, HTLCWithStatus
5✔
58
from .lnrater import LNRater
5✔
59
from . import lnutil
5✔
60
from .lnutil import funding_output_script
5✔
61
from .lnutil import serialize_htlc_key, deserialize_htlc_key
5✔
62
from .bitcoin import DummyAddress
5✔
63
from .lnutil import (Outpoint,
5✔
64
                     get_compressed_pubkey_from_bech32,
65
                     PaymentFailure,
66
                     generate_keypair, LnKeyFamily, LOCAL, REMOTE,
67
                     MIN_FINAL_CLTV_DELTA_FOR_INVOICE,
68
                     NUM_MAX_EDGES_IN_PAYMENT_PATH, SENT, RECEIVED, HTLCOwner,
69
                     UpdateAddHtlc, Direction, LnFeatures, ShortChannelID,
70
                     HtlcLog, derive_payment_secret_from_payment_preimage,
71
                     NoPathFound, InvalidGossipMsg, FeeBudgetExceeded)
72
from .lnutil import ln_compare_features, IncompatibleLightningFeatures, PaymentFeeBudget
5✔
73
from .transaction import PartialTxOutput, PartialTransaction, PartialTxInput
5✔
74
from .lnonion import decode_onion_error, OnionFailureCode, OnionRoutingFailure, OnionPacket
5✔
75
from .lnmsg import decode_msg
5✔
76
from .i18n import _
5✔
77
from .lnrouter import (RouteEdge, LNPaymentRoute, LNPaymentPath, is_route_within_budget,
5✔
78
                       NoChannelPolicy, LNPathInconsistent)
79
from .address_synchronizer import TX_HEIGHT_LOCAL, TX_TIMESTAMP_INF
5✔
80
from . import lnsweep
5✔
81
from .lnwatcher import LNWalletWatcher
5✔
82
from .crypto import pw_encode_with_version_and_mac, pw_decode_with_version_and_mac
5✔
83
from .lnutil import ImportedChannelBackupStorage, OnchainChannelBackupStorage
5✔
84
from .lnchannel import ChannelBackup
5✔
85
from .channel_db import UpdateStatus, ChannelDBNotLoaded
5✔
86
from .channel_db import get_mychannel_info, get_mychannel_policy
5✔
87
from .submarine_swaps import SwapManager
5✔
88
from .channel_db import ChannelInfo, Policy
5✔
89
from .mpp_split import suggest_splits, SplitConfigRating
5✔
90
from .trampoline import create_trampoline_route_and_onion, is_legacy_relay
5✔
91
from .json_db import stored_in
5✔
92

93
if TYPE_CHECKING:
5✔
94
    from .network import Network
×
95
    from .wallet import Abstract_Wallet
×
96
    from .channel_db import ChannelDB
×
97
    from .simple_config import SimpleConfig
×
98

99

100
SAVED_PR_STATUS = [PR_PAID, PR_UNPAID] # status that are persisted
5✔
101

102
NUM_PEERS_TARGET = 4
5✔
103

104
# onchain channel backup data
105
CB_VERSION = 0
5✔
106
CB_MAGIC_BYTES = bytes([0, 0, 0, CB_VERSION])
5✔
107
NODE_ID_PREFIX_LEN = 16
5✔
108

109

110
from .trampoline import trampolines_by_id, hardcoded_trampoline_nodes, is_hardcoded_trampoline
5✔
111

112

113
class PaymentDirection(IntEnum):
5✔
114
    SENT = 0
5✔
115
    RECEIVED = 1
5✔
116
    SELF_PAYMENT = 2
5✔
117
    FORWARDING = 3
5✔
118

119

120
class PaymentInfo(NamedTuple):
5✔
121
    payment_hash: bytes
5✔
122
    amount_msat: Optional[int]
5✔
123
    direction: int
5✔
124
    status: int
5✔
125

126

127
# Note: these states are persisted in the wallet file.
128
# Do not modify them without performing a wallet db upgrade
129
class RecvMPPResolution(IntEnum):
5✔
130
    WAITING = 0
5✔
131
    EXPIRED = 1
5✔
132
    ACCEPTED = 2
5✔
133
    FAILED = 3
5✔
134

135

136
class ReceivedMPPStatus(NamedTuple):
5✔
137
    resolution: RecvMPPResolution
5✔
138
    expected_msat: int
5✔
139
    htlc_set: Set[Tuple[ShortChannelID, UpdateAddHtlc]]
5✔
140

141
    @stored_in('received_mpp_htlcs', tuple)
5✔
142
    def from_tuple(resolution, expected_msat, htlc_list) -> 'ReceivedMPPStatus':
5✔
143
        htlc_set = set([(ShortChannelID(bytes.fromhex(scid)), UpdateAddHtlc.from_tuple(*x)) for (scid,x) in htlc_list])
×
144
        return ReceivedMPPStatus(
×
145
            resolution=RecvMPPResolution(resolution),
146
            expected_msat=expected_msat,
147
            htlc_set=htlc_set)
148

149
SentHtlcKey = Tuple[bytes, ShortChannelID, int]  # RHASH, scid, htlc_id
5✔
150

151

152
class SentHtlcInfo(NamedTuple):
5✔
153
    route: LNPaymentRoute
5✔
154
    payment_secret_orig: bytes
5✔
155
    payment_secret_bucket: bytes
5✔
156
    amount_msat: int
5✔
157
    bucket_msat: int
5✔
158
    amount_receiver_msat: int
5✔
159
    trampoline_fee_level: Optional[int]
5✔
160
    trampoline_route: Optional[LNPaymentRoute]
5✔
161

162

163
class ErrorAddingPeer(Exception): pass
5✔
164

165

166
# set some feature flags as baseline for both LNWallet and LNGossip
167
# note that e.g. DATA_LOSS_PROTECT is needed for LNGossip as many peers require it
168
BASE_FEATURES = (
5✔
169
    LnFeatures(0)
170
    | LnFeatures.OPTION_DATA_LOSS_PROTECT_OPT
171
    | LnFeatures.OPTION_STATIC_REMOTEKEY_OPT
172
    | LnFeatures.VAR_ONION_OPT
173
    | LnFeatures.PAYMENT_SECRET_OPT
174
    | LnFeatures.OPTION_UPFRONT_SHUTDOWN_SCRIPT_OPT
175
)
176

177
# we do not want to receive unrequested gossip (see lnpeer.maybe_save_remote_update)
178
LNWALLET_FEATURES = (
5✔
179
    BASE_FEATURES
180
    | LnFeatures.OPTION_DATA_LOSS_PROTECT_REQ
181
    | LnFeatures.OPTION_STATIC_REMOTEKEY_REQ
182
    | LnFeatures.GOSSIP_QUERIES_REQ
183
    | LnFeatures.VAR_ONION_REQ
184
    | LnFeatures.PAYMENT_SECRET_REQ
185
    | LnFeatures.BASIC_MPP_OPT
186
    | LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ELECTRUM
187
    | LnFeatures.OPTION_SHUTDOWN_ANYSEGWIT_OPT
188
    | LnFeatures.OPTION_CHANNEL_TYPE_OPT
189
    | LnFeatures.OPTION_SCID_ALIAS_OPT
190
    | LnFeatures.OPTION_SUPPORT_LARGE_CHANNEL_OPT
191
)
192

193
LNGOSSIP_FEATURES = (
5✔
194
    BASE_FEATURES
195
    | LnFeatures.GOSSIP_QUERIES_OPT
196
    | LnFeatures.GOSSIP_QUERIES_REQ
197
)
198

199

200
class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
5✔
201

202
    def __init__(self, node_keypair, features: LnFeatures, *, config: 'SimpleConfig'):
5✔
203
        Logger.__init__(self)
5✔
204
        NetworkRetryManager.__init__(
5✔
205
            self,
206
            max_retry_delay_normal=3600,
207
            init_retry_delay_normal=600,
208
            max_retry_delay_urgent=300,
209
            init_retry_delay_urgent=4,
210
        )
211
        self.lock = threading.RLock()
5✔
212
        self.node_keypair = node_keypair
5✔
213
        self._peers = {}  # type: Dict[bytes, Peer]  # pubkey -> Peer  # needs self.lock
5✔
214
        self.taskgroup = OldTaskGroup()
5✔
215
        self.listen_server = None  # type: Optional[asyncio.AbstractServer]
5✔
216
        self.features = features
5✔
217
        self.network = None  # type: Optional[Network]
5✔
218
        self.config = config
5✔
219
        self.stopping_soon = False  # whether we are being shut down
5✔
220
        self.register_callbacks()
5✔
221

222
    @property
5✔
223
    def channel_db(self):
5✔
224
        return self.network.channel_db if self.network else None
×
225

226
    def uses_trampoline(self):
5✔
227
        return not bool(self.channel_db)
×
228

229
    @property
5✔
230
    def peers(self) -> Mapping[bytes, Peer]:
5✔
231
        """Returns a read-only copy of peers."""
232
        with self.lock:
×
233
            return self._peers.copy()
×
234

235
    def channels_for_peer(self, node_id: bytes) -> Dict[bytes, Channel]:
5✔
236
        return {}
×
237

238
    def get_node_alias(self, node_id: bytes) -> Optional[str]:
5✔
239
        """Returns the alias of the node, or None if unknown."""
240
        node_alias = None
×
241
        if not self.uses_trampoline():
×
242
            node_info = self.channel_db.get_node_info_for_node_id(node_id)
×
243
            if node_info:
×
244
                node_alias = node_info.alias
×
245
        else:
246
            for k, v in hardcoded_trampoline_nodes().items():
×
247
                if v.pubkey.startswith(node_id):
×
248
                    node_alias = k
×
249
                    break
×
250
        return node_alias
×
251

252
    async def maybe_listen(self):
5✔
253
        # FIXME: only one LNWorker can listen at a time (single port)
254
        listen_addr = self.config.LIGHTNING_LISTEN
×
255
        if listen_addr:
×
256
            self.logger.info(f'lightning_listen enabled. will try to bind: {listen_addr!r}')
×
257
            try:
×
258
                netaddr = NetAddress.from_string(listen_addr)
×
259
            except Exception as e:
×
260
                self.logger.error(f"failed to parse config key '{self.config.cv.LIGHTNING_LISTEN.key()}'. got: {e!r}")
×
261
                return
×
262
            addr = str(netaddr.host)
×
263
            async def cb(reader, writer):
×
264
                transport = LNResponderTransport(self.node_keypair.privkey, reader, writer)
×
265
                try:
×
266
                    node_id = await transport.handshake()
×
267
                except Exception as e:
×
268
                    self.logger.info(f'handshake failure from incoming connection: {e!r}')
×
269
                    return
×
270
                await self._add_peer_from_transport(node_id=node_id, transport=transport)
×
271
            try:
×
272
                self.listen_server = await asyncio.start_server(cb, addr, netaddr.port)
×
273
            except OSError as e:
×
274
                self.logger.error(f"cannot listen for lightning p2p. error: {e!r}")
×
275

276
    async def main_loop(self):
5✔
277
        self.logger.info("starting taskgroup.")
×
278
        try:
×
279
            async with self.taskgroup as group:
×
280
                await group.spawn(asyncio.Event().wait)  # run forever (until cancel)
×
281
        except Exception as e:
×
282
            self.logger.exception("taskgroup died.")
×
283
        finally:
284
            self.logger.info("taskgroup stopped.")
×
285

286
    async def _maintain_connectivity(self):
5✔
287
        while True:
×
288
            await asyncio.sleep(1)
×
289
            if self.stopping_soon:
×
290
                return
×
291
            now = time.time()
×
292
            if len(self._peers) >= NUM_PEERS_TARGET:
×
293
                continue
×
294
            peers = await self._get_next_peers_to_try()
×
295
            for peer in peers:
×
296
                if self._can_retry_addr(peer, now=now):
×
297
                    try:
×
298
                        await self._add_peer(peer.host, peer.port, peer.pubkey)
×
299
                    except ErrorAddingPeer as e:
×
300
                        self.logger.info(f"failed to add peer: {peer}. exc: {e!r}")
×
301

302
    async def _add_peer(self, host: str, port: int, node_id: bytes) -> Peer:
5✔
303
        if node_id in self._peers:
×
304
            return self._peers[node_id]
×
305
        port = int(port)
×
306
        peer_addr = LNPeerAddr(host, port, node_id)
×
307
        self._trying_addr_now(peer_addr)
×
308
        self.logger.info(f"adding peer {peer_addr}")
×
309
        if node_id == self.node_keypair.pubkey:
×
310
            raise ErrorAddingPeer("cannot connect to self")
×
311
        transport = LNTransport(self.node_keypair.privkey, peer_addr,
×
312
                                e_proxy=ESocksProxy.from_network_settings(self.network))
313
        peer = await self._add_peer_from_transport(node_id=node_id, transport=transport)
×
314
        assert peer
×
315
        return peer
×
316

317
    async def _add_peer_from_transport(self, *, node_id: bytes, transport: LNTransportBase) -> Optional[Peer]:
5✔
318
        with self.lock:
×
319
            existing_peer = self._peers.get(node_id)
×
320
            if existing_peer:
×
321
                # Two instances of the same wallet are attempting to connect simultaneously.
322
                # If we let the new connection replace the existing one, the two instances might
323
                # both keep trying to reconnect, resulting in neither being usable.
324
                if existing_peer.is_initialized():
×
325
                    # give priority to the existing connection
326
                    return
×
327
                else:
328
                    # Use the new connection. (e.g. old peer might be an outgoing connection
329
                    # for an outdated host/port that will never connect)
330
                    existing_peer.close_and_cleanup()
×
331
            peer = Peer(self, node_id, transport)
×
332
            assert node_id not in self._peers
×
333
            self._peers[node_id] = peer
×
334
        await self.taskgroup.spawn(peer.main_loop())
×
335
        return peer
×
336

337
    def peer_closed(self, peer: Peer) -> None:
5✔
338
        with self.lock:
×
339
            peer2 = self._peers.get(peer.pubkey)
×
340
            if peer2 is peer:
×
341
                self._peers.pop(peer.pubkey)
×
342

343
    def num_peers(self) -> int:
5✔
344
        return sum([p.is_initialized() for p in self.peers.values()])
×
345

346
    def start_network(self, network: 'Network'):
5✔
347
        assert network
×
348
        assert self.network is None, "already started"
×
349
        self.network = network
×
350
        self._add_peers_from_config()
×
351
        asyncio.run_coroutine_threadsafe(self.main_loop(), self.network.asyncio_loop)
×
352

353
    async def stop(self):
5✔
354
        if self.listen_server:
5✔
355
            self.listen_server.close()
×
356
        self.unregister_callbacks()
5✔
357
        await self.taskgroup.cancel_remaining()
5✔
358

359
    def _add_peers_from_config(self):
5✔
360
        peer_list = self.config.LIGHTNING_PEERS or []
×
361
        for host, port, pubkey in peer_list:
×
362
            asyncio.run_coroutine_threadsafe(
×
363
                self._add_peer(host, int(port), bfh(pubkey)),
364
                self.network.asyncio_loop)
365

366
    def is_good_peer(self, peer: LNPeerAddr) -> bool:
5✔
367
        # the purpose of this method is to filter peers that advertise the desired feature bits
368
        # it is disabled for now, because feature bits published in node announcements seem to be unreliable
369
        return True
×
370
        node_id = peer.pubkey
371
        node = self.channel_db._nodes.get(node_id)
372
        if not node:
373
            return False
374
        try:
375
            ln_compare_features(self.features, node.features)
376
        except IncompatibleLightningFeatures:
377
            return False
378
        #self.logger.info(f'is_good {peer.host}')
379
        return True
380

381
    def on_peer_successfully_established(self, peer: Peer) -> None:
5✔
382
        if isinstance(peer.transport, LNTransport):
5✔
383
            peer_addr = peer.transport.peer_addr
×
384
            # reset connection attempt count
385
            self._on_connection_successfully_established(peer_addr)
×
386
            if not self.uses_trampoline():
×
387
                # add into channel db
388
                self.channel_db.add_recent_peer(peer_addr)
×
389
            # save network address into channels we might have with peer
390
            for chan in peer.channels.values():
×
391
                chan.add_or_update_peer_addr(peer_addr)
×
392

393
    async def _get_next_peers_to_try(self) -> Sequence[LNPeerAddr]:
5✔
394
        now = time.time()
×
395
        await self.channel_db.data_loaded.wait()
×
396
        # first try from recent peers
397
        recent_peers = self.channel_db.get_recent_peers()
×
398
        for peer in recent_peers:
×
399
            if not peer:
×
400
                continue
×
401
            if peer.pubkey in self._peers:
×
402
                continue
×
403
            if not self._can_retry_addr(peer, now=now):
×
404
                continue
×
405
            if not self.is_good_peer(peer):
×
406
                continue
×
407
            return [peer]
×
408
        # try random peer from graph
409
        unconnected_nodes = self.channel_db.get_200_randomly_sorted_nodes_not_in(self.peers.keys())
×
410
        if unconnected_nodes:
×
411
            for node_id in unconnected_nodes:
×
412
                addrs = self.channel_db.get_node_addresses(node_id)
×
413
                if not addrs:
×
414
                    continue
×
415
                host, port, timestamp = self.choose_preferred_address(list(addrs))
×
416
                try:
×
417
                    peer = LNPeerAddr(host, port, node_id)
×
418
                except ValueError:
×
419
                    continue
×
420
                if not self._can_retry_addr(peer, now=now):
×
421
                    continue
×
422
                if not self.is_good_peer(peer):
×
423
                    continue
×
424
                #self.logger.info('taking random ln peer from our channel db')
425
                return [peer]
×
426

427
        # getting desperate... let's try hardcoded fallback list of peers
428
        fallback_list = constants.net.FALLBACK_LN_NODES
×
429
        fallback_list = [peer for peer in fallback_list if self._can_retry_addr(peer, now=now)]
×
430
        if fallback_list:
×
431
            return [random.choice(fallback_list)]
×
432

433
        # last resort: try dns seeds (BOLT-10)
434
        return await run_in_thread(self._get_peers_from_dns_seeds)
×
435

436
    def _get_peers_from_dns_seeds(self) -> Sequence[LNPeerAddr]:
5✔
437
        # NOTE: potentially long blocking call, do not run directly on asyncio event loop.
438
        # Return several peers to reduce the number of dns queries.
439
        if not constants.net.LN_DNS_SEEDS:
×
440
            return []
×
441
        dns_seed = random.choice(constants.net.LN_DNS_SEEDS)
×
442
        self.logger.info('asking dns seed "{}" for ln peers'.format(dns_seed))
×
443
        try:
×
444
            # note: this might block for several seconds
445
            # this will include bech32-encoded-pubkeys and ports
446
            srv_answers = resolve_dns_srv('r{}.{}'.format(
×
447
                constants.net.LN_REALM_BYTE, dns_seed))
448
        except dns.exception.DNSException as e:
×
449
            self.logger.info(f'failed querying (1) dns seed "{dns_seed}" for ln peers: {repr(e)}')
×
450
            return []
×
451
        random.shuffle(srv_answers)
×
452
        num_peers = 2 * NUM_PEERS_TARGET
×
453
        srv_answers = srv_answers[:num_peers]
×
454
        # we now have pubkeys and ports but host is still needed
455
        peers = []
×
456
        for srv_ans in srv_answers:
×
457
            try:
×
458
                # note: this might block for several seconds
459
                answers = dns.resolver.resolve(srv_ans['host'])
×
460
            except dns.exception.DNSException as e:
×
461
                self.logger.info(f'failed querying (2) dns seed "{dns_seed}" for ln peers: {repr(e)}')
×
462
                continue
×
463
            try:
×
464
                ln_host = str(answers[0])
×
465
                port = int(srv_ans['port'])
×
466
                bech32_pubkey = srv_ans['host'].split('.')[0]
×
467
                pubkey = get_compressed_pubkey_from_bech32(bech32_pubkey)
×
468
                peers.append(LNPeerAddr(ln_host, port, pubkey))
×
469
            except Exception as e:
×
470
                self.logger.info(f'error with parsing peer from dns seed: {repr(e)}')
×
471
                continue
×
472
        self.logger.info(f'got {len(peers)} ln peers from dns seed')
×
473
        return peers
×
474

475
    @staticmethod
5✔
476
    def choose_preferred_address(addr_list: Sequence[Tuple[str, int, int]]) -> Tuple[str, int, int]:
5✔
477
        assert len(addr_list) >= 1
×
478
        # choose the most recent one that is an IP
479
        for host, port, timestamp in sorted(addr_list, key=lambda a: -a[2]):
×
480
            if is_ip_address(host):
×
481
                return host, port, timestamp
×
482
        # otherwise choose one at random
483
        # TODO maybe filter out onion if not on tor?
484
        choice = random.choice(addr_list)
×
485
        return choice
×
486

487
    @event_listener
5✔
488
    def on_event_proxy_set(self, *args):
5✔
489
        for peer in self.peers.values():
×
490
            peer.close_and_cleanup()
×
491
        self._clear_addr_retry_times()
×
492

493
    @log_exceptions
5✔
494
    async def add_peer(self, connect_str: str) -> Peer:
5✔
495
        node_id, rest = extract_nodeid(connect_str)
×
496
        peer = self._peers.get(node_id)
×
497
        if not peer:
×
498
            if rest is not None:
×
499
                host, port = split_host_port(rest)
×
500
            else:
501
                if self.uses_trampoline():
×
502
                    addr = trampolines_by_id().get(node_id)
×
503
                    if not addr:
×
504
                        raise ConnStringFormatError(_('Address unknown for node:') + ' ' + node_id.hex())
×
505
                    host, port = addr.host, addr.port
×
506
                else:
507
                    addrs = self.channel_db.get_node_addresses(node_id)
×
508
                    if not addrs:
×
509
                        raise ConnStringFormatError(_('Don\'t know any addresses for node:') + ' ' + node_id.hex())
×
510
                    host, port, timestamp = self.choose_preferred_address(list(addrs))
×
511
            port = int(port)
×
512

513
            if not self.network.proxy:
×
514
                # Try DNS-resolving the host (if needed). This is simply so that
515
                # the caller gets a nice exception if it cannot be resolved.
516
                # (we don't do the DNS lookup if a proxy is set, to avoid a DNS-leak)
517
                if host.endswith('.onion'):
×
518
                    raise ConnStringFormatError(_('.onion address, but no proxy configured'))
×
519
                try:
×
520
                    await asyncio.get_running_loop().getaddrinfo(host, port)
×
521
                except socket.gaierror:
×
522
                    raise ConnStringFormatError(_('Hostname does not resolve (getaddrinfo failed)'))
×
523

524
            # add peer
525
            peer = await self._add_peer(host, port, node_id)
×
526
        return peer
×
527

528

529
class LNGossip(LNWorker):
5✔
530
    max_age = 14*24*3600
5✔
531
    LOGGING_SHORTCUT = 'g'
5✔
532

533
    def __init__(self, config: 'SimpleConfig'):
5✔
534
        seed = os.urandom(32)
×
535
        node = BIP32Node.from_rootseed(seed, xtype='standard')
×
536
        xprv = node.to_xprv()
×
537
        node_keypair = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.NODE_KEY)
×
538
        LNWorker.__init__(self, node_keypair, LNGOSSIP_FEATURES, config=config)
×
539
        self.unknown_ids = set()
×
540

541
    def start_network(self, network: 'Network'):
5✔
542
        super().start_network(network)
×
543
        for coro in [
×
544
                self._maintain_connectivity(),
545
                self.maintain_db(),
546
        ]:
547
            tg_coro = self.taskgroup.spawn(coro)
×
548
            asyncio.run_coroutine_threadsafe(tg_coro, self.network.asyncio_loop)
×
549

550
    async def maintain_db(self):
5✔
551
        await self.channel_db.data_loaded.wait()
×
552
        while True:
×
553
            if len(self.unknown_ids) == 0:
×
554
                self.channel_db.prune_old_policies(self.max_age)
×
555
                self.channel_db.prune_orphaned_channels()
×
556
            await asyncio.sleep(120)
×
557

558
    async def add_new_ids(self, ids: Iterable[bytes]):
5✔
559
        known = self.channel_db.get_channel_ids()
×
560
        new = set(ids) - set(known)
×
561
        self.unknown_ids.update(new)
×
562
        util.trigger_callback('unknown_channels', len(self.unknown_ids))
×
563
        util.trigger_callback('gossip_peers', self.num_peers())
×
564
        util.trigger_callback('ln_gossip_sync_progress')
×
565

566
    def get_ids_to_query(self) -> Sequence[bytes]:
5✔
567
        N = 500
×
568
        l = list(self.unknown_ids)
×
569
        self.unknown_ids = set(l[N:])
×
570
        util.trigger_callback('unknown_channels', len(self.unknown_ids))
×
571
        util.trigger_callback('ln_gossip_sync_progress')
×
572
        return l[0:N]
×
573

574
    def get_sync_progress_estimate(self) -> Tuple[Optional[int], Optional[int], Optional[int]]:
5✔
575
        """Estimates the gossip synchronization process and returns the number
576
        of synchronized channels, the total channels in the network and a
577
        rescaled percentage of the synchronization process."""
578
        if self.num_peers() == 0:
×
579
            return None, None, None
×
580
        nchans_with_0p, nchans_with_1p, nchans_with_2p = self.channel_db.get_num_channels_partitioned_by_policy_count()
×
581
        num_db_channels = nchans_with_0p + nchans_with_1p + nchans_with_2p
×
582
        # some channels will never have two policies (only one is in gossip?...)
583
        # so if we have at least 1 policy for a channel, we consider that channel "complete" here
584
        current_est = num_db_channels - nchans_with_0p
×
585
        total_est = len(self.unknown_ids) + num_db_channels
×
586

587
        progress = current_est / total_est if total_est and current_est else 0
×
588
        progress_percent = (1.0 / 0.95 * progress) * 100
×
589
        progress_percent = min(progress_percent, 100)
×
590
        progress_percent = round(progress_percent)
×
591
        # take a minimal number of synchronized channels to get a more accurate
592
        # percentage estimate
593
        if current_est < 200:
×
594
            progress_percent = 0
×
595
        return current_est, total_est, progress_percent
×
596

597
    async def process_gossip(self, chan_anns, node_anns, chan_upds):
5✔
598
        # note: we run in the originating peer's TaskGroup, so we can safely raise here
599
        #       and disconnect only from that peer
600
        await self.channel_db.data_loaded.wait()
×
601
        self.logger.debug(f'process_gossip {len(chan_anns)} {len(node_anns)} {len(chan_upds)}')
×
602
        # channel announcements
603
        def process_chan_anns():
×
604
            for payload in chan_anns:
×
605
                self.channel_db.verify_channel_announcement(payload)
×
606
            self.channel_db.add_channel_announcements(chan_anns)
×
607
        await run_in_thread(process_chan_anns)
×
608
        # node announcements
609
        def process_node_anns():
×
610
            for payload in node_anns:
×
611
                self.channel_db.verify_node_announcement(payload)
×
612
            self.channel_db.add_node_announcements(node_anns)
×
613
        await run_in_thread(process_node_anns)
×
614
        # channel updates
615
        categorized_chan_upds = await run_in_thread(partial(
×
616
            self.channel_db.add_channel_updates,
617
            chan_upds,
618
            max_age=self.max_age))
619
        orphaned = categorized_chan_upds.orphaned
×
620
        if orphaned:
×
621
            self.logger.info(f'adding {len(orphaned)} unknown channel ids')
×
622
            orphaned_ids = [c['short_channel_id'] for c in orphaned]
×
623
            await self.add_new_ids(orphaned_ids)
×
624
        if categorized_chan_upds.good:
×
625
            self.logger.debug(f'process_gossip: {len(categorized_chan_upds.good)}/{len(chan_upds)}')
×
626

627

628
class PaySession(Logger):
5✔
629
    def __init__(
5✔
630
            self,
631
            *,
632
            payment_hash: bytes,
633
            payment_secret: bytes,
634
            initial_trampoline_fee_level: int,
635
            invoice_features: int,
636
            r_tags,
637
            min_final_cltv_delta: int,  # delta for last node (typically from invoice)
638
            amount_to_pay: int,  # total payment amount final receiver will get
639
            invoice_pubkey: bytes,
640
            uses_trampoline: bool,  # whether sender uses trampoline or gossip
641
            use_two_trampolines: bool,  # whether legacy payments will try to use two trampolines
642
    ):
643
        assert payment_hash
5✔
644
        assert payment_secret
5✔
645
        self.payment_hash = payment_hash
5✔
646
        self.payment_secret = payment_secret
5✔
647
        self.payment_key = payment_hash + payment_secret
5✔
648
        Logger.__init__(self)
5✔
649

650
        self.invoice_features = LnFeatures(invoice_features)
5✔
651
        self.r_tags = r_tags
5✔
652
        self.min_final_cltv_delta = min_final_cltv_delta
5✔
653
        self.amount_to_pay = amount_to_pay
5✔
654
        self.invoice_pubkey = invoice_pubkey
5✔
655

656
        self.sent_htlcs_q = asyncio.Queue()  # type: asyncio.Queue[HtlcLog]
5✔
657
        self.start_time = time.time()
5✔
658

659
        self.uses_trampoline = uses_trampoline
5✔
660
        self.trampoline_fee_level = initial_trampoline_fee_level
5✔
661
        self.failed_trampoline_routes = []
5✔
662
        self.use_two_trampolines = use_two_trampolines
5✔
663
        self._sent_buckets = dict()  # psecret_bucket -> (amount_sent, amount_failed)
5✔
664

665
        self._amount_inflight = 0  # what we sent in htlcs (that receiver gets, without fees)
5✔
666
        self._nhtlcs_inflight = 0
5✔
667
        self.is_active = True  # is still trying to send new htlcs?
5✔
668

669
    def diagnostic_name(self):
5✔
670
        pkey = sha256(self.payment_key)
5✔
671
        return f"{self.payment_hash[:4].hex()}-{pkey[:2].hex()}"
5✔
672

673
    def maybe_raise_trampoline_fee(self, htlc_log: HtlcLog):
5✔
674
        if htlc_log.trampoline_fee_level == self.trampoline_fee_level:
5✔
675
            self.trampoline_fee_level += 1
5✔
676
            self.failed_trampoline_routes = []
5✔
677
            self.logger.info(f'raising trampoline fee level {self.trampoline_fee_level}')
5✔
678
        else:
679
            self.logger.info(f'NOT raising trampoline fee level, already at {self.trampoline_fee_level}')
5✔
680

681
    def handle_failed_trampoline_htlc(self, *, htlc_log: HtlcLog, failure_msg: OnionRoutingFailure):
5✔
682
        # FIXME The trampoline nodes in the path are chosen randomly.
683
        #       Some of the errors might depend on how we have chosen them.
684
        #       Having more attempts is currently useful in part because of the randomness,
685
        #       instead we should give feedback to create_routes_for_payment.
686
        # Sometimes the trampoline node fails to send a payment and returns
687
        # TEMPORARY_CHANNEL_FAILURE, while it succeeds with a higher trampoline fee.
688
        if failure_msg.code in (
5✔
689
                OnionFailureCode.TRAMPOLINE_FEE_INSUFFICIENT,
690
                OnionFailureCode.TRAMPOLINE_EXPIRY_TOO_SOON,
691
                OnionFailureCode.TEMPORARY_CHANNEL_FAILURE):
692
            # TODO: parse the node policy here (not returned by eclair yet)
693
            # TODO: erring node is always the first trampoline even if second
694
            #  trampoline demands more fees, we can't influence this
695
            self.maybe_raise_trampoline_fee(htlc_log)
5✔
696
        elif self.use_two_trampolines:
5✔
697
            self.use_two_trampolines = False
×
698
        elif failure_msg.code in (
5✔
699
                OnionFailureCode.UNKNOWN_NEXT_PEER,
700
                OnionFailureCode.TEMPORARY_NODE_FAILURE):
701
            trampoline_route = htlc_log.route
5✔
702
            r = [hop.end_node.hex() for hop in trampoline_route]
5✔
703
            self.logger.info(f'failed trampoline route: {r}')
5✔
704
            if r not in self.failed_trampoline_routes:
5✔
705
                self.failed_trampoline_routes.append(r)
5✔
706
            else:
707
                pass  # maybe the route was reused between different MPP parts
×
708
        else:
709
            raise PaymentFailure(failure_msg.code_name())
5✔
710

711
    async def wait_for_one_htlc_to_resolve(self) -> HtlcLog:
5✔
712
        self.logger.info(f"waiting... amount_inflight={self._amount_inflight}. nhtlcs_inflight={self._nhtlcs_inflight}")
5✔
713
        htlc_log = await self.sent_htlcs_q.get()
5✔
714
        self._amount_inflight -= htlc_log.amount_msat
5✔
715
        self._nhtlcs_inflight -= 1
5✔
716
        if self._amount_inflight < 0 or self._nhtlcs_inflight < 0:
5✔
717
            raise Exception(f"amount_inflight={self._amount_inflight}, nhtlcs_inflight={self._nhtlcs_inflight}. both should be >= 0 !")
×
718
        return htlc_log
5✔
719

720
    def add_new_htlc(self, sent_htlc_info: SentHtlcInfo):
5✔
721
        self._nhtlcs_inflight += 1
5✔
722
        self._amount_inflight += sent_htlc_info.amount_receiver_msat
5✔
723
        if self._amount_inflight > self.amount_to_pay:  # safety belts
5✔
724
            raise Exception(f"amount_inflight={self._amount_inflight} > amount_to_pay={self.amount_to_pay}")
×
725
        shi = sent_htlc_info
5✔
726
        bkey = shi.payment_secret_bucket
5✔
727
        # if we sent MPP to a trampoline, add item to sent_buckets
728
        if self.uses_trampoline and shi.amount_msat != shi.bucket_msat:
5✔
729
            if bkey not in self._sent_buckets:
5✔
730
                self._sent_buckets[bkey] = (0, 0)
5✔
731
            amount_sent, amount_failed = self._sent_buckets[bkey]
5✔
732
            amount_sent += shi.amount_receiver_msat
5✔
733
            self._sent_buckets[bkey] = amount_sent, amount_failed
5✔
734

735
    def on_htlc_fail_get_fail_amt_to_propagate(self, sent_htlc_info: SentHtlcInfo) -> Optional[int]:
5✔
736
        shi = sent_htlc_info
5✔
737
        # check sent_buckets if we use trampoline
738
        bkey = shi.payment_secret_bucket
5✔
739
        if self.uses_trampoline and bkey in self._sent_buckets:
5✔
740
            amount_sent, amount_failed = self._sent_buckets[bkey]
5✔
741
            amount_failed += shi.amount_receiver_msat
5✔
742
            self._sent_buckets[bkey] = amount_sent, amount_failed
5✔
743
            if amount_sent != amount_failed:
5✔
744
                self.logger.info('bucket still active...')
5✔
745
                return None
5✔
746
            self.logger.info('bucket failed')
5✔
747
            return amount_sent
5✔
748
        # not using trampoline buckets
749
        return shi.amount_receiver_msat
5✔
750

751
    def get_outstanding_amount_to_send(self) -> int:
5✔
752
        return self.amount_to_pay - self._amount_inflight
5✔
753

754
    def can_be_deleted(self) -> bool:
5✔
755
        """Returns True iff finished sending htlcs AND all pending htlcs have resolved."""
756
        if self.is_active:
5✔
757
            return False
5✔
758
        # note: no one is consuming from sent_htlcs_q anymore
759
        nhtlcs_resolved = self.sent_htlcs_q.qsize()
5✔
760
        assert nhtlcs_resolved <= self._nhtlcs_inflight
5✔
761
        return nhtlcs_resolved == self._nhtlcs_inflight
5✔
762

763

764
class LNWallet(LNWorker):
5✔
765

766
    lnwatcher: Optional['LNWalletWatcher']
5✔
767
    MPP_EXPIRY = 120
5✔
768
    TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 3  # seconds
5✔
769
    PAYMENT_TIMEOUT = 120
5✔
770
    MPP_SPLIT_PART_FRACTION = 0.2
5✔
771
    MPP_SPLIT_PART_MINAMT_MSAT = 5_000_000
5✔
772

773
    def __init__(self, wallet: 'Abstract_Wallet', xprv):
5✔
774
        self.wallet = wallet
5✔
775
        self.config = wallet.config
5✔
776
        self.db = wallet.db
5✔
777
        self.node_keypair = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.NODE_KEY)
5✔
778
        self.backup_key = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.BACKUP_CIPHER).privkey
5✔
779
        self.static_payment_key = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.PAYMENT_BASE)
5✔
780
        self.payment_secret_key = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.PAYMENT_SECRET_KEY).privkey
5✔
781
        self.funding_root_keypair = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.FUNDING_ROOT_KEY)
5✔
782
        Logger.__init__(self)
5✔
783
        features = LNWALLET_FEATURES
5✔
784
        if self.config.ENABLE_ANCHOR_CHANNELS:
5✔
785
            features |= LnFeatures.OPTION_ANCHORS_ZERO_FEE_HTLC_OPT
×
786
        if self.config.ACCEPT_ZEROCONF_CHANNELS:
5✔
787
            features |= LnFeatures.OPTION_ZEROCONF_OPT
×
788
        LNWorker.__init__(self, self.node_keypair, features, config=self.config)
5✔
789
        self.lnwatcher = None
5✔
790
        self.lnrater: LNRater = None
5✔
791
        self.payment_info = self.db.get_dict('lightning_payments')     # RHASH -> amount, direction, is_paid
5✔
792
        self.preimages = self.db.get_dict('lightning_preimages')   # RHASH -> preimage
5✔
793
        self._bolt11_cache = {}
5✔
794
        # note: this sweep_address is only used as fallback; as it might result in address-reuse
795
        self.logs = defaultdict(list)  # type: Dict[str, List[HtlcLog]]  # key is RHASH  # (not persisted)
5✔
796
        # used in tests
797
        self.enable_htlc_settle = True
5✔
798
        self.enable_htlc_settle_onchain = True
5✔
799
        self.enable_htlc_forwarding = True
5✔
800

801
        # note: accessing channels (besides simple lookup) needs self.lock!
802
        self._channels = {}  # type: Dict[bytes, Channel]
5✔
803
        channels = self.db.get_dict("channels")
5✔
804
        for channel_id, c in random_shuffled_copy(channels.items()):
5✔
805
            self._channels[bfh(channel_id)] = chan = Channel(c, lnworker=self)
5✔
806
            self.wallet.set_reserved_addresses_for_chan(chan, reserved=True)
5✔
807

808
        self._channel_backups = {}  # type: Dict[bytes, ChannelBackup]
5✔
809
        # order is important: imported should overwrite onchain
810
        for name in ["onchain_channel_backups", "imported_channel_backups"]:
5✔
811
            channel_backups = self.db.get_dict(name)
5✔
812
            for channel_id, storage in channel_backups.items():
5✔
813
                self._channel_backups[bfh(channel_id)] = cb = ChannelBackup(storage, lnworker=self)
×
814
                self.wallet.set_reserved_addresses_for_chan(cb, reserved=True)
×
815

816
        self._paysessions = dict()                      # type: Dict[bytes, PaySession]
5✔
817
        self.sent_htlcs_info = dict()                   # type: Dict[SentHtlcKey, SentHtlcInfo]
5✔
818
        self.received_mpp_htlcs = self.db.get_dict('received_mpp_htlcs')   # type: Dict[str, ReceivedMPPStatus]  # payment_key -> ReceivedMPPStatus
5✔
819

820
        # detect inflight payments
821
        self.inflight_payments = set()        # (not persisted) keys of invoices that are in PR_INFLIGHT state
5✔
822
        for payment_hash in self.get_payments(status='inflight').keys():
5✔
823
            self.set_invoice_status(payment_hash.hex(), PR_INFLIGHT)
×
824

825
        # payment forwarding
826
        self.active_forwardings = self.db.get_dict('active_forwardings')    # type: Dict[str, List[str]]        # Dict: payment_key -> list of htlc_keys
5✔
827
        self.forwarding_failures = self.db.get_dict('forwarding_failures')  # type: Dict[str, Tuple[str, str]]  # Dict: payment_key -> (error_bytes, error_message)
5✔
828
        self.downstream_to_upstream_htlc = {}                               # type: Dict[str, str]              # Dict: htlc_key -> htlc_key (not persisted)
5✔
829

830
        # payment_hash -> callback:
831
        self.hold_invoice_callbacks = {}                # type: Dict[bytes, Callable[[bytes], Awaitable[None]]]
5✔
832
        self.payment_bundles = []                       # lists of hashes. todo:persist
5✔
833

834
        self.nostr_keypair = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.NOSTR_KEY)
5✔
835
        self.swap_manager = SwapManager(wallet=self.wallet, lnworker=self)
5✔
836

837
    def has_deterministic_node_id(self) -> bool:
5✔
838
        return bool(self.db.get('lightning_xprv'))
×
839

840
    def can_have_recoverable_channels(self) -> bool:
5✔
841
        return (self.has_deterministic_node_id()
×
842
                and not self.config.LIGHTNING_LISTEN)
843

844
    def has_recoverable_channels(self) -> bool:
5✔
845
        """Whether *future* channels opened by this wallet would be recoverable
846
        from seed (via putting OP_RETURN outputs into funding txs).
847
        """
848
        return (self.can_have_recoverable_channels()
×
849
                and self.config.LIGHTNING_USE_RECOVERABLE_CHANNELS)
850

851
    @property
5✔
852
    def channels(self) -> Mapping[bytes, Channel]:
5✔
853
        """Returns a read-only copy of channels."""
854
        with self.lock:
5✔
855
            return self._channels.copy()
5✔
856

857
    @property
5✔
858
    def channel_backups(self) -> Mapping[bytes, ChannelBackup]:
5✔
859
        """Returns a read-only copy of channels."""
860
        with self.lock:
×
861
            return self._channel_backups.copy()
×
862

863
    def get_channel_objects(self) -> Mapping[bytes, AbstractChannel]:
5✔
864
        r = self.channel_backups
×
865
        r.update(self.channels)
×
866
        return r
×
867

868
    def get_channel_by_id(self, channel_id: bytes) -> Optional[Channel]:
5✔
869
        return self._channels.get(channel_id, None)
5✔
870

871
    def diagnostic_name(self):
5✔
872
        return self.wallet.diagnostic_name()
5✔
873

874
    @ignore_exceptions
5✔
875
    @log_exceptions
5✔
876
    async def sync_with_remote_watchtower(self):
5✔
877
        self.watchtower_ctns = {}
×
878
        while True:
×
879
            # periodically poll if the user updated 'watchtower_url'
880
            await asyncio.sleep(5)
×
881
            watchtower_url = self.config.WATCHTOWER_CLIENT_URL
×
882
            if not watchtower_url:
×
883
                continue
×
884
            parsed_url = urllib.parse.urlparse(watchtower_url)
×
885
            if not (parsed_url.scheme == 'https' or is_private_netaddress(parsed_url.hostname)):
×
886
                self.logger.warning(f"got watchtower URL for remote tower but we won't use it! "
×
887
                                    f"can only use HTTPS (except if private IP): not using {watchtower_url!r}")
888
                continue
×
889
            # try to sync with the remote watchtower
890
            try:
×
891
                async with make_aiohttp_session(proxy=self.network.proxy) as session:
×
892
                    watchtower = JsonRPCClient(session, watchtower_url)
×
893
                    watchtower.add_method('get_ctn')
×
894
                    watchtower.add_method('add_sweep_tx')
×
895
                    for chan in self.channels.values():
×
896
                        await self.sync_channel_with_watchtower(chan, watchtower)
×
897
            except aiohttp.client_exceptions.ClientConnectorError:
×
898
                self.logger.info(f'could not contact remote watchtower {watchtower_url}')
×
899

900
    def get_watchtower_ctn(self, channel_point):
5✔
901
        return self.watchtower_ctns.get(channel_point)
×
902

903
    async def sync_channel_with_watchtower(self, chan: Channel, watchtower):
5✔
904
        outpoint = chan.funding_outpoint.to_str()
×
905
        addr = chan.get_funding_address()
×
906
        current_ctn = chan.get_oldest_unrevoked_ctn(REMOTE)
×
907
        watchtower_ctn = await watchtower.get_ctn(outpoint, addr)
×
908
        for ctn in range(watchtower_ctn + 1, current_ctn):
×
909
            sweeptxs = chan.create_sweeptxs_for_watchtower(ctn)
×
910
            for tx in sweeptxs:
×
911
                await watchtower.add_sweep_tx(outpoint, ctn, tx.inputs()[0].prevout.to_str(), tx.serialize())
×
912
            self.watchtower_ctns[outpoint] = ctn
×
913

914
    def start_network(self, network: 'Network'):
5✔
915
        super().start_network(network)
×
916
        self.lnwatcher = LNWalletWatcher(self, network)
×
917
        self.swap_manager.start_network(network)
×
918
        self.lnrater = LNRater(self, network)
×
919

920
        for chan in self.channels.values():
×
921
            if chan.need_to_subscribe():
×
922
                self.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address())
×
923
        for cb in self.channel_backups.values():
×
924
            if cb.need_to_subscribe():
×
925
                self.lnwatcher.add_channel(cb.funding_outpoint.to_str(), cb.get_funding_address())
×
926

927
        for coro in [
×
928
                self.maybe_listen(),
929
                self.lnwatcher.trigger_callbacks(), # shortcut (don't block) if funding tx locked and verified
930
                self.reestablish_peers_and_channels(),
931
                self.sync_with_remote_watchtower(),
932
        ]:
933
            tg_coro = self.taskgroup.spawn(coro)
×
934
            asyncio.run_coroutine_threadsafe(tg_coro, self.network.asyncio_loop)
×
935

936
    async def stop(self):
5✔
937
        self.stopping_soon = True
5✔
938
        if self.listen_server:  # stop accepting new peers
5✔
939
            self.listen_server.close()
×
940
        async with ignore_after(self.TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS):
5✔
941
            await self.wait_for_received_pending_htlcs_to_get_removed()
5✔
942
        await LNWorker.stop(self)
5✔
943
        if self.lnwatcher:
5✔
944
            await self.lnwatcher.stop()
×
945
            self.lnwatcher = None
×
946
        if self.swap_manager and self.swap_manager.network:  # may not be present in tests
5✔
947
            await self.swap_manager.stop()
×
948

949
    async def wait_for_received_pending_htlcs_to_get_removed(self):
5✔
950
        assert self.stopping_soon is True
5✔
951
        # We try to fail pending MPP HTLCs, and wait a bit for them to get removed.
952
        # Note: even without MPP, if we just failed/fulfilled an HTLC, it is good
953
        #       to wait a bit for it to become irrevocably removed.
954
        # Note: we don't wait for *all htlcs* to get removed, only for those
955
        #       that we can already fail/fulfill. e.g. forwarded htlcs cannot be removed
956
        async with OldTaskGroup() as group:
5✔
957
            for peer in self.peers.values():
5✔
958
                await group.spawn(peer.wait_one_htlc_switch_iteration())
5✔
959
        while True:
5✔
960
            if all(not peer.received_htlcs_pending_removal for peer in self.peers.values()):
5✔
961
                break
5✔
962
            async with OldTaskGroup(wait=any) as group:
5✔
963
                for peer in self.peers.values():
5✔
964
                    await group.spawn(peer.received_htlc_removed_event.wait())
5✔
965

966
    def peer_closed(self, peer):
5✔
967
        for chan in self.channels_for_peer(peer.pubkey).values():
×
968
            chan.peer_state = PeerState.DISCONNECTED
×
969
            util.trigger_callback('channel', self.wallet, chan)
×
970
        super().peer_closed(peer)
×
971

972
    def get_payments(self, *, status=None) -> Mapping[bytes, List[HTLCWithStatus]]:
5✔
973
        out = defaultdict(list)
5✔
974
        for chan in self.channels.values():
5✔
975
            d = chan.get_payments(status=status)
5✔
976
            for payment_hash, plist in d.items():
5✔
977
                out[payment_hash] += plist
5✔
978
        return out
5✔
979

980
    def get_payment_value(
5✔
981
            self, info: Optional['PaymentInfo'],
982
            plist: List[HTLCWithStatus]) -> Tuple[PaymentDirection, int, Optional[int], int]:
983
        """ fee_msat is included in amount_msat"""
984
        assert plist
×
985
        amount_msat = sum(int(x.direction) * x.htlc.amount_msat for x in plist)
×
986
        if all(x.direction == SENT for x in plist):
×
987
            direction = PaymentDirection.SENT
×
988
            fee_msat = (- info.amount_msat - amount_msat) if info else None
×
989
        elif all(x.direction == RECEIVED for x in plist):
×
990
            direction = PaymentDirection.RECEIVED
×
991
            fee_msat = None
×
992
        elif amount_msat < 0:
×
993
            direction = PaymentDirection.SELF_PAYMENT
×
994
            fee_msat = - amount_msat
×
995
        else:
996
            direction = PaymentDirection.FORWARDING
×
997
            fee_msat = - amount_msat
×
998
        timestamp = min([htlc_with_status.htlc.timestamp for htlc_with_status in plist])
×
999
        return direction, amount_msat, fee_msat, timestamp
×
1000

1001
    def get_lightning_history(self) -> Dict[str, LightningHistoryItem]:
5✔
1002
        """
1003
        side effect: sets defaults labels
1004
        note that the result is not ordered
1005
        """
1006
        out = {}
×
1007
        for payment_hash, plist in self.get_payments(status='settled').items():
×
1008
            if len(plist) == 0:
×
1009
                continue
×
1010
            key = payment_hash.hex()
×
1011
            info = self.get_payment_info(payment_hash)
×
1012
            direction, amount_msat, fee_msat, timestamp = self.get_payment_value(info, plist)
×
1013
            label = self.wallet.get_label_for_rhash(key)
×
1014
            if not label and direction == PaymentDirection.FORWARDING:
×
1015
                label = _('Forwarding')
×
1016
            preimage = self.get_preimage(payment_hash).hex()
×
NEW
1017
            group_id = self.swap_manager.get_group_id_for_payment_hash(payment_hash)
×
NEW
1018
            item = LightningHistoryItem(
×
1019
                type = 'payment',
1020
                payment_hash = payment_hash.hex(),
1021
                preimage = preimage,
1022
                amount_msat = amount_msat,
1023
                fee_msat = fee_msat,
1024
                group_id = group_id,
1025
                direction = direction,
1026
                timestamp = timestamp or 0,
1027
                label=label,
1028
            )
1029
            out[payment_hash] = item
×
1030
        for chan in itertools.chain(self.channels.values(), self.channel_backups.values()):  # type: AbstractChannel
×
UNCOV
1031
            item = chan.get_funding_height()
×
1032
            if item is None:
×
1033
                continue
×
1034
            funding_txid, funding_height, funding_timestamp = item
×
NEW
1035
            label = _('Open channel') + ' ' + chan.get_id_for_log()
×
NEW
1036
            self.wallet.set_default_label(funding_txid, label)
×
NEW
1037
            self.wallet.set_group_label(funding_txid, label)
×
NEW
1038
            item = LightningHistoryItem(
×
1039
                type = 'channel_opening',
1040
                label = label,
1041
                group_id = funding_txid,
1042
                timestamp = funding_timestamp,
1043
                amount_msat = chan.balance(LOCAL, ctn=0),
1044
                fee_msat = None,
1045
                direction = PaymentDirection.RECEIVED,
1046
                payment_hash = None,
1047
                preimage = None,
1048
            )
1049
            out[funding_txid] = item
×
1050
            item = chan.get_closing_height()
×
1051
            if item is None:
×
1052
                continue
×
1053
            closing_txid, closing_height, closing_timestamp = item
×
NEW
1054
            label = _('Close channel') + ' ' + chan.get_id_for_log()
×
NEW
1055
            self.wallet.set_default_label(closing_txid, label)
×
NEW
1056
            self.wallet.set_group_label(closing_txid, label)
×
NEW
1057
            item = LightningHistoryItem(
×
1058
                type = 'channel_closing',
1059
                label = label,
1060
                group_id = closing_txid,
1061
                timestamp = closing_timestamp,
1062
                amount_msat = -chan.balance(LOCAL),
1063
                fee_msat = None,
1064
                direction = PaymentDirection.SENT,
1065
                payment_hash = None,
1066
                preimage = None,
1067
            )
UNCOV
1068
            out[closing_txid] = item
×
1069

1070
        # sanity check
NEW
1071
        balance_msat = sum([x.amount_msat for x in out.values()])
×
UNCOV
1072
        lb = sum(chan.balance(LOCAL) if not chan.is_closed() else 0
×
1073
                for chan in self.channels.values())
1074
        assert balance_msat  == lb
×
1075
        return out
×
1076

1077
    def get_groups_for_onchain_history(self) -> Dict[str, str]:
5✔
1078
        """
1079
        returns dict: txid -> group_id
1080
        side effect: sets default labels
1081
        """
NEW
1082
        groups = {}
×
1083
        # add funding events
NEW
1084
        for chan in itertools.chain(self.channels.values(), self.channel_backups.values()):  # type: AbstractChannel
×
NEW
1085
            item = chan.get_funding_height()
×
NEW
1086
            if item is None:
×
NEW
1087
                continue
×
NEW
1088
            funding_txid, funding_height, funding_timestamp = item
×
NEW
1089
            groups[funding_txid] = funding_txid
×
NEW
1090
            item = chan.get_closing_height()
×
NEW
1091
            if item is None:
×
NEW
1092
                continue
×
NEW
1093
            closing_txid, closing_height, closing_timestamp = item
×
NEW
1094
            groups[closing_txid] = closing_txid
×
1095

NEW
1096
        d = self.swap_manager.get_groups_for_onchain_history()
×
NEW
1097
        for txid, v in d.items():
×
NEW
1098
            group_id = v['group_id']
×
NEW
1099
            label = v.get('label')
×
NEW
1100
            group_label = v.get('group_label') or label
×
NEW
1101
            groups[txid] = group_id
×
NEW
1102
            if label:
×
NEW
1103
                self.wallet.set_default_label(txid, label)
×
NEW
1104
            if group_label:
×
NEW
1105
                self.wallet.set_group_label(group_id, group_label)
×
1106

NEW
1107
        return groups
×
1108

1109
    def channel_peers(self) -> List[bytes]:
5✔
1110
        node_ids = [chan.node_id for chan in self.channels.values() if not chan.is_closed()]
×
1111
        return node_ids
×
1112

1113
    def channels_for_peer(self, node_id):
5✔
1114
        assert type(node_id) is bytes
5✔
1115
        return {chan_id: chan for (chan_id, chan) in self.channels.items()
5✔
1116
                if chan.node_id == node_id}
1117

1118
    def channel_state_changed(self, chan: Channel):
5✔
1119
        if type(chan) is Channel:
×
1120
            self.save_channel(chan)
×
1121
        self.clear_invoices_cache()
×
1122
        util.trigger_callback('channel', self.wallet, chan)
×
1123

1124
    def save_channel(self, chan: Channel):
5✔
1125
        assert type(chan) is Channel
×
1126
        if chan.config[REMOTE].next_per_commitment_point == chan.config[REMOTE].current_per_commitment_point:
×
1127
            raise Exception("Tried to save channel with next_point == current_point, this should not happen")
×
1128
        self.wallet.save_db()
×
1129
        util.trigger_callback('channel', self.wallet, chan)
×
1130

1131
    def channel_by_txo(self, txo: str) -> Optional[AbstractChannel]:
5✔
1132
        for chan in self.channels.values():
×
1133
            if chan.funding_outpoint.to_str() == txo:
×
1134
                return chan
×
1135
        for chan in self.channel_backups.values():
×
1136
            if chan.funding_outpoint.to_str() == txo:
×
1137
                return chan
×
1138

1139
    async def handle_onchain_state(self, chan: Channel):
5✔
1140
        if type(chan) is ChannelBackup:
×
1141
            util.trigger_callback('channel', self.wallet, chan)
×
1142
            return
×
1143

1144
        if (chan.get_state() in (ChannelState.OPEN, ChannelState.SHUTDOWN)
×
1145
                and chan.should_be_closed_due_to_expiring_htlcs(self.network.get_local_height())):
1146
            self.logger.info(f"force-closing due to expiring htlcs")
×
1147
            await self.schedule_force_closing(chan.channel_id)
×
1148

1149
        elif chan.get_state() == ChannelState.FUNDED:
×
1150
            peer = self._peers.get(chan.node_id)
×
1151
            if peer and peer.is_initialized() and chan.peer_state == PeerState.GOOD:
×
1152
                peer.send_channel_ready(chan)
×
1153

1154
        elif chan.get_state() == ChannelState.OPEN:
×
1155
            peer = self._peers.get(chan.node_id)
×
1156
            if peer and peer.is_initialized() and chan.peer_state == PeerState.GOOD:
×
1157
                peer.maybe_update_fee(chan)
×
1158
                peer.maybe_send_announcement_signatures(chan)
×
1159

1160
        elif chan.get_state() == ChannelState.FORCE_CLOSING:
×
1161
            force_close_tx = chan.force_close_tx()
×
1162
            txid = force_close_tx.txid()
×
1163
            height = self.lnwatcher.adb.get_tx_height(txid).height
×
1164
            if height == TX_HEIGHT_LOCAL:
×
1165
                self.logger.info('REBROADCASTING CLOSING TX')
×
1166
                await self.network.try_broadcasting(force_close_tx, 'force-close')
×
1167

1168
    def get_peer_by_scid_alias(self, scid_alias: bytes) -> Optional[Peer]:
5✔
1169
        for nodeid, peer in self.peers.items():
×
1170
            if scid_alias == self._scid_alias_of_node(nodeid):
×
1171
                return peer
×
1172

1173
    def _scid_alias_of_node(self, nodeid: bytes) -> bytes:
5✔
1174
        # scid alias for just-in-time channels
1175
        return sha256(b'Electrum' + nodeid)[0:8]
×
1176

1177
    def get_scid_alias(self) -> bytes:
5✔
1178
        return self._scid_alias_of_node(self.node_keypair.pubkey)
×
1179

1180
    @log_exceptions
5✔
1181
    async def open_channel_just_in_time(
5✔
1182
        self,
1183
        *,
1184
        next_peer: Peer,
1185
        next_amount_msat_htlc: int,
1186
        next_cltv_abs: int,
1187
        payment_hash: bytes,
1188
        next_onion: OnionPacket,
1189
    ) -> str:
1190
        # if an exception is raised during negotiation, we raise an OnionRoutingFailure.
1191
        # this will cancel the incoming HTLC
1192
        try:
×
1193
            funding_sat = 2 * (next_amount_msat_htlc // 1000) # try to fully spend htlcs
×
1194
            password = self.wallet.get_unlocked_password() if self.wallet.has_password() else None
×
1195
            channel_opening_fee = next_amount_msat_htlc // 100
×
1196
            if channel_opening_fee // 1000 < self.config.ZEROCONF_MIN_OPENING_FEE:
×
1197
                self.logger.info(f'rejecting JIT channel: payment too low')
×
1198
                raise OnionRoutingFailure(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'payment too low')
×
1199
            self.logger.info(f'channel opening fee (sats): {channel_opening_fee//1000}')
×
1200
            next_chan, funding_tx = await self.open_channel_with_peer(
×
1201
                next_peer, funding_sat,
1202
                push_sat=0,
1203
                zeroconf=True,
1204
                public=False,
1205
                opening_fee=channel_opening_fee,
1206
                password=password,
1207
            )
1208
            async def wait_for_channel():
×
1209
                while not next_chan.is_open():
×
1210
                    await asyncio.sleep(1)
×
1211
            await util.wait_for2(wait_for_channel(), LN_P2P_NETWORK_TIMEOUT)
×
1212
            next_chan.save_remote_scid_alias(self._scid_alias_of_node(next_peer.pubkey))
×
1213
            self.logger.info(f'JIT channel is open')
×
1214
            next_amount_msat_htlc -= channel_opening_fee
×
1215
            # fixme: some checks are missing
1216
            htlc = next_peer.send_htlc(
×
1217
                chan=next_chan,
1218
                payment_hash=payment_hash,
1219
                amount_msat=next_amount_msat_htlc,
1220
                cltv_abs=next_cltv_abs,
1221
                onion=next_onion)
1222
            async def wait_for_preimage():
×
1223
                while self.get_preimage(payment_hash) is None:
×
1224
                    await asyncio.sleep(1)
×
1225
            await util.wait_for2(wait_for_preimage(), LN_P2P_NETWORK_TIMEOUT)
×
1226
        except OnionRoutingFailure:
×
1227
            raise
×
1228
        except Exception:
×
1229
            raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_NODE_FAILURE, data=b'')
×
1230
        # We have been paid and can broadcast
1231
        # todo: if broadcasting raise an exception, we should try to rebroadcast
1232
        await self.network.broadcast_transaction(funding_tx)
×
1233
        htlc_key = serialize_htlc_key(next_chan.get_scid_or_local_alias(), htlc.htlc_id)
×
1234
        return htlc_key
×
1235

1236
    @log_exceptions
5✔
1237
    async def open_channel_with_peer(
5✔
1238
            self, peer, funding_sat, *,
1239
            push_sat: int = 0,
1240
            public: bool = False,
1241
            zeroconf: bool = False,
1242
            opening_fee: int = None,
1243
            password=None):
1244
        if self.config.ENABLE_ANCHOR_CHANNELS:
×
1245
            self.wallet.unlock(password)
×
1246
        coins = self.wallet.get_spendable_coins(None)
×
1247
        node_id = peer.pubkey
×
1248
        funding_tx = self.mktx_for_open_channel(
×
1249
            coins=coins,
1250
            funding_sat=funding_sat,
1251
            node_id=node_id,
1252
            fee_est=None)
1253
        chan, funding_tx = await self._open_channel_coroutine(
×
1254
            peer=peer,
1255
            funding_tx=funding_tx,
1256
            funding_sat=funding_sat,
1257
            push_sat=push_sat,
1258
            public=public,
1259
            zeroconf=zeroconf,
1260
            opening_fee=opening_fee,
1261
            password=password)
1262
        return chan, funding_tx
×
1263

1264
    @log_exceptions
5✔
1265
    async def _open_channel_coroutine(
5✔
1266
            self, *,
1267
            peer: Peer,
1268
            funding_tx: PartialTransaction,
1269
            funding_sat: int,
1270
            push_sat: int,
1271
            public: bool,
1272
            zeroconf=False,
1273
            opening_fee=None,
1274
            password: Optional[str],
1275
    ) -> Tuple[Channel, PartialTransaction]:
1276

1277
        if funding_sat > self.config.LIGHTNING_MAX_FUNDING_SAT:
×
1278
            raise Exception(
×
1279
                _("Requested channel capacity is over maximum.")
1280
                + f"\n{funding_sat} sat > {self.config.LIGHTNING_MAX_FUNDING_SAT} sat"
1281
            )
1282
        coro = peer.channel_establishment_flow(
×
1283
            funding_tx=funding_tx,
1284
            funding_sat=funding_sat,
1285
            push_msat=push_sat * 1000,
1286
            public=public,
1287
            zeroconf=zeroconf,
1288
            opening_fee=opening_fee,
1289
            temp_channel_id=os.urandom(32))
1290
        chan, funding_tx = await util.wait_for2(coro, LN_P2P_NETWORK_TIMEOUT)
×
1291
        util.trigger_callback('channels_updated', self.wallet)
×
1292
        self.wallet.adb.add_transaction(funding_tx)  # save tx as local into the wallet
×
1293
        self.wallet.sign_transaction(funding_tx, password)
×
1294
        if funding_tx.is_complete() and not zeroconf:
×
1295
            await self.network.try_broadcasting(funding_tx, 'open_channel')
×
1296
        return chan, funding_tx
×
1297

1298
    def add_channel(self, chan: Channel):
5✔
1299
        with self.lock:
×
1300
            self._channels[chan.channel_id] = chan
×
1301
        self.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address())
×
1302

1303
    def add_new_channel(self, chan: Channel):
5✔
1304
        self.add_channel(chan)
×
1305
        channels_db = self.db.get_dict('channels')
×
1306
        channels_db[chan.channel_id.hex()] = chan.storage
×
1307
        self.wallet.set_reserved_addresses_for_chan(chan, reserved=True)
×
1308
        try:
×
1309
            self.save_channel(chan)
×
1310
        except Exception:
×
1311
            chan.set_state(ChannelState.REDEEMED)
×
1312
            self.remove_channel(chan.channel_id)
×
1313
            raise
×
1314

1315
    def cb_data(self, node_id: bytes) -> bytes:
5✔
1316
        return CB_MAGIC_BYTES + node_id[0:NODE_ID_PREFIX_LEN]
×
1317

1318
    def decrypt_cb_data(self, encrypted_data, funding_address):
5✔
1319
        funding_scripthash = bytes.fromhex(address_to_scripthash(funding_address))
×
1320
        nonce = funding_scripthash[0:12]
×
1321
        return chacha20_decrypt(key=self.backup_key, data=encrypted_data, nonce=nonce)
×
1322

1323
    def encrypt_cb_data(self, data, funding_address):
5✔
1324
        funding_scripthash = bytes.fromhex(address_to_scripthash(funding_address))
×
1325
        nonce = funding_scripthash[0:12]
×
1326
        # note: we are only using chacha20 instead of chacha20+poly1305 to save onchain space
1327
        #       (not have the 16 byte MAC). Otherwise, the latter would be preferable.
1328
        return chacha20_encrypt(key=self.backup_key, data=data, nonce=nonce)
×
1329

1330
    def mktx_for_open_channel(
5✔
1331
            self, *,
1332
            coins: Sequence[PartialTxInput],
1333
            funding_sat: int,
1334
            node_id: bytes,
1335
            fee_est=None) -> PartialTransaction:
1336
        from .wallet import get_locktime_for_new_transaction
×
1337

1338
        outputs = [PartialTxOutput.from_address_and_value(DummyAddress.CHANNEL, funding_sat)]
×
1339
        if self.has_recoverable_channels():
×
1340
            dummy_scriptpubkey = make_op_return(self.cb_data(node_id))
×
1341
            outputs.append(PartialTxOutput(scriptpubkey=dummy_scriptpubkey, value=0))
×
1342
        tx = self.wallet.make_unsigned_transaction(
×
1343
            coins=coins,
1344
            outputs=outputs,
1345
            fee=fee_est)
1346
        tx.set_rbf(False)
×
1347
        # rm randomness from locktime, as we use the locktime as entropy for deriving the funding_privkey
1348
        # (and it would be confusing to get a collision as a consequence of the randomness)
1349
        tx.locktime = get_locktime_for_new_transaction(self.network, include_random_component=False)
×
1350
        return tx
×
1351

1352
    def suggest_funding_amount(self, amount_to_pay, coins):
5✔
1353
        """ whether we can pay amount_sat after opening a new channel"""
1354
        num_sats_can_send = int(self.num_sats_can_send())
×
1355
        lightning_needed = amount_to_pay - num_sats_can_send
×
1356
        assert lightning_needed > 0
×
1357
        min_funding_sat = lightning_needed + (lightning_needed // 20) + 1000 # safety margin
×
1358
        min_funding_sat = max(min_funding_sat, 100_000) # at least 1mBTC
×
1359
        if min_funding_sat > self.config.LIGHTNING_MAX_FUNDING_SAT:
×
1360
            return
×
1361
        fee_est = partial(self.config.estimate_fee, allow_fallback_to_static_rates=True)  # to avoid NoDynamicFeeEstimates
×
1362
        try:
×
1363
            self.mktx_for_open_channel(coins=coins, funding_sat=min_funding_sat, node_id=bytes(32), fee_est=fee_est)
×
1364
            funding_sat = min_funding_sat
×
1365
        except NotEnoughFunds:
×
1366
            return
×
1367
        # if available, suggest twice that amount:
1368
        if 2 * min_funding_sat <= self.config.LIGHTNING_MAX_FUNDING_SAT:
×
1369
            try:
×
1370
                self.mktx_for_open_channel(coins=coins, funding_sat=2*min_funding_sat, node_id=bytes(32), fee_est=fee_est)
×
1371
                funding_sat = 2 * min_funding_sat
×
1372
            except NotEnoughFunds:
×
1373
                pass
×
1374
        return funding_sat, min_funding_sat
×
1375

1376
    def open_channel(
5✔
1377
            self, *,
1378
            connect_str: str,
1379
            funding_tx: PartialTransaction,
1380
            funding_sat: int,
1381
            push_amt_sat: int,
1382
            public: bool = False,
1383
            password: str = None,
1384
    ) -> Tuple[Channel, PartialTransaction]:
1385

1386
        fut = asyncio.run_coroutine_threadsafe(self.add_peer(connect_str), self.network.asyncio_loop)
×
1387
        try:
×
1388
            peer = fut.result()
×
1389
        except concurrent.futures.TimeoutError:
×
1390
            raise Exception(_("add peer timed out"))
×
1391
        coro = self._open_channel_coroutine(
×
1392
            peer=peer,
1393
            funding_tx=funding_tx,
1394
            funding_sat=funding_sat,
1395
            push_sat=push_amt_sat,
1396
            public=public,
1397
            password=password)
1398
        fut = asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop)
×
1399
        try:
×
1400
            chan, funding_tx = fut.result()
×
1401
        except concurrent.futures.TimeoutError:
×
1402
            raise Exception(_("open_channel timed out"))
×
1403
        return chan, funding_tx
×
1404

1405
    def get_channel_by_short_id(self, short_channel_id: bytes) -> Optional[Channel]:
5✔
1406
        # First check against *real* SCIDs.
1407
        # This e.g. protects against maliciously chosen SCID aliases, and accidental collisions.
1408
        for chan in self.channels.values():
×
1409
            if chan.short_channel_id == short_channel_id:
×
1410
                return chan
×
1411
        # Now we also consider aliases.
1412
        # TODO we should split this as this search currently ignores the "direction"
1413
        #      of the aliases. We should only look at either the remote OR the local alias,
1414
        #      depending on context.
1415
        for chan in self.channels.values():
×
1416
            if chan.get_remote_scid_alias() == short_channel_id:
×
1417
                return chan
×
1418
            if chan.get_local_scid_alias() == short_channel_id:
×
1419
                return chan
×
1420

1421
    def can_pay_invoice(self, invoice: Invoice) -> bool:
5✔
1422
        assert invoice.is_lightning()
×
1423
        return (invoice.get_amount_sat() or 0) <= self.num_sats_can_send()
×
1424

1425
    @log_exceptions
5✔
1426
    async def pay_invoice(
5✔
1427
            self, invoice: str, *,
1428
            amount_msat: int = None,
1429
            attempts: int = None,  # used only in unit tests
1430
            full_path: LNPaymentPath = None,
1431
            channels: Optional[Sequence[Channel]] = None,
1432
    ) -> Tuple[bool, List[HtlcLog]]:
1433

1434
        lnaddr = self._check_invoice(invoice, amount_msat=amount_msat)
5✔
1435
        min_final_cltv_delta = lnaddr.get_min_final_cltv_delta()
5✔
1436
        payment_hash = lnaddr.paymenthash
5✔
1437
        key = payment_hash.hex()
5✔
1438
        payment_secret = lnaddr.payment_secret
5✔
1439
        invoice_pubkey = lnaddr.pubkey.serialize()
5✔
1440
        invoice_features = lnaddr.get_features()
5✔
1441
        r_tags = lnaddr.get_routing_info('r')
5✔
1442
        amount_to_pay = lnaddr.get_amount_msat()
5✔
1443
        status = self.get_payment_status(payment_hash)
5✔
1444
        if status == PR_PAID:
5✔
1445
            raise PaymentFailure(_("This invoice has been paid already"))
×
1446
        if status == PR_INFLIGHT:
5✔
1447
            raise PaymentFailure(_("A payment was already initiated for this invoice"))
×
1448
        if payment_hash in self.get_payments(status='inflight'):
5✔
1449
            raise PaymentFailure(_("A previous attempt to pay this invoice did not clear"))
×
1450
        info = PaymentInfo(payment_hash, amount_to_pay, SENT, PR_UNPAID)
5✔
1451
        self.save_payment_info(info)
5✔
1452
        self.wallet.set_label(key, lnaddr.get_description())
5✔
1453
        self.set_invoice_status(key, PR_INFLIGHT)
5✔
1454
        budget = PaymentFeeBudget.default(invoice_amount_msat=amount_to_pay, config=self.config)
5✔
1455
        if attempts is None and self.uses_trampoline():
5✔
1456
            # we don't expect lots of failed htlcs with trampoline, so we can fail sooner
1457
            attempts = 30
5✔
1458
        success = False
5✔
1459
        try:
5✔
1460
            await self.pay_to_node(
5✔
1461
                node_pubkey=invoice_pubkey,
1462
                payment_hash=payment_hash,
1463
                payment_secret=payment_secret,
1464
                amount_to_pay=amount_to_pay,
1465
                min_final_cltv_delta=min_final_cltv_delta,
1466
                r_tags=r_tags,
1467
                invoice_features=invoice_features,
1468
                attempts=attempts,
1469
                full_path=full_path,
1470
                channels=channels,
1471
                budget=budget,
1472
            )
1473
            success = True
5✔
1474
        except PaymentFailure as e:
5✔
1475
            self.logger.info(f'payment failure: {e!r}')
5✔
1476
            reason = str(e)
5✔
1477
        except ChannelDBNotLoaded as e:
5✔
1478
            self.logger.info(f'payment failure: {e!r}')
×
1479
            reason = str(e)
×
1480
        finally:
1481
            self.logger.info(f"pay_invoice ending session for RHASH={payment_hash.hex()}. {success=}")
5✔
1482
        if success:
5✔
1483
            self.set_invoice_status(key, PR_PAID)
5✔
1484
            util.trigger_callback('payment_succeeded', self.wallet, key)
5✔
1485
        else:
1486
            self.set_invoice_status(key, PR_UNPAID)
5✔
1487
            util.trigger_callback('payment_failed', self.wallet, key, reason)
5✔
1488
        log = self.logs[key]
5✔
1489
        return success, log
5✔
1490

1491
    async def pay_to_node(
5✔
1492
            self, *,
1493
            node_pubkey: bytes,
1494
            payment_hash: bytes,
1495
            payment_secret: bytes,
1496
            amount_to_pay: int,  # in msat
1497
            min_final_cltv_delta: int,
1498
            r_tags,
1499
            invoice_features: int,
1500
            attempts: int = None,
1501
            full_path: LNPaymentPath = None,
1502
            fwd_trampoline_onion: OnionPacket = None,
1503
            budget: PaymentFeeBudget,
1504
            channels: Optional[Sequence[Channel]] = None,
1505
            fw_payment_key = None,# for forwarding
1506
    ) -> None:
1507

1508
        assert budget
5✔
1509
        assert budget.fee_msat >= 0, budget
5✔
1510
        assert budget.cltv >= 0, budget
5✔
1511

1512
        payment_key = payment_hash + payment_secret
5✔
1513
        assert payment_key not in self._paysessions
5✔
1514
        self._paysessions[payment_key] = paysession = PaySession(
5✔
1515
            payment_hash=payment_hash,
1516
            payment_secret=payment_secret,
1517
            initial_trampoline_fee_level=self.config.INITIAL_TRAMPOLINE_FEE_LEVEL,
1518
            invoice_features=invoice_features,
1519
            r_tags=r_tags,
1520
            min_final_cltv_delta=min_final_cltv_delta,
1521
            amount_to_pay=amount_to_pay,
1522
            invoice_pubkey=node_pubkey,
1523
            uses_trampoline=self.uses_trampoline(),
1524
            use_two_trampolines=self.config.LIGHTNING_LEGACY_ADD_TRAMPOLINE,
1525
        )
1526
        self.logs[payment_hash.hex()] = log = []  # TODO incl payment_secret in key (re trampoline forwarding)
5✔
1527

1528
        paysession.logger.info(
5✔
1529
            f"pay_to_node starting session for RHASH={payment_hash.hex()}. "
1530
            f"using_trampoline={self.uses_trampoline()}. "
1531
            f"invoice_features={paysession.invoice_features.get_names()}. "
1532
            f"{amount_to_pay=} msat. {budget=}")
1533
        if not self.uses_trampoline():
5✔
1534
            self.logger.info(
5✔
1535
                f"gossip_db status. sync progress: {self.network.lngossip.get_sync_progress_estimate()}. "
1536
                f"num_nodes={self.channel_db.num_nodes}, "
1537
                f"num_channels={self.channel_db.num_channels}, "
1538
                f"num_policies={self.channel_db.num_policies}.")
1539

1540
        # when encountering trampoline forwarding difficulties in the legacy case, we
1541
        # sometimes need to fall back to a single trampoline forwarder, at the expense
1542
        # of privacy
1543
        try:
5✔
1544
            while True:
5✔
1545
                if (amount_to_send := paysession.get_outstanding_amount_to_send()) > 0:
5✔
1546
                    # 1. create a set of routes for remaining amount.
1547
                    # note: path-finding runs in a separate thread so that we don't block the asyncio loop
1548
                    # graph updates might occur during the computation
1549
                    remaining_fee_budget_msat = (budget.fee_msat * amount_to_send) // amount_to_pay
5✔
1550
                    routes = self.create_routes_for_payment(
5✔
1551
                        paysession=paysession,
1552
                        amount_msat=amount_to_send,
1553
                        full_path=full_path,
1554
                        fwd_trampoline_onion=fwd_trampoline_onion,
1555
                        channels=channels,
1556
                        budget=budget._replace(fee_msat=remaining_fee_budget_msat),
1557
                    )
1558
                    # 2. send htlcs
1559
                    async for sent_htlc_info, cltv_delta, trampoline_onion in routes:
5✔
1560
                        await self.pay_to_route(
5✔
1561
                            paysession=paysession,
1562
                            sent_htlc_info=sent_htlc_info,
1563
                            min_final_cltv_delta=cltv_delta,
1564
                            trampoline_onion=trampoline_onion,
1565
                            fw_payment_key=fw_payment_key,
1566
                        )
1567
                    # invoice_status is triggered in self.set_invoice_status when it actually changes.
1568
                    # It is also triggered here to update progress for a lightning payment in the GUI
1569
                    # (e.g. attempt counter)
1570
                    util.trigger_callback('invoice_status', self.wallet, payment_hash.hex(), PR_INFLIGHT)
5✔
1571
                # 3. await a queue
1572
                htlc_log = await paysession.wait_for_one_htlc_to_resolve()  # TODO maybe wait a bit, more failures might come
5✔
1573
                log.append(htlc_log)
5✔
1574
                if htlc_log.success:
5✔
1575
                    if self.network.path_finder:
5✔
1576
                        # TODO: report every route to liquidity hints for mpp
1577
                        # in the case of success, we report channels of the
1578
                        # route as being able to send the same amount in the future,
1579
                        # as we assume to not know the capacity
1580
                        self.network.path_finder.update_liquidity_hints(htlc_log.route, htlc_log.amount_msat)
5✔
1581
                        # remove inflight htlcs from liquidity hints
1582
                        self.network.path_finder.update_inflight_htlcs(htlc_log.route, add_htlcs=False)
5✔
1583
                    return
5✔
1584
                # htlc failed
1585
                # if we get a tmp channel failure, it might work to split the amount and try more routes
1586
                # if we get a channel update, we might retry the same route and amount
1587
                route = htlc_log.route
5✔
1588
                sender_idx = htlc_log.sender_idx
5✔
1589
                failure_msg = htlc_log.failure_msg
5✔
1590
                if sender_idx is None:
5✔
1591
                    raise PaymentFailure(failure_msg.code_name())
4✔
1592
                erring_node_id = route[sender_idx].node_id
5✔
1593
                code, data = failure_msg.code, failure_msg.data
5✔
1594
                self.logger.info(f"UPDATE_FAIL_HTLC. code={repr(code)}. "
5✔
1595
                                 f"decoded_data={failure_msg.decode_data()}. data={data.hex()!r}")
1596
                self.logger.info(f"error reported by {erring_node_id.hex()}")
5✔
1597
                if code == OnionFailureCode.MPP_TIMEOUT:
5✔
1598
                    raise PaymentFailure(failure_msg.code_name())
5✔
1599
                # errors returned by the next trampoline.
1600
                if fwd_trampoline_onion and code in [
5✔
1601
                        OnionFailureCode.TRAMPOLINE_FEE_INSUFFICIENT,
1602
                        OnionFailureCode.TRAMPOLINE_EXPIRY_TOO_SOON]:
1603
                    raise failure_msg
×
1604
                # trampoline
1605
                if self.uses_trampoline():
5✔
1606
                    paysession.handle_failed_trampoline_htlc(
5✔
1607
                        htlc_log=htlc_log, failure_msg=failure_msg)
1608
                else:
1609
                    self.handle_error_code_from_failed_htlc(
5✔
1610
                        route=route, sender_idx=sender_idx, failure_msg=failure_msg, amount=htlc_log.amount_msat)
1611
                # max attempts or timeout
1612
                if (attempts is not None and len(log) >= attempts) or (attempts is None and time.time() - paysession.start_time > self.PAYMENT_TIMEOUT):
5✔
1613
                    raise PaymentFailure('Giving up after %d attempts'%len(log))
5✔
1614
        finally:
1615
            paysession.is_active = False
5✔
1616
            if paysession.can_be_deleted():
5✔
1617
                self._paysessions.pop(payment_key)
5✔
1618
            paysession.logger.info(f"pay_to_node ending session for RHASH={payment_hash.hex()}")
5✔
1619

1620
    async def pay_to_route(
5✔
1621
            self, *,
1622
            paysession: PaySession,
1623
            sent_htlc_info: SentHtlcInfo,
1624
            min_final_cltv_delta: int,
1625
            trampoline_onion: Optional[OnionPacket] = None,
1626
            fw_payment_key: str = None,
1627
    ) -> None:
1628
        """Sends a single HTLC."""
1629
        shi = sent_htlc_info
5✔
1630
        del sent_htlc_info  # just renamed
5✔
1631
        short_channel_id = shi.route[0].short_channel_id
5✔
1632
        chan = self.get_channel_by_short_id(short_channel_id)
5✔
1633
        assert chan, ShortChannelID(short_channel_id)
5✔
1634
        peer = self._peers.get(shi.route[0].node_id)
5✔
1635
        if not peer:
5✔
1636
            raise PaymentFailure('Dropped peer')
×
1637
        await peer.initialized
5✔
1638
        htlc = peer.pay(
5✔
1639
            route=shi.route,
1640
            chan=chan,
1641
            amount_msat=shi.amount_msat,
1642
            total_msat=shi.bucket_msat,
1643
            payment_hash=paysession.payment_hash,
1644
            min_final_cltv_delta=min_final_cltv_delta,
1645
            payment_secret=shi.payment_secret_bucket,
1646
            trampoline_onion=trampoline_onion)
1647

1648
        key = (paysession.payment_hash, short_channel_id, htlc.htlc_id)
5✔
1649
        self.sent_htlcs_info[key] = shi
5✔
1650
        paysession.add_new_htlc(shi)
5✔
1651
        if fw_payment_key:
5✔
1652
            htlc_key = serialize_htlc_key(short_channel_id, htlc.htlc_id)
5✔
1653
            self.logger.info(f'adding active forwarding {fw_payment_key}')
5✔
1654
            self.active_forwardings[fw_payment_key].append(htlc_key)
5✔
1655
        if self.network.path_finder:
5✔
1656
            # add inflight htlcs to liquidity hints
1657
            self.network.path_finder.update_inflight_htlcs(shi.route, add_htlcs=True)
5✔
1658
        util.trigger_callback('htlc_added', chan, htlc, SENT)
5✔
1659

1660
    def handle_error_code_from_failed_htlc(
5✔
1661
            self,
1662
            *,
1663
            route: LNPaymentRoute,
1664
            sender_idx: int,
1665
            failure_msg: OnionRoutingFailure,
1666
            amount: int) -> None:
1667

1668
        assert self.channel_db  # cannot be in trampoline mode
5✔
1669
        assert self.network.path_finder
5✔
1670

1671
        # remove inflight htlcs from liquidity hints
1672
        self.network.path_finder.update_inflight_htlcs(route, add_htlcs=False)
5✔
1673

1674
        code, data = failure_msg.code, failure_msg.data
5✔
1675
        # TODO can we use lnmsg.OnionWireSerializer here?
1676
        # TODO update onion_wire.csv
1677
        # handle some specific error codes
1678
        failure_codes = {
5✔
1679
            OnionFailureCode.TEMPORARY_CHANNEL_FAILURE: 0,
1680
            OnionFailureCode.AMOUNT_BELOW_MINIMUM: 8,
1681
            OnionFailureCode.FEE_INSUFFICIENT: 8,
1682
            OnionFailureCode.INCORRECT_CLTV_EXPIRY: 4,
1683
            OnionFailureCode.EXPIRY_TOO_SOON: 0,
1684
            OnionFailureCode.CHANNEL_DISABLED: 2,
1685
        }
1686
        try:
5✔
1687
            failing_channel = route[sender_idx + 1].short_channel_id
5✔
1688
        except IndexError:
5✔
1689
            raise PaymentFailure(f'payment destination reported error: {failure_msg.code_name()}') from None
5✔
1690

1691
        # TODO: handle unknown next peer?
1692
        # handle failure codes that include a channel update
1693
        if code in failure_codes:
5✔
1694
            offset = failure_codes[code]
5✔
1695
            channel_update_len = int.from_bytes(data[offset:offset+2], byteorder="big")
5✔
1696
            channel_update_as_received = data[offset+2: offset+2+channel_update_len]
5✔
1697
            payload = self._decode_channel_update_msg(channel_update_as_received)
5✔
1698
            if payload is None:
5✔
1699
                self.logger.info(f'could not decode channel_update for failed htlc: '
×
1700
                                 f'{channel_update_as_received.hex()}')
1701
                blacklist = True
×
1702
            elif payload.get('short_channel_id') != failing_channel:
5✔
1703
                self.logger.info(f'short_channel_id in channel_update does not match our route')
×
1704
                blacklist = True
×
1705
            else:
1706
                # apply the channel update or get blacklisted
1707
                blacklist, update = self._handle_chanupd_from_failed_htlc(
5✔
1708
                    payload, route=route, sender_idx=sender_idx, failure_msg=failure_msg)
1709
                # we interpret a temporary channel failure as a liquidity issue
1710
                # in the channel and update our liquidity hints accordingly
1711
                if code == OnionFailureCode.TEMPORARY_CHANNEL_FAILURE:
5✔
1712
                    self.network.path_finder.update_liquidity_hints(
5✔
1713
                        route,
1714
                        amount,
1715
                        failing_channel=ShortChannelID(failing_channel))
1716
                # if we can't decide on some action, we are stuck
1717
                if not (blacklist or update):
5✔
1718
                    raise PaymentFailure(failure_msg.code_name())
×
1719
        # for errors that do not include a channel update
1720
        else:
1721
            blacklist = True
5✔
1722
        if blacklist:
5✔
1723
            self.network.path_finder.add_edge_to_blacklist(short_channel_id=failing_channel)
5✔
1724

1725
    def _handle_chanupd_from_failed_htlc(
5✔
1726
        self, payload, *,
1727
        route: LNPaymentRoute,
1728
        sender_idx: int,
1729
        failure_msg: OnionRoutingFailure,
1730
    ) -> Tuple[bool, bool]:
1731
        blacklist = False
5✔
1732
        update = False
5✔
1733
        try:
5✔
1734
            r = self.channel_db.add_channel_update(payload, verify=True)
5✔
1735
        except InvalidGossipMsg:
×
1736
            return True, False  # blacklist
×
1737
        short_channel_id = ShortChannelID(payload['short_channel_id'])
5✔
1738
        if r == UpdateStatus.GOOD:
5✔
1739
            self.logger.info(f"applied channel update to {short_channel_id}")
×
1740
            # TODO: add test for this
1741
            # FIXME: this does not work for our own unannounced channels.
1742
            for chan in self.channels.values():
×
1743
                if chan.short_channel_id == short_channel_id:
×
1744
                    chan.set_remote_update(payload)
×
1745
            update = True
×
1746
        elif r == UpdateStatus.ORPHANED:
5✔
1747
            # maybe it is a private channel (and data in invoice was outdated)
1748
            self.logger.info(f"Could not find {short_channel_id}. maybe update is for private channel?")
5✔
1749
            start_node_id = route[sender_idx].node_id
5✔
1750
            cache_ttl = None
5✔
1751
            if failure_msg.code == OnionFailureCode.CHANNEL_DISABLED:
5✔
1752
                # eclair sends CHANNEL_DISABLED if its peer is offline. E.g. we might be trying to pay
1753
                # a mobile phone with the app closed. So we cache this with a short TTL.
1754
                cache_ttl = self.channel_db.PRIVATE_CHAN_UPD_CACHE_TTL_SHORT
×
1755
            update = self.channel_db.add_channel_update_for_private_channel(payload, start_node_id, cache_ttl=cache_ttl)
5✔
1756
            blacklist = not update
5✔
1757
        elif r == UpdateStatus.EXPIRED:
×
1758
            blacklist = True
×
1759
        elif r == UpdateStatus.DEPRECATED:
×
1760
            self.logger.info(f'channel update is not more recent.')
×
1761
            blacklist = True
×
1762
        elif r == UpdateStatus.UNCHANGED:
×
1763
            blacklist = True
×
1764
        return blacklist, update
5✔
1765

1766
    @classmethod
5✔
1767
    def _decode_channel_update_msg(cls, chan_upd_msg: bytes) -> Optional[Dict[str, Any]]:
5✔
1768
        channel_update_as_received = chan_upd_msg
5✔
1769
        channel_update_typed = (258).to_bytes(length=2, byteorder="big") + channel_update_as_received
5✔
1770
        # note: some nodes put channel updates in error msgs with the leading msg_type already there.
1771
        #       we try decoding both ways here.
1772
        try:
5✔
1773
            message_type, payload = decode_msg(channel_update_typed)
5✔
1774
            if payload['chain_hash'] != constants.net.rev_genesis_bytes(): raise Exception()
5✔
1775
            payload['raw'] = channel_update_typed
5✔
1776
            return payload
5✔
1777
        except Exception:  # FIXME: too broad
5✔
1778
            try:
5✔
1779
                message_type, payload = decode_msg(channel_update_as_received)
5✔
1780
                if payload['chain_hash'] != constants.net.rev_genesis_bytes(): raise Exception()
5✔
1781
                payload['raw'] = channel_update_as_received
5✔
1782
                return payload
5✔
1783
            except Exception:
5✔
1784
                return None
5✔
1785

1786
    def _check_invoice(self, invoice: str, *, amount_msat: int = None) -> LnAddr:
5✔
1787
        """Parses and validates a bolt11 invoice str into a LnAddr.
1788
        Includes pre-payment checks external to the parser.
1789
        """
1790
        addr = lndecode(invoice)
5✔
1791
        if addr.is_expired():
5✔
1792
            raise InvoiceError(_("This invoice has expired"))
×
1793
        # check amount
1794
        if amount_msat:  # replace amt in invoice. main usecase is paying zero amt invoices
5✔
1795
            existing_amt_msat = addr.get_amount_msat()
×
1796
            if existing_amt_msat and amount_msat < existing_amt_msat:
×
1797
                raise Exception("cannot pay lower amt than what is originally in LN invoice")
×
1798
            addr.amount = Decimal(amount_msat) / COIN / 1000
×
1799
        if addr.amount is None:
5✔
1800
            raise InvoiceError(_("Missing amount"))
×
1801
        # check cltv
1802
        if addr.get_min_final_cltv_delta() > lnutil.NBLOCK_CLTV_DELTA_TOO_FAR_INTO_FUTURE:
5✔
1803
            raise InvoiceError("{}\n{}".format(
5✔
1804
                _("Invoice wants us to risk locking funds for unreasonably long."),
1805
                f"min_final_cltv_delta: {addr.get_min_final_cltv_delta()}"))
1806
        # check features
1807
        addr.validate_and_compare_features(self.features)
5✔
1808
        return addr
5✔
1809

1810
    def is_trampoline_peer(self, node_id: bytes) -> bool:
5✔
1811
        # until trampoline is advertised in lnfeatures, check against hardcoded list
1812
        if is_hardcoded_trampoline(node_id):
5✔
1813
            return True
5✔
1814
        peer = self._peers.get(node_id)
×
1815
        if not peer:
×
1816
            return False
×
1817
        return (peer.their_features.supports(LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ECLAIR)\
×
1818
                or peer.their_features.supports(LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ELECTRUM))
1819

1820
    def suggest_peer(self) -> Optional[bytes]:
5✔
1821
        if not self.uses_trampoline():
×
1822
            return self.lnrater.suggest_peer()
×
1823
        else:
1824
            return random.choice(list(hardcoded_trampoline_nodes().values())).pubkey
×
1825

1826
    def suggest_splits(
5✔
1827
        self,
1828
        *,
1829
        amount_msat: int,
1830
        final_total_msat: int,
1831
        my_active_channels: Sequence[Channel],
1832
        invoice_features: LnFeatures,
1833
        r_tags,
1834
    ) -> List['SplitConfigRating']:
1835
        channels_with_funds = {
5✔
1836
            (chan.channel_id, chan.node_id): int(chan.available_to_spend(HTLCOwner.LOCAL))
1837
            for chan in my_active_channels
1838
        }
1839
        self.logger.info(f"channels_with_funds: {channels_with_funds}")
5✔
1840
        exclude_single_part_payments = False
5✔
1841
        if self.uses_trampoline():
5✔
1842
            # in the case of a legacy payment, we don't allow splitting via different
1843
            # trampoline nodes, because of https://github.com/ACINQ/eclair/issues/2127
1844
            is_legacy, _ = is_legacy_relay(invoice_features, r_tags)
5✔
1845
            exclude_multinode_payments = is_legacy
5✔
1846
            # we don't split within a channel when sending to a trampoline node,
1847
            # the trampoline node will split for us
1848
            exclude_single_channel_splits = True
5✔
1849
        else:
1850
            exclude_multinode_payments = False
5✔
1851
            exclude_single_channel_splits = False
5✔
1852
            if invoice_features.supports(LnFeatures.BASIC_MPP_OPT) and not self.config.TEST_FORCE_DISABLE_MPP:
5✔
1853
                # if amt is still large compared to total_msat, split it:
1854
                if (amount_msat / final_total_msat > self.MPP_SPLIT_PART_FRACTION
5✔
1855
                        and amount_msat > self.MPP_SPLIT_PART_MINAMT_MSAT):
1856
                    exclude_single_part_payments = True
×
1857

1858
        def get_splits():
5✔
1859
            return suggest_splits(
5✔
1860
                amount_msat,
1861
                channels_with_funds,
1862
                exclude_single_part_payments=exclude_single_part_payments,
1863
                exclude_multinode_payments=exclude_multinode_payments,
1864
                exclude_single_channel_splits=exclude_single_channel_splits
1865
            )
1866

1867
        split_configurations = get_splits()
5✔
1868
        if not split_configurations and exclude_single_part_payments:
5✔
1869
            exclude_single_part_payments = False
×
1870
            split_configurations = get_splits()
×
1871
        self.logger.info(f'suggest_split {amount_msat} returned {len(split_configurations)} configurations')
5✔
1872
        return split_configurations
5✔
1873

1874
    async def create_routes_for_payment(
5✔
1875
            self, *,
1876
            paysession: PaySession,
1877
            amount_msat: int,        # part of payment amount we want routes for now
1878
            fwd_trampoline_onion: OnionPacket = None,
1879
            full_path: LNPaymentPath = None,
1880
            channels: Optional[Sequence[Channel]] = None,
1881
            budget: PaymentFeeBudget,
1882
    ) -> AsyncGenerator[Tuple[SentHtlcInfo, int, Optional[OnionPacket]], None]:
1883

1884
        """Creates multiple routes for splitting a payment over the available
1885
        private channels.
1886

1887
        We first try to conduct the payment over a single channel. If that fails
1888
        and mpp is supported by the receiver, we will split the payment."""
1889
        trampoline_features = LnFeatures.VAR_ONION_OPT
5✔
1890
        local_height = self.network.get_local_height()
5✔
1891
        fee_related_error = None  # type: Optional[FeeBudgetExceeded]
5✔
1892
        if channels:
5✔
1893
            my_active_channels = channels
×
1894
        else:
1895
            my_active_channels = [
5✔
1896
                chan for chan in self.channels.values() if
1897
                chan.is_active() and not chan.is_frozen_for_sending()]
1898
        # try random order
1899
        random.shuffle(my_active_channels)
5✔
1900
        split_configurations = self.suggest_splits(
5✔
1901
            amount_msat=amount_msat,
1902
            final_total_msat=paysession.amount_to_pay,
1903
            my_active_channels=my_active_channels,
1904
            invoice_features=paysession.invoice_features,
1905
            r_tags=paysession.r_tags,
1906
        )
1907
        for sc in split_configurations:
5✔
1908
            is_multichan_mpp = len(sc.config.items()) > 1
5✔
1909
            is_mpp = sc.config.number_parts() > 1
5✔
1910
            if is_mpp and not paysession.invoice_features.supports(LnFeatures.BASIC_MPP_OPT):
5✔
1911
                continue
5✔
1912
            if not is_mpp and self.config.TEST_FORCE_MPP:
5✔
1913
                continue
5✔
1914
            if is_mpp and self.config.TEST_FORCE_DISABLE_MPP:
5✔
1915
                continue
×
1916
            self.logger.info(f"trying split configuration: {sc.config.values()} rating: {sc.rating}")
5✔
1917
            routes = []
5✔
1918
            try:
5✔
1919
                if self.uses_trampoline():
5✔
1920
                    per_trampoline_channel_amounts = defaultdict(list)
5✔
1921
                    # categorize by trampoline nodes for trampoline mpp construction
1922
                    for (chan_id, _), part_amounts_msat in sc.config.items():
5✔
1923
                        chan = self.channels[chan_id]
5✔
1924
                        for part_amount_msat in part_amounts_msat:
5✔
1925
                            per_trampoline_channel_amounts[chan.node_id].append((chan_id, part_amount_msat))
5✔
1926
                    # for each trampoline forwarder, construct mpp trampoline
1927
                    for trampoline_node_id, trampoline_parts in per_trampoline_channel_amounts.items():
5✔
1928
                        per_trampoline_amount = sum([x[1] for x in trampoline_parts])
5✔
1929
                        trampoline_route, trampoline_onion, per_trampoline_amount_with_fees, per_trampoline_cltv_delta = create_trampoline_route_and_onion(
5✔
1930
                            amount_msat=per_trampoline_amount,
1931
                            total_msat=paysession.amount_to_pay,
1932
                            min_final_cltv_delta=paysession.min_final_cltv_delta,
1933
                            my_pubkey=self.node_keypair.pubkey,
1934
                            invoice_pubkey=paysession.invoice_pubkey,
1935
                            invoice_features=paysession.invoice_features,
1936
                            node_id=trampoline_node_id,
1937
                            r_tags=paysession.r_tags,
1938
                            payment_hash=paysession.payment_hash,
1939
                            payment_secret=paysession.payment_secret,
1940
                            local_height=local_height,
1941
                            trampoline_fee_level=paysession.trampoline_fee_level,
1942
                            use_two_trampolines=paysession.use_two_trampolines,
1943
                            failed_routes=paysession.failed_trampoline_routes,
1944
                            budget=budget._replace(fee_msat=budget.fee_msat // len(per_trampoline_channel_amounts)),
1945
                        )
1946
                        # node_features is only used to determine is_tlv
1947
                        per_trampoline_secret = os.urandom(32)
5✔
1948
                        per_trampoline_fees = per_trampoline_amount_with_fees - per_trampoline_amount
5✔
1949
                        self.logger.info(f'created route with trampoline fee level={paysession.trampoline_fee_level}')
5✔
1950
                        self.logger.info(f'trampoline hops: {[hop.end_node.hex() for hop in trampoline_route]}')
5✔
1951
                        self.logger.info(f'per trampoline fees: {per_trampoline_fees}')
5✔
1952
                        for chan_id, part_amount_msat in trampoline_parts:
5✔
1953
                            chan = self.channels[chan_id]
5✔
1954
                            margin = chan.available_to_spend(LOCAL, strict=True) - part_amount_msat
5✔
1955
                            delta_fee = min(per_trampoline_fees, margin)
5✔
1956
                            # TODO: distribute trampoline fee over several channels?
1957
                            part_amount_msat_with_fees = part_amount_msat + delta_fee
5✔
1958
                            per_trampoline_fees -= delta_fee
5✔
1959
                            route = [
5✔
1960
                                RouteEdge(
1961
                                    start_node=self.node_keypair.pubkey,
1962
                                    end_node=trampoline_node_id,
1963
                                    short_channel_id=chan.short_channel_id,
1964
                                    fee_base_msat=0,
1965
                                    fee_proportional_millionths=0,
1966
                                    cltv_delta=0,
1967
                                    node_features=trampoline_features)
1968
                            ]
1969
                            self.logger.info(f'adding route {part_amount_msat} {delta_fee} {margin}')
5✔
1970
                            shi = SentHtlcInfo(
5✔
1971
                                route=route,
1972
                                payment_secret_orig=paysession.payment_secret,
1973
                                payment_secret_bucket=per_trampoline_secret,
1974
                                amount_msat=part_amount_msat_with_fees,
1975
                                bucket_msat=per_trampoline_amount_with_fees,
1976
                                amount_receiver_msat=part_amount_msat,
1977
                                trampoline_fee_level=paysession.trampoline_fee_level,
1978
                                trampoline_route=trampoline_route,
1979
                            )
1980
                            routes.append((shi, per_trampoline_cltv_delta, trampoline_onion))
5✔
1981
                        if per_trampoline_fees != 0:
5✔
1982
                            e = 'not enough margin to pay trampoline fee'
×
1983
                            self.logger.info(e)
×
1984
                            raise FeeBudgetExceeded(e)
×
1985
                else:
1986
                    # We atomically loop through a split configuration. If there was
1987
                    # a failure to find a path for a single part, we try the next configuration
1988
                    for (chan_id, _), part_amounts_msat in sc.config.items():
5✔
1989
                        for part_amount_msat in part_amounts_msat:
5✔
1990
                            channel = self.channels[chan_id]
5✔
1991
                            route = await run_in_thread(
5✔
1992
                                partial(
1993
                                    self.create_route_for_single_htlc,
1994
                                    amount_msat=part_amount_msat,
1995
                                    invoice_pubkey=paysession.invoice_pubkey,
1996
                                    min_final_cltv_delta=paysession.min_final_cltv_delta,
1997
                                    r_tags=paysession.r_tags,
1998
                                    invoice_features=paysession.invoice_features,
1999
                                    my_sending_channels=[channel] if is_multichan_mpp else my_active_channels,
2000
                                    full_path=full_path,
2001
                                    budget=budget._replace(fee_msat=budget.fee_msat // sc.config.number_parts()),
2002
                                )
2003
                            )
2004
                            shi = SentHtlcInfo(
5✔
2005
                                route=route,
2006
                                payment_secret_orig=paysession.payment_secret,
2007
                                payment_secret_bucket=paysession.payment_secret,
2008
                                amount_msat=part_amount_msat,
2009
                                bucket_msat=paysession.amount_to_pay,
2010
                                amount_receiver_msat=part_amount_msat,
2011
                                trampoline_fee_level=None,
2012
                                trampoline_route=None,
2013
                            )
2014
                            routes.append((shi, paysession.min_final_cltv_delta, fwd_trampoline_onion))
5✔
2015
            except NoPathFound:
5✔
2016
                continue
5✔
2017
            except FeeBudgetExceeded as e:
5✔
2018
                fee_related_error = e
×
2019
                continue
×
2020
            for route in routes:
5✔
2021
                yield route
5✔
2022
            return
5✔
2023
        if fee_related_error is not None:
5✔
2024
            raise fee_related_error
×
2025
        raise NoPathFound()
5✔
2026

2027
    @profiler
5✔
2028
    def create_route_for_single_htlc(
5✔
2029
            self, *,
2030
            amount_msat: int,  # that final receiver gets
2031
            invoice_pubkey: bytes,
2032
            min_final_cltv_delta: int,
2033
            r_tags,
2034
            invoice_features: int,
2035
            my_sending_channels: List[Channel],
2036
            full_path: Optional[LNPaymentPath],
2037
            budget: PaymentFeeBudget,
2038
    ) -> LNPaymentRoute:
2039

2040
        my_sending_aliases = set(chan.get_local_scid_alias() for chan in my_sending_channels)
5✔
2041
        my_sending_channels = {chan.short_channel_id: chan for chan in my_sending_channels
5✔
2042
            if chan.short_channel_id is not None}
2043
        # Collect all private edges from route hints.
2044
        # Note: if some route hints are multiple edges long, and these paths cross each other,
2045
        #       we allow our path finding to cross the paths; i.e. the route hints are not isolated.
2046
        private_route_edges = {}  # type: Dict[ShortChannelID, RouteEdge]
5✔
2047
        for private_path in r_tags:
5✔
2048
            # we need to shift the node pubkey by one towards the destination:
2049
            private_path_nodes = [edge[0] for edge in private_path][1:] + [invoice_pubkey]
5✔
2050
            private_path_rest = [edge[1:] for edge in private_path]
5✔
2051
            start_node = private_path[0][0]
5✔
2052
            # remove aliases from direct routes
2053
            if len(private_path) == 1 and private_path[0][1] in my_sending_aliases:
5✔
2054
                self.logger.info(f'create_route: skipping alias {ShortChannelID(private_path[0][1])}')
×
2055
                continue
×
2056
            for end_node, edge_rest in zip(private_path_nodes, private_path_rest):
5✔
2057
                short_channel_id, fee_base_msat, fee_proportional_millionths, cltv_delta = edge_rest
5✔
2058
                short_channel_id = ShortChannelID(short_channel_id)
5✔
2059
                # if we have a routing policy for this edge in the db, that takes precedence,
2060
                # as it is likely from a previous failure
2061
                channel_policy = self.channel_db.get_policy_for_node(
5✔
2062
                    short_channel_id=short_channel_id,
2063
                    node_id=start_node,
2064
                    my_channels=my_sending_channels)
2065
                if channel_policy:
5✔
2066
                    fee_base_msat = channel_policy.fee_base_msat
5✔
2067
                    fee_proportional_millionths = channel_policy.fee_proportional_millionths
5✔
2068
                    cltv_delta = channel_policy.cltv_delta
5✔
2069
                node_info = self.channel_db.get_node_info_for_node_id(node_id=end_node)
5✔
2070
                route_edge = RouteEdge(
5✔
2071
                        start_node=start_node,
2072
                        end_node=end_node,
2073
                        short_channel_id=short_channel_id,
2074
                        fee_base_msat=fee_base_msat,
2075
                        fee_proportional_millionths=fee_proportional_millionths,
2076
                        cltv_delta=cltv_delta,
2077
                        node_features=node_info.features if node_info else 0)
2078
                private_route_edges[route_edge.short_channel_id] = route_edge
5✔
2079
                start_node = end_node
5✔
2080
        # now find a route, end to end: between us and the recipient
2081
        try:
5✔
2082
            route = self.network.path_finder.find_route(
5✔
2083
                nodeA=self.node_keypair.pubkey,
2084
                nodeB=invoice_pubkey,
2085
                invoice_amount_msat=amount_msat,
2086
                path=full_path,
2087
                my_sending_channels=my_sending_channels,
2088
                private_route_edges=private_route_edges)
2089
        except NoChannelPolicy as e:
5✔
2090
            raise NoPathFound() from e
×
2091
        if not route:
5✔
2092
            raise NoPathFound()
5✔
2093
        if not is_route_within_budget(
5✔
2094
            route, budget=budget, amount_msat_for_dest=amount_msat, cltv_delta_for_dest=min_final_cltv_delta,
2095
        ):
2096
            self.logger.info(f"rejecting route (exceeds budget): {route=}. {budget=}")
×
2097
            raise FeeBudgetExceeded()
×
2098
        assert len(route) > 0
5✔
2099
        if route[-1].end_node != invoice_pubkey:
5✔
2100
            raise LNPathInconsistent("last node_id != invoice pubkey")
5✔
2101
        # add features from invoice
2102
        route[-1].node_features |= invoice_features
5✔
2103
        return route
5✔
2104

2105
    def clear_invoices_cache(self):
5✔
2106
        self._bolt11_cache.clear()
×
2107

2108
    def get_bolt11_invoice(
5✔
2109
            self, *,
2110
            payment_hash: bytes,
2111
            amount_msat: Optional[int],
2112
            message: str,
2113
            expiry: int,  # expiration of invoice (in seconds, relative)
2114
            fallback_address: Optional[str],
2115
            channels: Optional[Sequence[Channel]] = None,
2116
            min_final_cltv_expiry_delta: Optional[int] = None,
2117
    ) -> Tuple[LnAddr, str]:
2118
        assert isinstance(payment_hash, bytes), f"expected bytes, but got {type(payment_hash)}"
×
2119

2120
        pair = self._bolt11_cache.get(payment_hash)
×
2121
        if pair:
×
2122
            lnaddr, invoice = pair
×
2123
            assert lnaddr.get_amount_msat() == amount_msat
×
2124
            return pair
×
2125

2126
        assert amount_msat is None or amount_msat > 0
×
2127
        timestamp = int(time.time())
×
2128
        routing_hints = self.calc_routing_hints_for_invoice(amount_msat, channels=channels)
×
2129
        self.logger.info(f"creating bolt11 invoice with routing_hints: {routing_hints}")
×
2130
        invoice_features = self.features.for_invoice()
×
2131
        if not self.uses_trampoline():
×
2132
            invoice_features &= ~ LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ELECTRUM
×
2133
        payment_secret = self.get_payment_secret(payment_hash)
×
2134
        amount_btc = amount_msat/Decimal(COIN*1000) if amount_msat else None
×
2135
        if expiry == 0:
×
2136
            expiry = LN_EXPIRY_NEVER
×
2137
        if min_final_cltv_expiry_delta is None:
×
2138
            min_final_cltv_expiry_delta = MIN_FINAL_CLTV_DELTA_FOR_INVOICE
×
2139
        lnaddr = LnAddr(
×
2140
            paymenthash=payment_hash,
2141
            amount=amount_btc,
2142
            tags=[
2143
                ('d', message),
2144
                ('c', min_final_cltv_expiry_delta),
2145
                ('x', expiry),
2146
                ('9', invoice_features),
2147
                ('f', fallback_address),
2148
            ] + routing_hints,
2149
            date=timestamp,
2150
            payment_secret=payment_secret)
2151
        invoice = lnencode(lnaddr, self.node_keypair.privkey)
×
2152
        pair = lnaddr, invoice
×
2153
        self._bolt11_cache[payment_hash] = pair
×
2154
        return pair
×
2155

2156
    def get_payment_secret(self, payment_hash):
5✔
2157
        return sha256(sha256(self.payment_secret_key) + payment_hash)
5✔
2158

2159
    def _get_payment_key(self, payment_hash: bytes) -> bytes:
5✔
2160
        """Return payment bucket key.
2161
        We bucket htlcs based on payment_hash+payment_secret. payment_secret is included
2162
        as it changes over a trampoline path (in the outer onion), and these paths can overlap.
2163
        """
2164
        payment_secret = self.get_payment_secret(payment_hash)
5✔
2165
        return payment_hash + payment_secret
5✔
2166

2167
    def create_payment_info(self, *, amount_msat: Optional[int], write_to_disk=True) -> bytes:
5✔
2168
        payment_preimage = os.urandom(32)
5✔
2169
        payment_hash = sha256(payment_preimage)
5✔
2170
        info = PaymentInfo(payment_hash, amount_msat, RECEIVED, PR_UNPAID)
5✔
2171
        self.save_preimage(payment_hash, payment_preimage, write_to_disk=False)
5✔
2172
        self.save_payment_info(info, write_to_disk=False)
5✔
2173
        if write_to_disk:
5✔
2174
            self.wallet.save_db()
×
2175
        return payment_hash
5✔
2176

2177
    def bundle_payments(self, hash_list):
5✔
2178
        payment_keys = [self._get_payment_key(x) for x in hash_list]
5✔
2179
        self.payment_bundles.append(payment_keys)
5✔
2180

2181
    def get_payment_bundle(self, payment_key: bytes) -> Sequence[bytes]:
5✔
2182
        for key_list in self.payment_bundles:
5✔
2183
            if payment_key in key_list:
5✔
2184
                return key_list
5✔
2185

2186
    def save_preimage(self, payment_hash: bytes, preimage: bytes, *, write_to_disk: bool = True):
5✔
2187
        if sha256(preimage) != payment_hash:
5✔
2188
            raise Exception("tried to save incorrect preimage for payment_hash")
×
2189
        self.preimages[payment_hash.hex()] = preimage.hex()
5✔
2190
        if write_to_disk:
5✔
2191
            self.wallet.save_db()
5✔
2192

2193
    def get_preimage(self, payment_hash: bytes) -> Optional[bytes]:
5✔
2194
        assert isinstance(payment_hash, bytes), f"expected bytes, but got {type(payment_hash)}"
5✔
2195
        preimage_hex = self.preimages.get(payment_hash.hex())
5✔
2196
        if preimage_hex is None:
5✔
2197
            return None
5✔
2198
        preimage_bytes = bytes.fromhex(preimage_hex)
5✔
2199
        if sha256(preimage_bytes) != payment_hash:
5✔
2200
            raise Exception("found incorrect preimage for payment_hash")
×
2201
        return preimage_bytes
5✔
2202

2203
    def get_payment_info(self, payment_hash: bytes) -> Optional[PaymentInfo]:
5✔
2204
        """returns None if payment_hash is a payment we are forwarding"""
2205
        key = payment_hash.hex()
5✔
2206
        with self.lock:
5✔
2207
            if key in self.payment_info:
5✔
2208
                amount_msat, direction, status = self.payment_info[key]
5✔
2209
                return PaymentInfo(payment_hash, amount_msat, direction, status)
5✔
2210

2211
    def add_payment_info_for_hold_invoice(self, payment_hash: bytes, lightning_amount_sat: int):
5✔
2212
        info = PaymentInfo(payment_hash, lightning_amount_sat * 1000, RECEIVED, PR_UNPAID)
×
2213
        self.save_payment_info(info, write_to_disk=False)
×
2214

2215
    def register_hold_invoice(self, payment_hash: bytes, cb: Callable[[bytes], Awaitable[None]]):
5✔
2216
        self.hold_invoice_callbacks[payment_hash] = cb
5✔
2217

2218
    def unregister_hold_invoice(self, payment_hash: bytes):
5✔
2219
        self.hold_invoice_callbacks.pop(payment_hash)
×
2220

2221
    def save_payment_info(self, info: PaymentInfo, *, write_to_disk: bool = True) -> None:
5✔
2222
        key = info.payment_hash.hex()
5✔
2223
        assert info.status in SAVED_PR_STATUS
5✔
2224
        with self.lock:
5✔
2225
            self.payment_info[key] = info.amount_msat, info.direction, info.status
5✔
2226
        if write_to_disk:
5✔
2227
            self.wallet.save_db()
5✔
2228

2229
    def check_mpp_status(
5✔
2230
            self, *,
2231
            payment_secret: bytes,
2232
            short_channel_id: ShortChannelID,
2233
            htlc: UpdateAddHtlc,
2234
            expected_msat: int,
2235
    ) -> RecvMPPResolution:
2236
        """Returns the status of the incoming htlc set the given *htlc* belongs to.
2237

2238
        ACCEPTED simply means the mpp set is complete, and we can proceed with further
2239
        checks before fulfilling (or failing) the htlcs.
2240
        In particular, note that hold-invoice-htlcs typically remain in the ACCEPTED state
2241
        for quite some time -- not in the "WAITING" state (which would refer to the mpp set
2242
        not yet being complete!).
2243
        """
2244
        payment_hash = htlc.payment_hash
5✔
2245
        payment_key = payment_hash + payment_secret
5✔
2246
        self.update_mpp_with_received_htlc(
5✔
2247
            payment_key=payment_key, scid=short_channel_id, htlc=htlc, expected_msat=expected_msat)
2248
        mpp_resolution = self.received_mpp_htlcs[payment_key.hex()].resolution
5✔
2249
        # if still waiting, calc resolution now:
2250
        if mpp_resolution == RecvMPPResolution.WAITING:
5✔
2251
            bundle = self.get_payment_bundle(payment_key)
5✔
2252
            if bundle:
5✔
2253
                payment_keys = bundle
5✔
2254
            else:
2255
                payment_keys = [payment_key]
5✔
2256
            first_timestamp = min([self.get_first_timestamp_of_mpp(pkey) for pkey in payment_keys])
5✔
2257
            if self.get_payment_status(payment_hash) == PR_PAID:
5✔
2258
                mpp_resolution = RecvMPPResolution.ACCEPTED
×
2259
            elif self.stopping_soon:
5✔
2260
                # try to time out pending HTLCs before shutting down
2261
                mpp_resolution = RecvMPPResolution.EXPIRED
5✔
2262
            elif all([self.is_mpp_amount_reached(pkey) for pkey in payment_keys]):
5✔
2263
                mpp_resolution = RecvMPPResolution.ACCEPTED
5✔
2264
            elif time.time() - first_timestamp > self.MPP_EXPIRY:
5✔
2265
                mpp_resolution = RecvMPPResolution.EXPIRED
5✔
2266
            # save resolution, if any.
2267
            if mpp_resolution != RecvMPPResolution.WAITING:
5✔
2268
                for pkey in payment_keys:
5✔
2269
                    if pkey.hex() in self.received_mpp_htlcs:
5✔
2270
                        self.set_mpp_resolution(payment_key=pkey, resolution=mpp_resolution)
5✔
2271

2272
        return mpp_resolution
5✔
2273

2274
    def update_mpp_with_received_htlc(
5✔
2275
        self,
2276
        *,
2277
        payment_key: bytes,
2278
        scid: ShortChannelID,
2279
        htlc: UpdateAddHtlc,
2280
        expected_msat: int,
2281
    ):
2282
        # add new htlc to set
2283
        mpp_status = self.received_mpp_htlcs.get(payment_key.hex())
5✔
2284
        if mpp_status is None:
5✔
2285
            mpp_status = ReceivedMPPStatus(
5✔
2286
                resolution=RecvMPPResolution.WAITING,
2287
                expected_msat=expected_msat,
2288
                htlc_set=set(),
2289
            )
2290
        if expected_msat != mpp_status.expected_msat:
5✔
2291
            self.logger.info(
5✔
2292
                f"marking received mpp as failed. inconsistent total_msats in bucket. {payment_key.hex()=}")
2293
            mpp_status = mpp_status._replace(resolution=RecvMPPResolution.FAILED)
5✔
2294
        key = (scid, htlc)
5✔
2295
        if key not in mpp_status.htlc_set:
5✔
2296
            mpp_status.htlc_set.add(key)  # side-effecting htlc_set
5✔
2297
        self.received_mpp_htlcs[payment_key.hex()] = mpp_status
5✔
2298

2299
    def set_mpp_resolution(self, *, payment_key: bytes, resolution: RecvMPPResolution):
5✔
2300
        mpp_status = self.received_mpp_htlcs[payment_key.hex()]
5✔
2301
        self.logger.info(f'set_mpp_resolution {resolution.name} {len(mpp_status.htlc_set)} {payment_key.hex()}')
5✔
2302
        self.received_mpp_htlcs[payment_key.hex()] = mpp_status._replace(resolution=resolution)
5✔
2303

2304
    def is_mpp_amount_reached(self, payment_key: bytes) -> bool:
5✔
2305
        mpp_status = self.received_mpp_htlcs.get(payment_key.hex())
5✔
2306
        if not mpp_status:
5✔
2307
            return False
5✔
2308
        total = sum([_htlc.amount_msat for scid, _htlc in mpp_status.htlc_set])
5✔
2309
        return total >= mpp_status.expected_msat
5✔
2310

2311
    def is_accepted_mpp(self, payment_hash: bytes) -> bool:
5✔
2312
        payment_key = self._get_payment_key(payment_hash)
×
2313
        status = self.received_mpp_htlcs.get(payment_key.hex())
×
2314
        return status and status.resolution == RecvMPPResolution.ACCEPTED
×
2315

2316
    def get_first_timestamp_of_mpp(self, payment_key: bytes) -> int:
5✔
2317
        mpp_status = self.received_mpp_htlcs.get(payment_key.hex())
5✔
2318
        if not mpp_status:
5✔
2319
            return int(time.time())
5✔
2320
        return min([_htlc.timestamp for scid, _htlc in mpp_status.htlc_set])
5✔
2321

2322
    def maybe_cleanup_mpp(
5✔
2323
            self,
2324
            short_channel_id: ShortChannelID,
2325
            htlc: UpdateAddHtlc,
2326
    ) -> None:
2327

2328
        htlc_key = (short_channel_id, htlc)
5✔
2329
        for payment_key_hex, mpp_status in list(self.received_mpp_htlcs.items()):
5✔
2330
            if htlc_key not in mpp_status.htlc_set:
5✔
2331
                continue
5✔
2332
            assert mpp_status.resolution != RecvMPPResolution.WAITING
5✔
2333
            self.logger.info(f'maybe_cleanup_mpp: removing htlc of MPP {payment_key_hex}')
5✔
2334
            mpp_status.htlc_set.remove(htlc_key)  # side-effecting htlc_set
5✔
2335
            if len(mpp_status.htlc_set) == 0:
5✔
2336
                self.logger.info(f'maybe_cleanup_mpp: removing mpp {payment_key_hex}')
5✔
2337
                self.received_mpp_htlcs.pop(payment_key_hex)
5✔
2338
                self.maybe_cleanup_forwarding(payment_key_hex)
5✔
2339

2340
    def maybe_cleanup_forwarding(self, payment_key_hex: str) -> None:
5✔
2341
        self.active_forwardings.pop(payment_key_hex, None)
5✔
2342
        self.forwarding_failures.pop(payment_key_hex, None)
5✔
2343

2344
    def get_payment_status(self, payment_hash: bytes) -> int:
5✔
2345
        info = self.get_payment_info(payment_hash)
5✔
2346
        return info.status if info else PR_UNPAID
5✔
2347

2348
    def get_invoice_status(self, invoice: BaseInvoice) -> int:
5✔
2349
        invoice_id = invoice.rhash
5✔
2350
        status = self.get_payment_status(bfh(invoice_id))
5✔
2351
        if status == PR_UNPAID and invoice_id in self.inflight_payments:
5✔
2352
            return PR_INFLIGHT
×
2353
        # status may be PR_FAILED
2354
        if status == PR_UNPAID and invoice_id in self.logs:
5✔
2355
            status = PR_FAILED
×
2356
        return status
5✔
2357

2358
    def set_invoice_status(self, key: str, status: int) -> None:
5✔
2359
        if status == PR_INFLIGHT:
5✔
2360
            self.inflight_payments.add(key)
5✔
2361
        elif key in self.inflight_payments:
5✔
2362
            self.inflight_payments.remove(key)
5✔
2363
        if status in SAVED_PR_STATUS:
5✔
2364
            self.set_payment_status(bfh(key), status)
5✔
2365
        util.trigger_callback('invoice_status', self.wallet, key, status)
5✔
2366
        self.logger.info(f"set_invoice_status {key}: {status}")
5✔
2367
        # liquidity changed
2368
        self.clear_invoices_cache()
5✔
2369

2370
    def set_request_status(self, payment_hash: bytes, status: int) -> None:
5✔
2371
        if self.get_payment_status(payment_hash) == status:
5✔
2372
            return
5✔
2373
        self.set_payment_status(payment_hash, status)
5✔
2374
        request_id = payment_hash.hex()
5✔
2375
        req = self.wallet.get_request(request_id)
5✔
2376
        if req is None:
5✔
2377
            return
5✔
2378
        util.trigger_callback('request_status', self.wallet, request_id, status)
5✔
2379

2380
    def set_payment_status(self, payment_hash: bytes, status: int) -> None:
5✔
2381
        info = self.get_payment_info(payment_hash)
5✔
2382
        if info is None:
5✔
2383
            # if we are forwarding
2384
            return
5✔
2385
        info = info._replace(status=status)
5✔
2386
        self.save_payment_info(info)
5✔
2387

2388
    def is_forwarded_htlc(self, htlc_key) -> Optional[str]:
5✔
2389
        """Returns whether this was a forwarded HTLC."""
2390
        for payment_key, htlcs in self.active_forwardings.items():
5✔
2391
            if htlc_key in htlcs:
5✔
2392
                return payment_key
5✔
2393

2394
    def notify_upstream_peer(self, htlc_key: str) -> None:
5✔
2395
        """Called when an HTLC we offered on chan gets irrevocably fulfilled or failed.
2396
        If we find this was a forwarded HTLC, the upstream peer is notified.
2397
        """
2398
        upstream_key = self.downstream_to_upstream_htlc.pop(htlc_key, None)
5✔
2399
        if not upstream_key:
5✔
2400
            return
4✔
2401
        upstream_chan_scid, _ = deserialize_htlc_key(upstream_key)
5✔
2402
        upstream_chan = self.get_channel_by_short_id(upstream_chan_scid)
5✔
2403
        upstream_peer = self.peers.get(upstream_chan.node_id) if upstream_chan else None
5✔
2404
        if upstream_peer:
5✔
2405
            upstream_peer.downstream_htlc_resolved_event.set()
5✔
2406
            upstream_peer.downstream_htlc_resolved_event.clear()
5✔
2407

2408
    def htlc_fulfilled(self, chan: Channel, payment_hash: bytes, htlc_id: int):
5✔
2409

2410
        util.trigger_callback('htlc_fulfilled', payment_hash, chan, htlc_id)
5✔
2411
        htlc_key = serialize_htlc_key(chan.get_scid_or_local_alias(), htlc_id)
5✔
2412
        fw_key = self.is_forwarded_htlc(htlc_key)
5✔
2413
        if fw_key:
5✔
2414
            fw_htlcs = self.active_forwardings[fw_key]
5✔
2415
            fw_htlcs.remove(htlc_key)
5✔
2416

2417
        if shi := self.sent_htlcs_info.get((payment_hash, chan.short_channel_id, htlc_id)):
5✔
2418
            chan.pop_onion_key(htlc_id)
5✔
2419
            payment_key = payment_hash + shi.payment_secret_orig
5✔
2420
            paysession = self._paysessions[payment_key]
5✔
2421
            q = paysession.sent_htlcs_q
5✔
2422
            htlc_log = HtlcLog(
5✔
2423
                success=True,
2424
                route=shi.route,
2425
                amount_msat=shi.amount_receiver_msat,
2426
                trampoline_fee_level=shi.trampoline_fee_level)
2427
            q.put_nowait(htlc_log)
5✔
2428
            if paysession.can_be_deleted():
5✔
2429
                self._paysessions.pop(payment_key)
5✔
2430
                paysession_active = False
5✔
2431
            else:
2432
                paysession_active = True
5✔
2433
        else:
2434
            if fw_key:
5✔
2435
                paysession_active = False
5✔
2436
            else:
2437
                key = payment_hash.hex()
5✔
2438
                self.set_invoice_status(key, PR_PAID)
5✔
2439
                util.trigger_callback('payment_succeeded', self.wallet, key)
5✔
2440

2441
        if fw_key:
5✔
2442
            fw_htlcs = self.active_forwardings[fw_key]
5✔
2443
            if len(fw_htlcs) == 0 and not paysession_active:
5✔
2444
                self.notify_upstream_peer(htlc_key)
5✔
2445

2446

2447
    def htlc_failed(
5✔
2448
            self,
2449
            chan: Channel,
2450
            payment_hash: bytes,
2451
            htlc_id: int,
2452
            error_bytes: Optional[bytes],
2453
            failure_message: Optional['OnionRoutingFailure']):
2454

2455
        util.trigger_callback('htlc_failed', payment_hash, chan, htlc_id)
5✔
2456
        htlc_key = serialize_htlc_key(chan.get_scid_or_local_alias(), htlc_id)
5✔
2457
        fw_key = self.is_forwarded_htlc(htlc_key)
5✔
2458
        if fw_key:
5✔
2459
            fw_htlcs = self.active_forwardings[fw_key]
5✔
2460
            fw_htlcs.remove(htlc_key)
5✔
2461

2462
        if shi := self.sent_htlcs_info.get((payment_hash, chan.short_channel_id, htlc_id)):
5✔
2463
            onion_key = chan.pop_onion_key(htlc_id)
5✔
2464
            payment_okey = payment_hash + shi.payment_secret_orig
5✔
2465
            paysession = self._paysessions[payment_okey]
5✔
2466
            q = paysession.sent_htlcs_q
5✔
2467
            # detect if it is part of a bucket
2468
            # if yes, wait until the bucket completely failed
2469
            route = shi.route
5✔
2470
            if error_bytes:
5✔
2471
                # TODO "decode_onion_error" might raise, catch and maybe blacklist/penalise someone?
2472
                try:
5✔
2473
                    failure_message, sender_idx = decode_onion_error(
5✔
2474
                        error_bytes,
2475
                        [x.node_id for x in route],
2476
                        onion_key)
2477
                except Exception as e:
4✔
2478
                    sender_idx = None
4✔
2479
                    failure_message = OnionRoutingFailure(OnionFailureCode.INVALID_ONION_PAYLOAD, str(e).encode())
4✔
2480
            else:
2481
                # probably got "update_fail_malformed_htlc". well... who to penalise now?
2482
                assert failure_message is not None
×
2483
                sender_idx = None
×
2484
            self.logger.info(f"htlc_failed {failure_message}")
5✔
2485
            amount_receiver_msat = paysession.on_htlc_fail_get_fail_amt_to_propagate(shi)
5✔
2486
            if amount_receiver_msat is None:
5✔
2487
                return
5✔
2488
            if shi.trampoline_route:
5✔
2489
                route = shi.trampoline_route
5✔
2490
            htlc_log = HtlcLog(
5✔
2491
                success=False,
2492
                route=route,
2493
                amount_msat=amount_receiver_msat,
2494
                error_bytes=error_bytes,
2495
                failure_msg=failure_message,
2496
                sender_idx=sender_idx,
2497
                trampoline_fee_level=shi.trampoline_fee_level)
2498
            q.put_nowait(htlc_log)
5✔
2499
            if paysession.can_be_deleted():
5✔
2500
                self._paysessions.pop(payment_okey)
5✔
2501
                paysession_active = False
5✔
2502
            else:
2503
                paysession_active = True
5✔
2504
        else:
2505
            if fw_key:
5✔
2506
                paysession_active = False
5✔
2507
            else:
2508
                self.logger.info(f"received unknown htlc_failed, probably from previous session (phash={payment_hash.hex()})")
5✔
2509
                key = payment_hash.hex()
5✔
2510
                self.set_invoice_status(key, PR_UNPAID)
5✔
2511
                util.trigger_callback('payment_failed', self.wallet, key, '')
5✔
2512

2513
        if fw_key:
5✔
2514
            fw_htlcs = self.active_forwardings[fw_key]
5✔
2515
            can_forward_failure = (len(fw_htlcs) == 0) and not paysession_active
5✔
2516
            if can_forward_failure:
5✔
2517
                self.save_forwarding_failure(fw_key, error_bytes=error_bytes, failure_message=failure_message)
5✔
2518
                self.notify_upstream_peer(htlc_key)
5✔
2519
            else:
2520
                self.logger.info(f"waiting for other htlcs to fail (phash={payment_hash.hex()})")
5✔
2521

2522
    def calc_routing_hints_for_invoice(self, amount_msat: Optional[int], channels=None):
5✔
2523
        """calculate routing hints (BOLT-11 'r' field)"""
2524
        routing_hints = []
5✔
2525
        if self.config.ZEROCONF_TRUSTED_NODE:
5✔
2526
            node_id, rest = extract_nodeid(self.config.ZEROCONF_TRUSTED_NODE)
×
2527
            alias_or_scid = self.get_scid_alias()
×
2528
            routing_hints.append(('r', [(node_id, alias_or_scid, 0, 0, 144)]))
×
2529
            # no need for more
2530
            channels = []
×
2531
        else:
2532
            if channels is None:
5✔
2533
                channels = list(self.get_channels_for_receiving(amount_msat))
5✔
2534
                random.shuffle(channels)  # let's not leak channel order
5✔
2535
            scid_to_my_channels = {
5✔
2536
                chan.short_channel_id: chan for chan in channels
2537
                if chan.short_channel_id is not None
2538
            }
2539
        for chan in channels:
5✔
2540
            alias_or_scid = chan.get_remote_scid_alias() or chan.short_channel_id
5✔
2541
            assert isinstance(alias_or_scid, bytes), alias_or_scid
5✔
2542
            channel_info = get_mychannel_info(chan.short_channel_id, scid_to_my_channels)
5✔
2543
            # note: as a fallback, if we don't have a channel update for the
2544
            # incoming direction of our private channel, we fill the invoice with garbage.
2545
            # the sender should still be able to pay us, but will incur an extra round trip
2546
            # (they will get the channel update from the onion error)
2547
            # at least, that's the theory. https://github.com/lightningnetwork/lnd/issues/2066
2548
            fee_base_msat = fee_proportional_millionths = 0
5✔
2549
            cltv_delta = 1  # lnd won't even try with zero
5✔
2550
            missing_info = True
5✔
2551
            if channel_info:
5✔
2552
                policy = get_mychannel_policy(channel_info.short_channel_id, chan.node_id, scid_to_my_channels)
5✔
2553
                if policy:
5✔
2554
                    fee_base_msat = policy.fee_base_msat
5✔
2555
                    fee_proportional_millionths = policy.fee_proportional_millionths
5✔
2556
                    cltv_delta = policy.cltv_delta
5✔
2557
                    missing_info = False
5✔
2558
            if missing_info:
5✔
2559
                self.logger.info(
×
2560
                    f"Warning. Missing channel update for our channel {chan.short_channel_id}; "
2561
                    f"filling invoice with incorrect data.")
2562
            routing_hints.append(('r', [(
5✔
2563
                chan.node_id,
2564
                alias_or_scid,
2565
                fee_base_msat,
2566
                fee_proportional_millionths,
2567
                cltv_delta)]))
2568
        return routing_hints
5✔
2569

2570
    def delete_payment_info(self, payment_hash_hex: str):
5✔
2571
        # This method is called when an invoice or request is deleted by the user.
2572
        # The GUI only lets the user delete invoices or requests that have not been paid.
2573
        # Once an invoice/request has been paid, it is part of the history,
2574
        # and get_lightning_history assumes that payment_info is there.
2575
        assert self.get_payment_status(bytes.fromhex(payment_hash_hex)) != PR_PAID
×
2576
        with self.lock:
×
2577
            self.payment_info.pop(payment_hash_hex, None)
×
2578

2579
    def get_balance(self, frozen=False):
5✔
2580
        with self.lock:
×
2581
            return Decimal(sum(
×
2582
                chan.balance(LOCAL) if not chan.is_closed() and (chan.is_frozen_for_sending() if frozen else True) else 0
2583
                for chan in self.channels.values())) / 1000
2584

2585
    def get_channels_for_sending(self):
5✔
2586
        for c in self.channels.values():
×
2587
            if c.is_active() and not c.is_frozen_for_sending():
×
2588
                if self.channel_db or self.is_trampoline_peer(c.node_id):
×
2589
                    yield c
×
2590

2591
    def fee_estimate(self, amount_sat):
5✔
2592
        # Here we have to guess a fee, because some callers (submarine swaps)
2593
        # use this method to initiate a payment, which would otherwise fail.
2594
        fee_base_msat = 5000               # FIXME ehh.. there ought to be a better way...
×
2595
        fee_proportional_millionths = 500  # FIXME
×
2596
        # inverse of fee_for_edge_msat
2597
        amount_msat = amount_sat * 1000
×
2598
        amount_minus_fees = (amount_msat - fee_base_msat) * 1_000_000 // ( 1_000_000 + fee_proportional_millionths)
×
2599
        return Decimal(amount_msat - amount_minus_fees) / 1000
×
2600

2601
    def num_sats_can_send(self, deltas=None) -> Decimal:
5✔
2602
        """
2603
        without trampoline, sum of all channel capacity
2604
        with trampoline, MPP must use a single trampoline
2605
        """
2606
        if deltas is None:
×
2607
            deltas = {}
×
2608
        def send_capacity(chan):
×
2609
            if chan in deltas:
×
2610
                delta_msat = deltas[chan] * 1000
×
2611
                if delta_msat > chan.available_to_spend(REMOTE):
×
2612
                    delta_msat = 0
×
2613
            else:
2614
                delta_msat = 0
×
2615
            return chan.available_to_spend(LOCAL) + delta_msat
×
2616
        can_send_dict = defaultdict(int)
×
2617
        with self.lock:
×
2618
            for c in self.get_channels_for_sending():
×
2619
                if not self.uses_trampoline():
×
2620
                    can_send_dict[0] += send_capacity(c)
×
2621
                else:
2622
                    can_send_dict[c.node_id] += send_capacity(c)
×
2623
        can_send = max(can_send_dict.values()) if can_send_dict else 0
×
2624
        can_send_sat = Decimal(can_send)/1000
×
2625
        can_send_sat -= self.fee_estimate(can_send_sat)
×
2626
        return max(can_send_sat, 0)
×
2627

2628
    def get_channels_for_receiving(self, amount_msat=None) -> Sequence[Channel]:
5✔
2629
        if not amount_msat:  # assume we want to recv a large amt, e.g. finding max.
5✔
2630
            amount_msat = float('inf')
×
2631
        with self.lock:
5✔
2632
            channels = list(self.channels.values())
5✔
2633
            # we exclude channels that cannot *right now* receive (e.g. peer offline)
2634
            channels = [chan for chan in channels
5✔
2635
                        if (chan.is_open() and not chan.is_frozen_for_receiving())]
2636
            # Filter out nodes that have low receive capacity compared to invoice amt.
2637
            # Even with MPP, below a certain threshold, including these channels probably
2638
            # hurts more than help, as they lead to many failed attempts for the sender.
2639
            channels = sorted(channels, key=lambda chan: -chan.available_to_spend(REMOTE))
5✔
2640
            selected_channels = []
5✔
2641
            running_sum = 0
5✔
2642
            cutoff_factor = 0.2  # heuristic
5✔
2643
            for chan in channels:
5✔
2644
                recv_capacity = chan.available_to_spend(REMOTE)
5✔
2645
                chan_can_handle_payment_as_single_part = recv_capacity >= amount_msat
5✔
2646
                chan_small_compared_to_running_sum = recv_capacity < cutoff_factor * running_sum
5✔
2647
                if not chan_can_handle_payment_as_single_part and chan_small_compared_to_running_sum:
5✔
2648
                    break
5✔
2649
                running_sum += recv_capacity
5✔
2650
                selected_channels.append(chan)
5✔
2651
            channels = selected_channels
5✔
2652
            del selected_channels
5✔
2653
            # cap max channels to include to keep QR code reasonably scannable
2654
            channels = channels[:10]
5✔
2655
            return channels
5✔
2656

2657
    def num_sats_can_receive(self, deltas=None) -> Decimal:
5✔
2658
        """
2659
        We no longer assume the sender to send MPP on different channels,
2660
        because channel liquidities are hard to guess
2661
        """
2662
        if deltas is None:
×
2663
            deltas = {}
×
2664
        def recv_capacity(chan):
×
2665
            if chan in deltas:
×
2666
                delta_msat = deltas[chan] * 1000
×
2667
                if delta_msat > chan.available_to_spend(LOCAL):
×
2668
                    delta_msat = 0
×
2669
            else:
2670
                delta_msat = 0
×
2671
            return chan.available_to_spend(REMOTE) + delta_msat
×
2672
        with self.lock:
×
2673
            recv_channels = self.get_channels_for_receiving()
×
2674
            recv_chan_msats = [recv_capacity(chan) for chan in recv_channels]
×
2675
        if not recv_chan_msats:
×
2676
            return Decimal(0)
×
2677
        can_receive_msat = max(recv_chan_msats)
×
2678
        return Decimal(can_receive_msat) / 1000
×
2679

2680

2681
    def _suggest_channels_for_rebalance(self, direction, amount_sat) -> Sequence[Tuple[Channel, int]]:
5✔
2682
        """
2683
        Suggest a channel and amount to send/receive with that channel, so that we will be able to receive/send amount_sat
2684
        This is used when suggesting a swap or rebalance in order to receive a payment
2685
        """
2686
        with self.lock:
×
2687
            func = self.num_sats_can_send if direction == SENT else self.num_sats_can_receive
×
2688
            suggestions = []
×
2689
            channels = self.get_channels_for_sending() if direction == SENT else self.get_channels_for_receiving()
×
2690
            for chan in channels:
×
2691
                available_sat = chan.available_to_spend(LOCAL if direction == SENT else REMOTE) // 1000
×
2692
                delta = amount_sat - available_sat
×
2693
                delta += self.fee_estimate(amount_sat)
×
2694
                # add safety margin
2695
                delta += delta // 100 + 1
×
2696
                if func(deltas={chan:delta}) >= amount_sat:
×
2697
                    suggestions.append((chan, delta))
×
2698
                elif direction==RECEIVED and func(deltas={chan:2*delta}) >= amount_sat:
×
2699
                    # MPP heuristics has a 0.5 slope
2700
                    suggestions.append((chan, 2*delta))
×
2701
        if not suggestions:
×
2702
            raise NotEnoughFunds
×
2703
        return suggestions
×
2704

2705
    def _suggest_rebalance(self, direction, amount_sat):
5✔
2706
        """
2707
        Suggest a rebalance in order to be able to send or receive amount_sat.
2708
        Returns (from_channel, to_channel, amount to shuffle)
2709
        """
2710
        try:
×
2711
            suggestions = self._suggest_channels_for_rebalance(direction, amount_sat)
×
2712
        except NotEnoughFunds:
×
2713
            return False
×
2714
        for chan2, delta in suggestions:
×
2715
            # margin for fee caused by rebalancing
2716
            delta += self.fee_estimate(amount_sat)
×
2717
            # find other channel or trampoline that can send delta
2718
            for chan1 in self.channels.values():
×
2719
                if chan1.is_frozen_for_sending() or not chan1.is_active():
×
2720
                    continue
×
2721
                if chan1 == chan2:
×
2722
                    continue
×
2723
                if self.uses_trampoline() and chan1.node_id == chan2.node_id:
×
2724
                    continue
×
2725
                if direction == SENT:
×
2726
                    if chan1.can_pay(delta*1000):
×
2727
                        return (chan1, chan2, delta)
×
2728
                else:
2729
                    if chan1.can_receive(delta*1000):
×
2730
                        return (chan2, chan1, delta)
×
2731
            else:
2732
                continue
×
2733
        else:
2734
            return False
×
2735

2736
    def num_sats_can_rebalance(self, chan1, chan2):
5✔
2737
        # TODO: we should be able to spend 'max', with variable fee
2738
        n1 = chan1.available_to_spend(LOCAL)
×
2739
        n1 -= self.fee_estimate(n1)
×
2740
        n2 = chan2.available_to_spend(REMOTE)
×
2741
        amount_sat = min(n1, n2) // 1000
×
2742
        return amount_sat
×
2743

2744
    def suggest_rebalance_to_send(self, amount_sat):
5✔
2745
        return self._suggest_rebalance(SENT, amount_sat)
×
2746

2747
    def suggest_rebalance_to_receive(self, amount_sat):
5✔
2748
        return self._suggest_rebalance(RECEIVED, amount_sat)
×
2749

2750
    def suggest_swap_to_send(self, amount_sat, coins):
5✔
2751
        # fixme: if swap_amount_sat is lower than the minimum swap amount, we need to propose a higher value
2752
        assert amount_sat > self.num_sats_can_send()
×
2753
        try:
×
2754
            suggestions = self._suggest_channels_for_rebalance(SENT, amount_sat)
×
2755
        except NotEnoughFunds:
×
2756
            return
×
2757
        for chan, swap_recv_amount in suggestions:
×
2758
            # check that we can send onchain
2759
            swap_server_mining_fee = 10000 # guessing, because we have not called get_pairs yet
×
2760
            swap_funding_sat = swap_recv_amount + swap_server_mining_fee
×
2761
            swap_output = PartialTxOutput.from_address_and_value(DummyAddress.SWAP, int(swap_funding_sat))
×
2762
            if not self.wallet.can_pay_onchain([swap_output], coins=coins):
×
2763
                continue
×
2764
            return (chan, swap_recv_amount)
×
2765

2766
    def suggest_swap_to_receive(self, amount_sat):
5✔
2767
        assert amount_sat > self.num_sats_can_receive()
×
2768
        try:
×
2769
            suggestions = self._suggest_channels_for_rebalance(RECEIVED, amount_sat)
×
2770
        except NotEnoughFunds:
×
2771
            return
×
2772
        for chan, swap_recv_amount in suggestions:
×
2773
            return (chan, swap_recv_amount)
×
2774

2775
    async def rebalance_channels(self, chan1: Channel, chan2: Channel, *, amount_msat: int):
5✔
2776
        if chan1 == chan2:
×
2777
            raise Exception('Rebalance requires two different channels')
×
2778
        if self.uses_trampoline() and chan1.node_id == chan2.node_id:
×
2779
            raise Exception('Rebalance requires channels from different trampolines')
×
2780
        payment_hash = self.create_payment_info(amount_msat=amount_msat)
×
2781
        lnaddr, invoice = self.get_bolt11_invoice(
×
2782
            payment_hash=payment_hash,
2783
            amount_msat=amount_msat,
2784
            message='rebalance',
2785
            expiry=3600,
2786
            fallback_address=None,
2787
            channels=[chan2],
2788
        )
2789
        return await self.pay_invoice(
×
2790
            invoice, channels=[chan1])
2791

2792
    def can_receive_invoice(self, invoice: BaseInvoice) -> bool:
5✔
2793
        assert invoice.is_lightning()
×
2794
        return (invoice.get_amount_sat() or 0) <= self.num_sats_can_receive()
×
2795

2796
    async def close_channel(self, chan_id):
5✔
2797
        chan = self._channels[chan_id]
×
2798
        peer = self._peers[chan.node_id]
×
2799
        return await peer.close_channel(chan_id)
×
2800

2801
    def _force_close_channel(self, chan_id: bytes) -> Transaction:
5✔
2802
        chan = self._channels[chan_id]
5✔
2803
        tx = chan.force_close_tx()
5✔
2804
        # We set the channel state to make sure we won't sign new commitment txs.
2805
        # We expect the caller to try to broadcast this tx, after which it is
2806
        # not safe to keep using the channel even if the broadcast errors (server could be lying).
2807
        # Until the tx is seen in the mempool, there will be automatic rebroadcasts.
2808
        chan.set_state(ChannelState.FORCE_CLOSING)
5✔
2809
        # Add local tx to wallet to also allow manual rebroadcasts.
2810
        try:
5✔
2811
            self.wallet.adb.add_transaction(tx)
5✔
2812
        except UnrelatedTransactionException:
×
2813
            pass  # this can happen if (~all the balance goes to REMOTE)
×
2814
        return tx
5✔
2815

2816
    async def force_close_channel(self, chan_id: bytes) -> str:
5✔
2817
        """Force-close the channel. Network-related exceptions are propagated to the caller.
2818
        (automatic rebroadcasts will be scheduled)
2819
        """
2820
        # note: as we are async, it can take a few event loop iterations between the caller
2821
        #       "calling us" and us getting to run, and we only set the channel state now:
2822
        tx = self._force_close_channel(chan_id)
5✔
2823
        await self.network.broadcast_transaction(tx)
5✔
2824
        return tx.txid()
5✔
2825

2826
    def schedule_force_closing(self, chan_id: bytes) -> 'asyncio.Task[bool]':
5✔
2827
        """Schedules a task to force-close the channel and returns it.
2828
        Network-related exceptions are suppressed.
2829
        (automatic rebroadcasts will be scheduled)
2830
        Note: this method is intentionally not async so that callers have a guarantee
2831
              that the channel state is set immediately.
2832
        """
2833
        tx = self._force_close_channel(chan_id)
5✔
2834
        return asyncio.create_task(self.network.try_broadcasting(tx, 'force-close'))
5✔
2835

2836
    def remove_channel(self, chan_id):
5✔
2837
        chan = self.channels[chan_id]
×
2838
        assert chan.can_be_deleted()
×
2839
        with self.lock:
×
2840
            self._channels.pop(chan_id)
×
2841
            self.db.get('channels').pop(chan_id.hex())
×
2842
        self.wallet.set_reserved_addresses_for_chan(chan, reserved=False)
×
2843

2844
        util.trigger_callback('channels_updated', self.wallet)
×
2845
        util.trigger_callback('wallet_updated', self.wallet)
×
2846

2847
    @ignore_exceptions
5✔
2848
    @log_exceptions
5✔
2849
    async def reestablish_peer_for_given_channel(self, chan: Channel) -> None:
5✔
2850
        now = time.time()
×
2851
        peer_addresses = []
×
2852
        if self.uses_trampoline():
×
2853
            addr = trampolines_by_id().get(chan.node_id)
×
2854
            if addr:
×
2855
                peer_addresses.append(addr)
×
2856
        else:
2857
            # will try last good address first, from gossip
2858
            last_good_addr = self.channel_db.get_last_good_address(chan.node_id)
×
2859
            if last_good_addr:
×
2860
                peer_addresses.append(last_good_addr)
×
2861
            # will try addresses for node_id from gossip
2862
            addrs_from_gossip = self.channel_db.get_node_addresses(chan.node_id) or []
×
2863
            for host, port, ts in addrs_from_gossip:
×
2864
                peer_addresses.append(LNPeerAddr(host, port, chan.node_id))
×
2865
        # will try addresses stored in channel storage
2866
        peer_addresses += list(chan.get_peer_addresses())
×
2867
        # Done gathering addresses.
2868
        # Now select first one that has not failed recently.
2869
        for peer in peer_addresses:
×
2870
            if self._can_retry_addr(peer, urgent=True, now=now):
×
2871
                await self._add_peer(peer.host, peer.port, peer.pubkey)
×
2872
                return
×
2873

2874
    async def reestablish_peers_and_channels(self):
5✔
2875
        while True:
×
2876
            await asyncio.sleep(1)
×
2877
            if self.stopping_soon:
×
2878
                return
×
2879
            if self.config.ZEROCONF_TRUSTED_NODE:
×
2880
                peer = LNPeerAddr.from_str(self.config.ZEROCONF_TRUSTED_NODE)
×
2881
                if self._can_retry_addr(peer, urgent=True):
×
2882
                    await self._add_peer(peer.host, peer.port, peer.pubkey)
×
2883
            for chan in self.channels.values():
×
2884
                # reestablish
2885
                # note: we delegate filtering out uninteresting chans to this:
2886
                if not chan.should_try_to_reestablish_peer():
×
2887
                    continue
×
2888
                peer = self._peers.get(chan.node_id, None)
×
2889
                if peer:
×
2890
                    await peer.taskgroup.spawn(peer.reestablish_channel(chan))
×
2891
                else:
2892
                    await self.taskgroup.spawn(self.reestablish_peer_for_given_channel(chan))
×
2893

2894
    def current_target_feerate_per_kw(self) -> int:
5✔
2895
        from .simple_config import FEE_LN_ETA_TARGET, FEERATE_FALLBACK_STATIC_FEE
5✔
2896
        from .simple_config import FEERATE_PER_KW_MIN_RELAY_LIGHTNING
5✔
2897
        if constants.net is constants.BitcoinRegtest:
5✔
2898
            feerate_per_kvbyte = self.network.config.FEE_EST_STATIC_FEERATE
×
2899
        else:
2900
            feerate_per_kvbyte = self.network.config.eta_target_to_fee(FEE_LN_ETA_TARGET)
5✔
2901
            if feerate_per_kvbyte is None:
5✔
2902
                feerate_per_kvbyte = FEERATE_FALLBACK_STATIC_FEE
5✔
2903
        return max(FEERATE_PER_KW_MIN_RELAY_LIGHTNING, feerate_per_kvbyte // 4)
5✔
2904

2905
    def current_low_feerate_per_kw(self) -> int:
5✔
2906
        from .simple_config import FEE_LN_LOW_ETA_TARGET
5✔
2907
        from .simple_config import FEERATE_PER_KW_MIN_RELAY_LIGHTNING
5✔
2908
        if constants.net is constants.BitcoinRegtest:
5✔
2909
            feerate_per_kvbyte = 0
×
2910
        else:
2911
            feerate_per_kvbyte = self.network.config.eta_target_to_fee(FEE_LN_LOW_ETA_TARGET) or 0
5✔
2912
        low_feerate_per_kw = max(FEERATE_PER_KW_MIN_RELAY_LIGHTNING, feerate_per_kvbyte // 4)
5✔
2913
        # make sure this is never higher than the target feerate:
2914
        low_feerate_per_kw = min(low_feerate_per_kw, self.current_target_feerate_per_kw())
5✔
2915
        return low_feerate_per_kw
5✔
2916

2917
    def create_channel_backup(self, channel_id: bytes):
5✔
2918
        chan = self._channels[channel_id]
×
2919
        # do not backup old-style channels
2920
        assert chan.is_static_remotekey_enabled()
×
2921
        peer_addresses = list(chan.get_peer_addresses())
×
2922
        peer_addr = peer_addresses[0]
×
2923
        return ImportedChannelBackupStorage(
×
2924
            node_id = chan.node_id,
2925
            privkey = self.node_keypair.privkey,
2926
            funding_txid = chan.funding_outpoint.txid,
2927
            funding_index = chan.funding_outpoint.output_index,
2928
            funding_address = chan.get_funding_address(),
2929
            host = peer_addr.host,
2930
            port = peer_addr.port,
2931
            is_initiator = chan.constraints.is_initiator,
2932
            channel_seed = chan.config[LOCAL].channel_seed,
2933
            local_delay = chan.config[LOCAL].to_self_delay,
2934
            remote_delay = chan.config[REMOTE].to_self_delay,
2935
            remote_revocation_pubkey = chan.config[REMOTE].revocation_basepoint.pubkey,
2936
            remote_payment_pubkey = chan.config[REMOTE].payment_basepoint.pubkey,
2937
            local_payment_pubkey=chan.config[LOCAL].payment_basepoint.pubkey,
2938
            multisig_funding_privkey=chan.config[LOCAL].multisig_key.privkey,
2939
        )
2940

2941
    def export_channel_backup(self, channel_id):
5✔
2942
        xpub = self.wallet.get_fingerprint()
×
2943
        backup_bytes = self.create_channel_backup(channel_id).to_bytes()
×
2944
        assert backup_bytes == ImportedChannelBackupStorage.from_bytes(backup_bytes).to_bytes(), "roundtrip failed"
×
2945
        encrypted = pw_encode_with_version_and_mac(backup_bytes, xpub)
×
2946
        assert backup_bytes == pw_decode_with_version_and_mac(encrypted, xpub), "encrypt failed"
×
2947
        return 'channel_backup:' + encrypted
×
2948

2949
    async def request_force_close(self, channel_id: bytes, *, connect_str=None) -> None:
5✔
2950
        if channel_id in self.channels:
×
2951
            chan = self.channels[channel_id]
×
2952
            peer = self._peers.get(chan.node_id)
×
2953
            chan.should_request_force_close = True
×
2954
            if peer:
×
2955
                peer.close_and_cleanup()  # to force a reconnect
×
2956
        elif connect_str:
×
2957
            peer = await self.add_peer(connect_str)
×
2958
            await peer.request_force_close(channel_id)
×
2959
        elif channel_id in self.channel_backups:
×
2960
            await self._request_force_close_from_backup(channel_id)
×
2961
        else:
2962
            raise Exception(f'Unknown channel {channel_id.hex()}')
×
2963

2964
    def import_channel_backup(self, data):
5✔
2965
        xpub = self.wallet.get_fingerprint()
×
2966
        cb_storage = ImportedChannelBackupStorage.from_encrypted_str(data, password=xpub)
×
2967
        channel_id = cb_storage.channel_id()
×
2968
        if channel_id.hex() in self.db.get_dict("channels"):
×
2969
            raise Exception('Channel already in wallet')
×
2970
        self.logger.info(f'importing channel backup: {channel_id.hex()}')
×
2971
        d = self.db.get_dict("imported_channel_backups")
×
2972
        d[channel_id.hex()] = cb_storage
×
2973
        with self.lock:
×
2974
            cb = ChannelBackup(cb_storage, lnworker=self)
×
2975
            self._channel_backups[channel_id] = cb
×
2976
        self.wallet.set_reserved_addresses_for_chan(cb, reserved=True)
×
2977
        self.wallet.save_db()
×
2978
        util.trigger_callback('channels_updated', self.wallet)
×
2979
        self.lnwatcher.add_channel(cb.funding_outpoint.to_str(), cb.get_funding_address())
×
2980

2981
    def has_conflicting_backup_with(self, remote_node_id: bytes):
5✔
2982
        """ Returns whether we have an active channel with this node on another device, using same local node id. """
2983
        channel_backup_peers = [
×
2984
            cb.node_id for cb in self.channel_backups.values()
2985
            if (not cb.is_closed() and cb.get_local_pubkey() == self.node_keypair.pubkey)]
2986
        return any(remote_node_id.startswith(cb_peer_nodeid) for cb_peer_nodeid in channel_backup_peers)
×
2987

2988
    def remove_channel_backup(self, channel_id):
5✔
2989
        chan = self.channel_backups[channel_id]
×
2990
        assert chan.can_be_deleted()
×
2991
        found = False
×
2992
        onchain_backups = self.db.get_dict("onchain_channel_backups")
×
2993
        imported_backups = self.db.get_dict("imported_channel_backups")
×
2994
        if channel_id.hex() in onchain_backups:
×
2995
            onchain_backups.pop(channel_id.hex())
×
2996
            found = True
×
2997
        if channel_id.hex() in imported_backups:
×
2998
            imported_backups.pop(channel_id.hex())
×
2999
            found = True
×
3000
        if not found:
×
3001
            raise Exception('Channel not found')
×
3002
        with self.lock:
×
3003
            self._channel_backups.pop(channel_id)
×
3004
        self.wallet.set_reserved_addresses_for_chan(chan, reserved=False)
×
3005
        self.wallet.save_db()
×
3006
        util.trigger_callback('channels_updated', self.wallet)
×
3007

3008
    @log_exceptions
5✔
3009
    async def _request_force_close_from_backup(self, channel_id: bytes):
5✔
3010
        cb = self.channel_backups.get(channel_id)
×
3011
        if not cb:
×
3012
            raise Exception(f'channel backup not found {self.channel_backups}')
×
3013
        cb = cb.cb # storage
×
3014
        self.logger.info(f'requesting channel force close: {channel_id.hex()}')
×
3015
        if isinstance(cb, ImportedChannelBackupStorage):
×
3016
            node_id = cb.node_id
×
3017
            privkey = cb.privkey
×
3018
            addresses = [(cb.host, cb.port, 0)]
×
3019
        else:
3020
            assert isinstance(cb, OnchainChannelBackupStorage)
×
3021
            privkey = self.node_keypair.privkey
×
3022
            for pubkey, peer_addr in trampolines_by_id().items():
×
3023
                if pubkey.startswith(cb.node_id_prefix):
×
3024
                    node_id = pubkey
×
3025
                    addresses = [(peer_addr.host, peer_addr.port, 0)]
×
3026
                    break
×
3027
            else:
3028
                # we will try with gossip (see below)
3029
                addresses = []
×
3030

3031
        async def _request_fclose(addresses):
×
3032
            for host, port, timestamp in addresses:
×
3033
                peer_addr = LNPeerAddr(host, port, node_id)
×
3034
                transport = LNTransport(privkey, peer_addr, e_proxy=ESocksProxy.from_network_settings(self.network))
×
3035
                peer = Peer(self, node_id, transport, is_channel_backup=True)
×
3036
                try:
×
3037
                    async with OldTaskGroup(wait=any) as group:
×
3038
                        await group.spawn(peer._message_loop())
×
3039
                        await group.spawn(peer.request_force_close(channel_id))
×
3040
                    return True
×
3041
                except Exception as e:
×
3042
                    self.logger.info(f'failed to connect {host} {e}')
×
3043
                    continue
×
3044
            else:
3045
                return False
×
3046
        # try first without gossip db
3047
        success = await _request_fclose(addresses)
×
3048
        if success:
×
3049
            return
×
3050
        # try with gossip db
3051
        if self.uses_trampoline():
×
3052
            raise Exception(_('Please enable gossip'))
×
3053
        node_id = self.network.channel_db.get_node_by_prefix(cb.node_id_prefix)
×
3054
        addresses_from_gossip = self.network.channel_db.get_node_addresses(node_id)
×
3055
        if not addresses_from_gossip:
×
3056
            raise Exception('Peer not found in gossip database')
×
3057
        success = await _request_fclose(addresses_from_gossip)
×
3058
        if not success:
×
3059
            raise Exception('failed to connect')
×
3060

3061
    def maybe_add_backup_from_tx(self, tx):
5✔
3062
        funding_address = None
5✔
3063
        node_id_prefix = None
5✔
3064
        for i, o in enumerate(tx.outputs()):
5✔
3065
            script_type = get_script_type_from_output_script(o.scriptpubkey)
5✔
3066
            if script_type == 'p2wsh':
5✔
3067
                funding_index = i
×
3068
                funding_address = o.address
×
3069
                for o2 in tx.outputs():
×
3070
                    if o2.scriptpubkey.startswith(bytes([opcodes.OP_RETURN])):
×
3071
                        encrypted_data = o2.scriptpubkey[2:]
×
3072
                        data = self.decrypt_cb_data(encrypted_data, funding_address)
×
3073
                        if data.startswith(CB_MAGIC_BYTES):
×
3074
                            node_id_prefix = data[len(CB_MAGIC_BYTES):]
×
3075
        if node_id_prefix is None:
5✔
3076
            return
5✔
3077
        funding_txid = tx.txid()
×
3078
        cb_storage = OnchainChannelBackupStorage(
×
3079
            node_id_prefix = node_id_prefix,
3080
            funding_txid = funding_txid,
3081
            funding_index = funding_index,
3082
            funding_address = funding_address,
3083
            is_initiator = True)
3084
        channel_id = cb_storage.channel_id().hex()
×
3085
        if channel_id in self.db.get_dict("channels"):
×
3086
            return
×
3087
        self.logger.info(f"adding backup from tx")
×
3088
        d = self.db.get_dict("onchain_channel_backups")
×
3089
        d[channel_id] = cb_storage
×
3090
        cb = ChannelBackup(cb_storage, lnworker=self)
×
3091
        self.wallet.set_reserved_addresses_for_chan(cb, reserved=True)
×
3092
        self.wallet.save_db()
×
3093
        with self.lock:
×
3094
            self._channel_backups[bfh(channel_id)] = cb
×
3095
        util.trigger_callback('channels_updated', self.wallet)
×
3096
        self.lnwatcher.add_channel(cb.funding_outpoint.to_str(), cb.get_funding_address())
×
3097

3098
    def save_forwarding_failure(
5✔
3099
            self, payment_key:str, *,
3100
            error_bytes: Optional[bytes] = None,
3101
            failure_message: Optional['OnionRoutingFailure'] = None):
3102
        error_hex = error_bytes.hex() if error_bytes else None
5✔
3103
        failure_hex = failure_message.to_bytes().hex() if failure_message else None
5✔
3104
        self.forwarding_failures[payment_key] = (error_hex, failure_hex)
5✔
3105

3106
    def get_forwarding_failure(self, payment_key: str) -> Tuple[Optional[bytes], Optional['OnionRoutingFailure']]:
5✔
3107
        error_hex, failure_hex = self.forwarding_failures.get(payment_key, (None, None))
5✔
3108
        error_bytes = bytes.fromhex(error_hex) if error_hex else None
5✔
3109
        failure_message = OnionRoutingFailure.from_bytes(bytes.fromhex(failure_hex)) if failure_hex else None
5✔
3110
        return error_bytes, failure_message
5✔
3111

STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc