• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

spesmilo / electrum / 4755190865199104

15 Dec 2025 03:38PM UTC coverage: 62.419%. Remained the same
4755190865199104

push

CirrusCI

SomberNight
tests: lnpeer: simplify MockNetwork: rm dead code

not needed since https://github.com/spesmilo/electrum/commit/922440410

23738 of 38030 relevant lines covered (62.42%)

0.62 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

65.19
/electrum/util.py
1
# Electrum - lightweight Bitcoin client
2
# Copyright (C) 2011 Thomas Voegtlin
3
#
4
# Permission is hereby granted, free of charge, to any person
5
# obtaining a copy of this software and associated documentation files
6
# (the "Software"), to deal in the Software without restriction,
7
# including without limitation the rights to use, copy, modify, merge,
8
# publish, distribute, sublicense, and/or sell copies of the Software,
9
# and to permit persons to whom the Software is furnished to do so,
10
# subject to the following conditions:
11
#
12
# The above copyright notice and this permission notice shall be
13
# included in all copies or substantial portions of the Software.
14
#
15
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
16
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
17
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
18
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
19
# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
20
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
21
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
# SOFTWARE.
23
import concurrent.futures
1✔
24
from dataclasses import dataclass
1✔
25
import logging
1✔
26
import os
1✔
27
import sys
1✔
28
import re
1✔
29
from collections import defaultdict, OrderedDict
1✔
30
from concurrent.futures.process import ProcessPoolExecutor
1✔
31
from typing import (
1✔
32
    NamedTuple, Union, TYPE_CHECKING, Tuple, Optional, Callable, Any, Sequence, Dict, Generic, TypeVar, List, Iterable,
33
    Set, Awaitable
34
)
35
from types import MappingProxyType
1✔
36
from datetime import datetime, timezone, timedelta
1✔
37
import decimal
1✔
38
from decimal import Decimal
1✔
39
import threading
1✔
40
import hmac
1✔
41
import hashlib
1✔
42
import stat
1✔
43
import asyncio
1✔
44
import builtins
1✔
45
import json
1✔
46
import time
1✔
47
import ssl
1✔
48
import ipaddress
1✔
49
from ipaddress import IPv4Address, IPv6Address
1✔
50
import random
1✔
51
import secrets
1✔
52
import functools
1✔
53
from functools import partial
1✔
54
from abc import abstractmethod, ABC
1✔
55
import enum
1✔
56
from contextlib import nullcontext, suppress
1✔
57
import traceback
1✔
58
import inspect
1✔
59

60
import aiohttp
1✔
61
from aiohttp_socks import ProxyConnector, ProxyType
1✔
62
import aiorpcx
1✔
63
import certifi
1✔
64
import dns.asyncresolver
1✔
65

66
from .i18n import _
1✔
67
from .logging import get_logger, Logger
1✔
68

69
if TYPE_CHECKING:
70
    from .network import Network, ProxySettings
71
    from .interface import Interface
72
    from .simple_config import SimpleConfig
73

74

75
_logger = get_logger(__name__)
1✔
76

77

78
def inv_dict(d):
1✔
79
    return {v: k for k, v in d.items()}
1✔
80

81

82
def all_subclasses(cls) -> Set:
1✔
83
    """Return all (transitive) subclasses of cls."""
84
    res = set(cls.__subclasses__())
1✔
85
    for sub in res.copy():
1✔
86
        res |= all_subclasses(sub)
1✔
87
    return res
1✔
88

89

90
ca_path = certifi.where()
1✔
91

92

93
base_units = {'BTC':8, 'mBTC':5, 'bits':2, 'sat':0}
1✔
94
base_units_inverse = inv_dict(base_units)
1✔
95
base_units_list = ['BTC', 'mBTC', 'bits', 'sat']  # list(dict) does not guarantee order
1✔
96

97
DECIMAL_POINT_DEFAULT = 5  # mBTC
1✔
98

99

100
class UnknownBaseUnit(Exception): pass
1✔
101

102

103
def decimal_point_to_base_unit_name(dp: int) -> str:
1✔
104
    # e.g. 8 -> "BTC"
105
    try:
1✔
106
        return base_units_inverse[dp]
1✔
107
    except KeyError:
×
108
        raise UnknownBaseUnit(dp) from None
×
109

110

111
def base_unit_name_to_decimal_point(unit_name: str) -> int:
1✔
112
    """Returns the max number of digits allowed after the decimal point."""
113
    # e.g. "BTC" -> 8
114
    try:
1✔
115
        return base_units[unit_name]
1✔
116
    except KeyError:
×
117
        raise UnknownBaseUnit(unit_name) from None
×
118

119
def parse_max_spend(amt: Any) -> Optional[int]:
1✔
120
    """Checks if given amount is "spend-max"-like.
121
    Returns None or the positive integer weight for "max". Never raises.
122

123
    When creating invoices and on-chain txs, the user can specify to send "max".
124
    This is done by setting the amount to '!'. Splitting max between multiple
125
    tx outputs is also possible, and custom weights (positive ints) can also be used.
126
    For example, to send 40% of all coins to address1, and 60% to address2:
127
    ```
128
    address1, 2!
129
    address2, 3!
130
    ```
131
    """
132
    if not (isinstance(amt, str) and amt and amt[-1] == '!'):
1✔
133
        return None
1✔
134
    if amt == '!':
1✔
135
        return 1
1✔
136
    x = amt[:-1]
1✔
137
    try:
1✔
138
        x = int(x)
1✔
139
    except ValueError:
×
140
        return None
×
141
    if x > 0:
1✔
142
        return x
1✔
143
    return None
×
144

145
class NotEnoughFunds(Exception):
1✔
146
    def __str__(self):
1✔
147
        return _("Insufficient funds")
×
148

149

150
class UneconomicFee(Exception):
1✔
151
    def __str__(self):
1✔
152
        return _("The fee for the transaction is higher than the funds gained from it.")
×
153

154

155
class NoDynamicFeeEstimates(Exception):
1✔
156
    def __str__(self):
1✔
157
        return _('Dynamic fee estimates not available')
×
158

159

160
class BelowDustLimit(Exception):
1✔
161
    pass
1✔
162

163

164
class InvalidPassword(Exception):
1✔
165
    def __init__(self, message: Optional[str] = None):
1✔
166
        self.message = message
1✔
167

168
    def __str__(self):
1✔
169
        if self.message is None:
×
170
            return _("Incorrect password")
×
171
        else:
172
            return str(self.message)
×
173

174

175
class AddTransactionException(Exception):
1✔
176
    pass
1✔
177

178

179
class UnrelatedTransactionException(AddTransactionException):
1✔
180
    def __str__(self):
1✔
181
        return _("Transaction is unrelated to this wallet.")
×
182

183

184
class FileImportFailed(Exception):
1✔
185
    def __init__(self, message=''):
1✔
186
        self.message = str(message)
×
187

188
    def __str__(self):
1✔
189
        return _("Failed to import from file.") + "\n" + self.message
×
190

191

192
class FileExportFailed(Exception):
1✔
193
    def __init__(self, message=''):
1✔
194
        self.message = str(message)
×
195

196
    def __str__(self):
1✔
197
        return _("Failed to export to file.") + "\n" + self.message
×
198

199

200
class WalletFileException(Exception):
1✔
201
    def __init__(self, message='', *, should_report_crash: bool = False):
1✔
202
        Exception.__init__(self, message)
1✔
203
        self.should_report_crash = should_report_crash
1✔
204

205

206
class BitcoinException(Exception): pass
1✔
207

208

209
class UserFacingException(Exception):
1✔
210
    """Exception that contains information intended to be shown to the user."""
211

212

213
class InvoiceError(UserFacingException): pass
1✔
214

215

216
class NetworkOfflineException(UserFacingException):
1✔
217
    """Can be raised if we are running in offline mode (--offline flag)
218
    and the user requests an operation that requires the network.
219
    """
220
    def __str__(self):
1✔
221
        return _("You are offline.")
×
222

223

224
# Throw this exception to unwind the stack like when an error occurs.
225
# However unlike other exceptions the user won't be informed.
226
class UserCancelled(Exception):
1✔
227
    '''An exception that is suppressed from the user'''
228
    pass
1✔
229

230

231
def to_decimal(x: Union[str, float, int, Decimal]) -> Decimal:
1✔
232
    # helper function mainly for float->Decimal conversion, i.e.:
233
    #   >>> Decimal(41754.681)
234
    #   Decimal('41754.680999999996856786310672760009765625')
235
    #   >>> Decimal("41754.681")
236
    #   Decimal('41754.681')
237
    if isinstance(x, Decimal):
1✔
238
        return x
1✔
239
    if isinstance(x, int):
1✔
240
        return Decimal(x)
1✔
241
    return Decimal(str(x))
1✔
242

243

244
# note: this is not a NamedTuple as then its json encoding cannot be customized
245
class Satoshis(object):
1✔
246
    __slots__ = ('value',)
1✔
247

248
    def __new__(cls, value):
1✔
249
        self = super(Satoshis, cls).__new__(cls)
1✔
250
        # note: 'value' sometimes has msat precision
251
        assert isinstance(value, (int, Decimal)), f"unexpected type for {value=!r}"
1✔
252
        self.value = value
1✔
253
        return self
1✔
254

255
    def __repr__(self):
1✔
256
        return f'Satoshis({self.value})'
×
257

258
    def __str__(self):
1✔
259
        # note: precision is truncated to satoshis here
260
        return format_satoshis(self.value)
1✔
261

262
    def __eq__(self, other):
1✔
263
        return self.value == other.value
×
264

265
    def __ne__(self, other):
1✔
266
        return not (self == other)
×
267

268
    def __add__(self, other):
1✔
269
        return Satoshis(self.value + other.value)
×
270

271

272
# note: this is not a NamedTuple as then its json encoding cannot be customized
273
class Fiat(object):
1✔
274
    __slots__ = ('value', 'ccy')
1✔
275

276
    def __new__(cls, value: Optional[Decimal], ccy: str):
1✔
277
        self = super(Fiat, cls).__new__(cls)
1✔
278
        self.ccy = ccy
1✔
279
        if not isinstance(value, (Decimal, type(None))):
1✔
280
            raise TypeError(f"value should be Decimal or None, not {type(value)}")
×
281
        self.value = value
1✔
282
        return self
1✔
283

284
    def __repr__(self):
1✔
285
        return 'Fiat(%s)'% self.__str__()
×
286

287
    def __str__(self):
1✔
288
        if self.value is None or self.value.is_nan():
1✔
289
            return _('No Data')
×
290
        else:
291
            return "{:.2f}".format(self.value)
1✔
292

293
    def to_ui_string(self):
1✔
294
        if self.value is None or self.value.is_nan():
×
295
            return _('No Data')
×
296
        else:
297
            return "{:.2f}".format(self.value) + ' ' + self.ccy
×
298

299
    def __eq__(self, other):
1✔
300
        if not isinstance(other, Fiat):
×
301
            return False
×
302
        if self.ccy != other.ccy:
×
303
            return False
×
304
        if isinstance(self.value, Decimal) and isinstance(other.value, Decimal) \
×
305
                and self.value.is_nan() and other.value.is_nan():
306
            return True
×
307
        return self.value == other.value
×
308

309
    def __ne__(self, other):
1✔
310
        return not (self == other)
×
311

312
    def __add__(self, other):
1✔
313
        assert self.ccy == other.ccy
×
314
        return Fiat(self.value + other.value, self.ccy)
×
315

316

317
class MyEncoder(json.JSONEncoder):
1✔
318
    def default(self, obj):
1✔
319
        # note: this does not get called for namedtuples :(  https://bugs.python.org/issue30343
320
        from .transaction import Transaction, TxOutput
1✔
321
        if isinstance(obj, Transaction):
1✔
322
            return obj.serialize()
1✔
323
        if isinstance(obj, TxOutput):
1✔
324
            return obj.to_legacy_tuple()
1✔
325
        if isinstance(obj, Satoshis):
1✔
326
            return str(obj)
1✔
327
        if isinstance(obj, Fiat):
1✔
328
            return str(obj)
1✔
329
        if isinstance(obj, Decimal):
1✔
330
            return str(obj)
×
331
        if isinstance(obj, datetime):
1✔
332
            # note: if there is a timezone specified, this will include the offset
333
            return obj.isoformat(' ', timespec="minutes")
1✔
334
        if isinstance(obj, set):
1✔
335
            return list(obj)
×
336
        if isinstance(obj, bytes): # for nametuples in lnchannel
1✔
337
            return obj.hex()
1✔
338
        if hasattr(obj, 'to_json') and callable(obj.to_json):
1✔
339
            return obj.to_json()
1✔
340
        return super(MyEncoder, self).default(obj)
×
341

342

343
class ThreadJob(Logger):
1✔
344
    """A job that is run periodically from a thread's main loop.  run() is
345
    called from that thread's context.
346
    """
347

348
    def __init__(self):
1✔
349
        Logger.__init__(self)
1✔
350

351
    def run(self):
1✔
352
        """Called periodically from the thread"""
353
        pass
×
354

355
class DebugMem(ThreadJob):
1✔
356
    '''A handy class for debugging GC memory leaks'''
357
    def __init__(self, classes, interval=30):
1✔
358
        ThreadJob.__init__(self)
×
359
        self.next_time = 0
×
360
        self.classes = classes
×
361
        self.interval = interval
×
362

363
    def mem_stats(self):
1✔
364
        import gc
×
365
        self.logger.info("Start memscan")
×
366
        gc.collect()
×
367
        objmap = defaultdict(list)
×
368
        for obj in gc.get_objects():
×
369
            for class_ in self.classes:
×
370
                try:
×
371
                    _isinstance = isinstance(obj, class_)
×
372
                except AttributeError:
×
373
                    _isinstance = False
×
374
                if _isinstance:
×
375
                    objmap[class_].append(obj)
×
376
        for class_, objs in objmap.items():
×
377
            self.logger.info(f"{class_.__name__}: {len(objs)}")
×
378
        self.logger.info("Finish memscan")
×
379

380
    def run(self):
1✔
381
        if time.time() > self.next_time:
×
382
            self.mem_stats()
×
383
            self.next_time = time.time() + self.interval
×
384

385
class DaemonThread(threading.Thread, Logger):
1✔
386
    """ daemon thread that terminates cleanly """
387

388
    def __init__(self):
1✔
389
        threading.Thread.__init__(self)
1✔
390
        Logger.__init__(self)
1✔
391
        self.parent_thread = threading.current_thread()
1✔
392
        self.running = False
1✔
393
        self.running_lock = threading.Lock()
1✔
394
        self.job_lock = threading.Lock()
1✔
395
        self.jobs = []
1✔
396
        self.stopped_event = threading.Event()        # set when fully stopped
1✔
397
        self.stopped_event_async = asyncio.Event()    # set when fully stopped
1✔
398
        self.wake_up_event = threading.Event()  # for perf optimisation of polling in run()
1✔
399

400
    def add_jobs(self, jobs):
1✔
401
        with self.job_lock:
1✔
402
            self.jobs.extend(jobs)
1✔
403

404
    def run_jobs(self):
1✔
405
        # Don't let a throwing job disrupt the thread, future runs of
406
        # itself, or other jobs.  This is useful protection against
407
        # malformed or malicious server responses
408
        with self.job_lock:
1✔
409
            for job in self.jobs:
1✔
410
                try:
1✔
411
                    job.run()
1✔
412
                except Exception as e:
×
413
                    self.logger.exception('')
×
414

415
    def remove_jobs(self, jobs):
1✔
416
        with self.job_lock:
×
417
            for job in jobs:
×
418
                self.jobs.remove(job)
×
419

420
    def start(self):
1✔
421
        with self.running_lock:
1✔
422
            self.running = True
1✔
423
        return threading.Thread.start(self)
1✔
424

425
    def is_running(self):
1✔
426
        with self.running_lock:
1✔
427
            return self.running and self.parent_thread.is_alive()
1✔
428

429
    def stop(self):
1✔
430
        with self.running_lock:
1✔
431
            self.running = False
1✔
432
            self.wake_up_event.set()
1✔
433
            self.wake_up_event.clear()
1✔
434

435
    def on_stop(self):
1✔
436
        if 'ANDROID_DATA' in os.environ:
1✔
437
            import jnius
×
438
            jnius.detach()
×
439
            self.logger.info("jnius detach")
×
440
        self.logger.info("stopped")
1✔
441
        self.stopped_event.set()
1✔
442
        loop = get_asyncio_loop()
1✔
443
        loop.call_soon_threadsafe(self.stopped_event_async.set)
1✔
444

445

446
def print_stderr(*args):
1✔
447
    args = [str(item) for item in args]
×
448
    sys.stderr.write(" ".join(args) + "\n")
×
449
    sys.stderr.flush()
×
450

451

452
def print_msg(*args):
1✔
453
    # Stringify args
454
    args = [str(item) for item in args]
×
455
    sys.stdout.write(" ".join(args) + "\n")
×
456
    sys.stdout.flush()
×
457

458

459
def json_encode(obj):
1✔
460
    try:
1✔
461
        s = json.dumps(obj, sort_keys = True, indent = 4, cls=MyEncoder)
1✔
462
    except TypeError:
×
463
        s = repr(obj)
×
464
    return s
1✔
465

466

467
def json_decode(x):
1✔
468
    try:
1✔
469
        return json.loads(x, parse_float=Decimal)
1✔
470
    except Exception:
1✔
471
        return x
1✔
472

473

474
def json_normalize(x):
1✔
475
    # note: The return value of commands, when going through the JSON-RPC interface,
476
    #       is json-encoded. The encoder used there cannot handle some types, e.g. electrum.util.Satoshis.
477
    # note: We should not simply do "json_encode(x)" here, as then later x would get doubly json-encoded.
478
    # see #5868
479
    return json_decode(json_encode(x))
1✔
480

481

482
# taken from Django Source Code
483
def constant_time_compare(val1, val2):
1✔
484
    """Return True if the two strings are equal, False otherwise."""
485
    return hmac.compare_digest(to_bytes(val1, 'utf8'), to_bytes(val2, 'utf8'))
1✔
486

487

488
_profiler_logger = _logger.getChild('profiler')
1✔
489

490

491
def profiler(func=None, *, min_threshold: Union[int, float, None] = None):
1✔
492
    """Function decorator that logs execution time.
493

494
    min_threshold: if set, only log if time taken is higher than threshold
495
    """
496
    if func is None:  # to make "@profiler(...)" work. (in addition to bare "@profiler")
1✔
497
        return partial(profiler, min_threshold=min_threshold)
1✔
498
    t0 = None  # type: Optional[float]
1✔
499

500
    def timer_start():
1✔
501
        nonlocal t0
502
        t0 = time.time()
1✔
503

504
    def timer_done():
1✔
505
        t = time.time() - t0
1✔
506
        if min_threshold is None or t > min_threshold:
1✔
507
            _profiler_logger.debug(f"{func.__qualname__} {t:,.4f} sec")
1✔
508

509
    if inspect.iscoroutinefunction(func):
1✔
510
        async def do_profile(*args, **kw_args):
×
511
            timer_start()
×
512
            o = await func(*args, **kw_args)
×
513
            timer_done()
×
514
            return o
×
515
    else:
516
        def do_profile(*args, **kw_args):
1✔
517
            timer_start()
1✔
518
            o = func(*args, **kw_args)
1✔
519
            timer_done()
1✔
520
            return o
1✔
521
    return do_profile
1✔
522

523

524
class AsyncHangDetector:
1✔
525
    """Context manager that logs every `n` seconds if encapsulated context still has not exited."""
526

527
    def __init__(
1✔
528
        self,
529
        *,
530
        period_sec: int = 15,
531
        message: str,
532
        logger: logging.Logger = None,
533
    ):
534
        self.period_sec = period_sec
1✔
535
        self.message = message
1✔
536
        self.logger = logger or _logger
1✔
537

538
    async def _monitor(self):
1✔
539
        # note: this assumes that the event loop itself is not blocked
540
        t0 = time.monotonic()
1✔
541
        while True:
1✔
542
            await asyncio.sleep(self.period_sec)
1✔
543
            t1 = time.monotonic()
×
544
            self.logger.info(f"{self.message} (after {t1 - t0:.2f} sec)")
×
545

546
    async def __aenter__(self):
1✔
547
        self.mtask = asyncio.create_task(self._monitor())
1✔
548

549
    async def __aexit__(self, exc_type, exc, tb):
1✔
550
        self.mtask.cancel()
1✔
551

552

553
def android_ext_dir():
1✔
554
    from android.storage import primary_external_storage_path
×
555
    return primary_external_storage_path()
×
556

557

558
def android_backup_dir():
1✔
559
    pkgname = get_android_package_name()
×
560
    d = os.path.join(android_ext_dir(), pkgname)
×
561
    if not os.path.exists(d):
×
562
        os.mkdir(d)
×
563
    return d
×
564

565

566
def android_data_dir():
1✔
567
    import jnius
×
568
    PythonActivity = jnius.autoclass('org.kivy.android.PythonActivity')
×
569
    return PythonActivity.mActivity.getFilesDir().getPath() + '/data'
×
570

571

572
def ensure_sparse_file(filename):
1✔
573
    # On modern Linux, no need to do anything.
574
    # On Windows, need to explicitly mark file.
575
    if os.name == "nt":
1✔
576
        try:
×
577
            os.system('fsutil sparse setflag "{}" 1'.format(filename))
×
578
        except Exception as e:
×
579
            _logger.info(f'error marking file {filename} as sparse: {e}')
×
580

581

582
def get_headers_dir(config):
1✔
583
    return config.path
1✔
584

585

586
def assert_datadir_available(config_path):
1✔
587
    path = config_path
1✔
588
    if os.path.exists(path):
1✔
589
        return
1✔
590
    else:
591
        raise FileNotFoundError(
×
592
            'Electrum datadir does not exist. Was it deleted while running?' + '\n' +
593
            'Should be at {}'.format(path))
594

595

596
def assert_file_in_datadir_available(path, config_path):
1✔
597
    if os.path.exists(path):
×
598
        return
×
599
    else:
600
        assert_datadir_available(config_path)
×
601
        raise FileNotFoundError(
×
602
            'Cannot find file but datadir is there.' + '\n' +
603
            'Should be at {}'.format(path))
604

605

606
def standardize_path(path):
1✔
607
    # note: os.path.realpath() is not used, as on Windows it can return non-working paths (see #8495).
608
    #       This means that we don't resolve symlinks!
609
    return os.path.normcase(
1✔
610
                os.path.abspath(
611
                    os.path.expanduser(
612
                        path
613
    )))
614

615

616
def get_new_wallet_name(wallet_folder: str) -> str:
1✔
617
    """Returns a file basename for a new wallet to be used.
618
    Can raise OSError.
619
    """
620
    i = 1
1✔
621
    while True:
1✔
622
        filename = "wallet_%d" % i
1✔
623
        if filename in os.listdir(wallet_folder):
1✔
624
            i += 1
1✔
625
        else:
626
            break
1✔
627
    return filename
1✔
628

629

630
def is_android_debug_apk() -> bool:
1✔
631
    is_android = 'ANDROID_DATA' in os.environ
×
632
    if not is_android:
×
633
        return False
×
634
    from jnius import autoclass
×
635
    pkgname = get_android_package_name()
×
636
    build_config = autoclass(f"{pkgname}.BuildConfig")
×
637
    return bool(build_config.DEBUG)
×
638

639

640
def get_android_package_name() -> str:
1✔
641
    is_android = 'ANDROID_DATA' in os.environ
×
642
    assert is_android
×
643
    from jnius import autoclass
×
644
    from android.config import ACTIVITY_CLASS_NAME
×
645
    activity = autoclass(ACTIVITY_CLASS_NAME).mActivity
×
646
    pkgname = str(activity.getPackageName())
×
647
    return pkgname
×
648

649

650
def assert_bytes(*args):
1✔
651
    """
652
    porting helper, assert args type
653
    """
654
    try:
1✔
655
        for x in args:
1✔
656
            assert isinstance(x, (bytes, bytearray))
1✔
657
    except Exception:
×
658
        print('assert bytes failed', list(map(type, args)))
×
659
        raise
×
660

661

662
def assert_str(*args):
1✔
663
    """
664
    porting helper, assert args type
665
    """
666
    for x in args:
×
667
        assert isinstance(x, str)
×
668

669

670
def to_string(x, enc) -> str:
1✔
671
    if isinstance(x, (bytes, bytearray)):
1✔
672
        return x.decode(enc)
1✔
673
    if isinstance(x, str):
×
674
        return x
×
675
    else:
676
        raise TypeError("Not a string or bytes like object")
×
677

678

679
def to_bytes(something, encoding='utf8') -> bytes:
1✔
680
    """
681
    cast string to bytes() like object, but for python2 support it's bytearray copy
682
    """
683
    if isinstance(something, bytes):
1✔
684
        return something
1✔
685
    if isinstance(something, str):
1✔
686
        return something.encode(encoding)
1✔
687
    elif isinstance(something, bytearray):
1✔
688
        return bytes(something)
1✔
689
    else:
690
        raise TypeError("Not a string or bytes like object")
1✔
691

692

693
bfh = bytes.fromhex
1✔
694

695

696
def xor_bytes(a: bytes, b: bytes) -> bytes:
1✔
697
    size = min(len(a), len(b))
1✔
698
    return ((int.from_bytes(a[:size], "big") ^ int.from_bytes(b[:size], "big"))
1✔
699
            .to_bytes(size, "big"))
700

701

702
def user_dir():
1✔
703
    if "ELECTRUMDIR" in os.environ:
1✔
704
        return os.environ["ELECTRUMDIR"]
×
705
    elif 'ANDROID_DATA' in os.environ:
1✔
706
        return android_data_dir()
×
707
    elif os.name == 'posix':
1✔
708
        return os.path.join(os.environ["HOME"], ".electrum")
1✔
709
    elif "APPDATA" in os.environ:
×
710
        return os.path.join(os.environ["APPDATA"], "Electrum")
×
711
    elif "LOCALAPPDATA" in os.environ:
×
712
        return os.path.join(os.environ["LOCALAPPDATA"], "Electrum")
×
713
    else:
714
        #raise Exception("No home directory found in environment variables.")
715
        return
×
716

717

718
def resource_path(*parts):
1✔
719
    return os.path.join(pkg_dir, *parts)
1✔
720

721

722
# absolute path to python package folder of electrum ("lib")
723
pkg_dir = os.path.split(os.path.realpath(__file__))[0]
1✔
724

725

726
def is_valid_email(s):
1✔
727
    regexp = r"[^@]+@[^@]+\.[^@]+"
×
728
    return re.match(regexp, s) is not None
×
729

730

731
def is_valid_websocket_url(url: str) -> bool:
1✔
732
    """
733
    uses this django url validation regex:
734
    https://github.com/django/django/blob/2c6906a0c4673a7685817156576724aba13ad893/django/core/validators.py#L45C1-L52C43
735
    Note: this is not perfect, urls and their parsing can get very complex (see recent django code).
736
    however its sufficient for catching weird user input in the gui dialog
737
    """
738
    # stores the compiled regex in the function object itself to avoid recompiling it every call
739
    if not hasattr(is_valid_websocket_url, "regex"):
×
740
        is_valid_websocket_url.regex = re.compile(
×
741
            r'^(?:ws|wss)://'  # ws:// or wss://
742
            r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|'  # domain...
743
            r'localhost|'  # localhost...
744
            r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}|'  # ...or ipv4
745
            r'\[?[A-F0-9]*:[A-F0-9:]+\]?)'  # ...or ipv6
746
            r'(?::\d+)?'  # optional port
747
            r'(?:/?|[/?]\S+)$', re.IGNORECASE)
748
    try:
×
749
        return re.match(is_valid_websocket_url.regex, url) is not None
×
750
    except Exception:
×
751
        return False
×
752

753

754
def is_hash256_str(text: Any) -> bool:
1✔
755
    if not isinstance(text, str): return False
1✔
756
    if len(text) != 64: return False
1✔
757
    return is_hex_str(text)
1✔
758

759

760
def is_hex_str(text: Any) -> bool:
1✔
761
    if not isinstance(text, str): return False
1✔
762
    try:
1✔
763
        b = bytes.fromhex(text)
1✔
764
    except Exception:
1✔
765
        return False
1✔
766
    # forbid whitespaces in text:
767
    if len(text) != 2 * len(b):
1✔
768
        return False
1✔
769
    return True
1✔
770

771

772
def is_integer(val: Any) -> bool:
1✔
773
    return isinstance(val, int)
1✔
774

775

776
def is_non_negative_integer(val: Any) -> bool:
1✔
777
    if is_integer(val):
1✔
778
        return val >= 0
1✔
779
    return False
1✔
780

781

782
def is_int_or_float(val: Any) -> bool:
1✔
783
    return isinstance(val, (int, float))
1✔
784

785

786
def is_non_negative_int_or_float(val: Any) -> bool:
1✔
787
    if is_int_or_float(val):
1✔
788
        return val >= 0
1✔
789
    return False
1✔
790

791

792
def chunks(items, size: int):
1✔
793
    """Break up items, an iterable, into chunks of length size."""
794
    if size < 1:
1✔
795
        raise ValueError(f"size must be positive, not {repr(size)}")
1✔
796
    for i in range(0, len(items), size):
1✔
797
        yield items[i: i + size]
1✔
798

799

800
def format_satoshis_plain(
1✔
801
        x: Union[int, float, Decimal, str],  # amount in satoshis,
802
        *,
803
        decimal_point: int = 8,  # how much to shift decimal point to left (default: sat->BTC)
804
        is_max_allowed: bool = True,
805
) -> str:
806
    """Display a satoshi amount scaled.  Always uses a '.' as a decimal
807
    point and has no thousands separator"""
808
    if is_max_allowed and parse_max_spend(x):
1✔
809
        return f'max({x})'
×
810
    assert isinstance(x, (int, float, Decimal)), f"{x!r} should be a number"
1✔
811
    # TODO(ghost43) just hard-fail if x is a float. do we even use floats for money anywhere?
812
    x = to_decimal(x)
1✔
813
    scale_factor = pow(10, decimal_point)
1✔
814
    return "{:.8f}".format(x / scale_factor).rstrip('0').rstrip('.')
1✔
815

816

817
# Check that Decimal precision is sufficient.
818
# We need at the very least ~20, as we deal with msat amounts, and
819
# log10(21_000_000 * 10**8 * 1000) ~= 18.3
820
# decimal.DefaultContext.prec == 28 by default, but it is mutable.
821
# We enforce that we have at least that available.
822
assert decimal.getcontext().prec >= 28, f"PyDecimal precision too low: {decimal.getcontext().prec}"
1✔
823

824
# DECIMAL_POINT = locale.localeconv()['decimal_point']  # type: str
825
DECIMAL_POINT = "."
1✔
826
THOUSANDS_SEP = " "
1✔
827
assert len(DECIMAL_POINT) == 1, f"DECIMAL_POINT has unexpected len. {DECIMAL_POINT!r}"
1✔
828
assert len(THOUSANDS_SEP) == 1, f"THOUSANDS_SEP has unexpected len. {THOUSANDS_SEP!r}"
1✔
829

830

831
def format_satoshis(
1✔
832
        x: Union[int, float, Decimal, str, None],  # amount in satoshis
833
        *,
834
        num_zeros: int = 0,
835
        decimal_point: int = 8,  # how much to shift decimal point to left (default: sat->BTC)
836
        precision: int = 0,  # extra digits after satoshi precision
837
        is_diff: bool = False,  # if True, enforce a leading sign (+/-)
838
        whitespaces: bool = False,  # if True, add whitespaces, to align numbers in a column
839
        add_thousands_sep: bool = False,  # if True, add whitespaces, for better readability of the numbers
840
) -> str:
841
    if x is None:
1✔
842
        return 'unknown'
×
843
    if parse_max_spend(x):
1✔
844
        return f'max({x})'
×
845
    assert isinstance(x, (int, float, Decimal)), f"{x!r} should be a number"
1✔
846
    # TODO(ghost43) just hard-fail if x is a float. do we even use floats for money anywhere?
847
    x = to_decimal(x)
1✔
848
    # lose redundant precision
849
    x = x.quantize(Decimal(10) ** (-precision))
1✔
850
    # format string
851
    overall_precision = decimal_point + precision  # max digits after final decimal point
1✔
852
    decimal_format = "." + str(overall_precision) if overall_precision > 0 else ""
1✔
853
    if is_diff:
1✔
854
        decimal_format = '+' + decimal_format
1✔
855
    # initial result
856
    scale_factor = pow(10, decimal_point)
1✔
857
    result = ("{:" + decimal_format + "f}").format(x / scale_factor)
1✔
858
    if "." not in result: result += "."
1✔
859
    result = result.rstrip('0')
1✔
860
    # add extra decimal places (zeros)
861
    integer_part, fract_part = result.split(".")
1✔
862
    if len(fract_part) < num_zeros:
1✔
863
        fract_part += "0" * (num_zeros - len(fract_part))
1✔
864
    # add whitespaces as thousands' separator for better readability of numbers
865
    if add_thousands_sep:
1✔
866
        sign = integer_part[0] if integer_part[0] in ("+", "-") else ""
1✔
867
        if sign == "-":
1✔
868
            integer_part = integer_part[1:]
1✔
869
        integer_part = "{:,}".format(int(integer_part)).replace(',', THOUSANDS_SEP)
1✔
870
        integer_part = sign + integer_part
1✔
871
        fract_part = THOUSANDS_SEP.join(fract_part[i:i+3] for i in range(0, len(fract_part), 3))
1✔
872
    result = integer_part + DECIMAL_POINT + fract_part
1✔
873
    # add leading/trailing whitespaces so that numbers can be aligned in a column
874
    if whitespaces:
1✔
875
        target_fract_len = overall_precision
1✔
876
        target_integer_len = 14 - decimal_point  # should be enough for up to unsigned 999999 BTC
1✔
877
        if add_thousands_sep:
1✔
878
            target_fract_len += max(0, (target_fract_len - 1) // 3)
1✔
879
            target_integer_len += max(0, (target_integer_len - 1) // 3)
1✔
880
        # add trailing whitespaces
881
        result += " " * (target_fract_len - len(fract_part))
1✔
882
        # add leading whitespaces
883
        target_total_len = target_integer_len + 1 + target_fract_len
1✔
884
        result = " " * (target_total_len - len(result)) + result
1✔
885
    return result
1✔
886

887

888
FEERATE_PRECISION = 1  # num fractional decimal places for sat/byte fee rates
1✔
889
_feerate_quanta = Decimal(10) ** (-FEERATE_PRECISION)
1✔
890
UI_UNIT_NAME_FEERATE_SAT_PER_VBYTE = "sat/vbyte"
1✔
891
UI_UNIT_NAME_FEERATE_SAT_PER_VB = "sat/vB"
1✔
892
UI_UNIT_NAME_TXSIZE_VBYTES = "vbytes"
1✔
893
UI_UNIT_NAME_MEMPOOL_MB = "vMB"
1✔
894

895

896
def format_fee_satoshis(fee, *, num_zeros=0, precision=None):
1✔
897
    if precision is None:
1✔
898
        precision = FEERATE_PRECISION
1✔
899
    num_zeros = min(num_zeros, FEERATE_PRECISION)  # no more zeroes than available prec
1✔
900
    return format_satoshis(fee, num_zeros=num_zeros, decimal_point=0, precision=precision)
1✔
901

902

903
def quantize_feerate(fee) -> Union[None, Decimal, int]:
1✔
904
    """Strip sat/byte fee rate of excess precision."""
905
    if fee is None:
1✔
906
        return None
×
907
    return Decimal(fee).quantize(_feerate_quanta, rounding=decimal.ROUND_HALF_DOWN)
1✔
908

909

910
DEFAULT_TIMEZONE = None  # type: timezone | None  # None means local OS timezone
1✔
911
def timestamp_to_datetime(timestamp: Union[int, float, None], *, utc: bool = False) -> Optional[datetime]:
1✔
912
    if timestamp is None:
1✔
913
        return None
×
914
    tz = DEFAULT_TIMEZONE
1✔
915
    if utc:
1✔
916
        tz = timezone.utc
×
917
    return datetime.fromtimestamp(timestamp, tz=tz)
1✔
918

919

920
def format_time(timestamp: Union[int, float, None]) -> str:
1✔
921
    date = timestamp_to_datetime(timestamp)
×
922
    return date.isoformat(' ', timespec="minutes") if date else _("Unknown")
×
923

924

925
def age(
1✔
926
    from_date: Union[int, float, None],  # POSIX timestamp
927
    *,
928
    since_date: datetime = None,
929
    target_tz=None,
930
    include_seconds: bool = False,
931
) -> str:
932
    """Takes a timestamp and returns a string with the approximation of the age"""
933
    if from_date is None:
1✔
934
        return _("Unknown")
1✔
935
    from_date = datetime.fromtimestamp(from_date)
1✔
936
    if since_date is None:
1✔
937
        since_date = datetime.now(target_tz)
×
938
    distance_in_time = from_date - since_date
1✔
939
    is_in_past = from_date < since_date
1✔
940
    s = delta_time_str(distance_in_time, include_seconds=include_seconds)
1✔
941
    return _("{} ago").format(s) if is_in_past else _("in {}").format(s)
1✔
942

943

944
def delta_time_str(distance_in_time: timedelta, *, include_seconds: bool = False) -> str:
1✔
945
    distance_in_seconds = int(round(abs(distance_in_time.days * 86400 + distance_in_time.seconds)))
1✔
946
    distance_in_minutes = int(round(distance_in_seconds / 60))
1✔
947
    if distance_in_minutes == 0:
1✔
948
        if include_seconds:
1✔
949
            return _("{} seconds").format(distance_in_seconds)
1✔
950
        else:
951
            return _("less than a minute")
1✔
952
    elif distance_in_minutes < 45:
1✔
953
        return _("about {} minutes").format(distance_in_minutes)
1✔
954
    elif distance_in_minutes < 90:
1✔
955
        return _("about 1 hour")
1✔
956
    elif distance_in_minutes < 1440:
1✔
957
        return _("about {} hours").format(round(distance_in_minutes / 60.0))
1✔
958
    elif distance_in_minutes < 2880:
1✔
959
        return _("about 1 day")
1✔
960
    elif distance_in_minutes < 43220:
1✔
961
        return _("about {} days").format(round(distance_in_minutes / 1440))
1✔
962
    elif distance_in_minutes < 86400:
1✔
963
        return _("about 1 month")
1✔
964
    elif distance_in_minutes < 525600:
1✔
965
        return _("about {} months").format(round(distance_in_minutes / 43200))
1✔
966
    elif distance_in_minutes < 1051200:
1✔
967
        return _("about 1 year")
1✔
968
    else:
969
        return _("over {} years").format(round(distance_in_minutes / 525600))
1✔
970

971

972
mainnet_block_explorers = {
1✔
973
    '3xpl.com': ('https://3xpl.com/bitcoin/',
974
                        {'tx': 'transaction/', 'addr': 'address/'}),
975
    'Bitflyer.jp': ('https://chainflyer.bitflyer.jp/',
976
                        {'tx': 'Transaction/', 'addr': 'Address/'}),
977
    'Blockchain.info': ('https://blockchain.com/btc/',
978
                        {'tx': 'tx/', 'addr': 'address/'}),
979
    'Blockstream.info': ('https://blockstream.info/',
980
                        {'tx': 'tx/', 'addr': 'address/'}),
981
    'Bitaps.com': ('https://btc.bitaps.com/',
982
                        {'tx': '', 'addr': ''}),
983
    'BTC.com': ('https://btc.com/',
984
                        {'tx': '', 'addr': ''}),
985
    'Chain.so': ('https://www.chain.so/',
986
                        {'tx': 'tx/BTC/', 'addr': 'address/BTC/'}),
987
    'Insight.is': ('https://insight.bitpay.com/',
988
                        {'tx': 'tx/', 'addr': 'address/'}),
989
    'BlockCypher.com': ('https://live.blockcypher.com/btc/',
990
                        {'tx': 'tx/', 'addr': 'address/'}),
991
    'Blockchair.com': ('https://blockchair.com/bitcoin/',
992
                        {'tx': 'transaction/', 'addr': 'address/'}),
993
    'blockonomics.co': ('https://www.blockonomics.co/',
994
                        {'tx': 'api/tx?txid=', 'addr': '#/search?q='}),
995
    'mempool.space': ('https://mempool.space/',
996
                        {'tx': 'tx/', 'addr': 'address/'}),
997
    'mempool.emzy.de': ('https://mempool.emzy.de/',
998
                        {'tx': 'tx/', 'addr': 'address/'}),
999
    'OXT.me': ('https://oxt.me/',
1000
                        {'tx': 'transaction/', 'addr': 'address/'}),
1001
    'mynode.local': ('http://mynode.local:3002/',
1002
                        {'tx': 'tx/', 'addr': 'address/'}),
1003
    'system default': ('blockchain:/',
1004
                        {'tx': 'tx/', 'addr': 'address/'}),
1005
}
1006

1007
testnet_block_explorers = {
1✔
1008
    'Bitaps.com': ('https://tbtc.bitaps.com/',
1009
                       {'tx': '', 'addr': ''}),
1010
    'BlockCypher.com': ('https://live.blockcypher.com/btc-testnet/',
1011
                       {'tx': 'tx/', 'addr': 'address/'}),
1012
    'Blockchain.info': ('https://www.blockchain.com/btc-testnet/',
1013
                       {'tx': 'tx/', 'addr': 'address/'}),
1014
    'Blockstream.info': ('https://blockstream.info/testnet/',
1015
                        {'tx': 'tx/', 'addr': 'address/'}),
1016
    'mempool.space': ('https://mempool.space/testnet/',
1017
                        {'tx': 'tx/', 'addr': 'address/'}),
1018
    'smartbit.com.au': ('https://testnet.smartbit.com.au/',
1019
                       {'tx': 'tx/', 'addr': 'address/'}),
1020
    'system default': ('blockchain://000000000933ea01ad0ee984209779baaec3ced90fa3f408719526f8d77f4943/',
1021
                       {'tx': 'tx/', 'addr': 'address/'}),
1022
}
1023

1024
testnet4_block_explorers = {
1✔
1025
    'mempool.space': ('https://mempool.space/testnet4/',
1026
                        {'tx': 'tx/', 'addr': 'address/'}),
1027
    'wakiyamap.dev': ('https://testnet4-explorer.wakiyamap.dev/',
1028
                       {'tx': 'tx/', 'addr': 'address/'}),
1029
}
1030

1031
signet_block_explorers = {
1✔
1032
    'bc-2.jp': ('https://explorer.bc-2.jp/',
1033
                        {'tx': 'tx/', 'addr': 'address/'}),
1034
    'mempool.space': ('https://mempool.space/signet/',
1035
                        {'tx': 'tx/', 'addr': 'address/'}),
1036
    'bitcoinexplorer.org': ('https://signet.bitcoinexplorer.org/',
1037
                       {'tx': 'tx/', 'addr': 'address/'}),
1038
    'wakiyamap.dev': ('https://signet-explorer.wakiyamap.dev/',
1039
                       {'tx': 'tx/', 'addr': 'address/'}),
1040
    'ex.signet.bublina.eu.org': ('https://ex.signet.bublina.eu.org/',
1041
                       {'tx': 'tx/', 'addr': 'address/'}),
1042
    'system default': ('blockchain:/',
1043
                       {'tx': 'tx/', 'addr': 'address/'}),
1044
}
1045

1046
_block_explorer_default_api_loc = {'tx': 'tx/', 'addr': 'address/'}
1✔
1047

1048

1049
def block_explorer_info():
1✔
1050
    from . import constants
×
1051
    if constants.net.NET_NAME == "testnet":
×
1052
        return testnet_block_explorers
×
1053
    elif constants.net.NET_NAME == "testnet4":
×
1054
        return testnet4_block_explorers
×
1055
    elif constants.net.NET_NAME == "signet":
×
1056
        return signet_block_explorers
×
1057
    return mainnet_block_explorers
×
1058

1059

1060
def block_explorer(config: 'SimpleConfig') -> Optional[str]:
1✔
1061
    """Returns name of selected block explorer,
1062
    or None if a custom one (not among hardcoded ones) is configured.
1063
    """
1064
    if config.BLOCK_EXPLORER_CUSTOM is not None:
×
1065
        return None
×
1066
    be_key = config.BLOCK_EXPLORER
×
1067
    be_tuple = block_explorer_info().get(be_key)
×
1068
    if be_tuple is None:
×
1069
        be_key = config.cv.BLOCK_EXPLORER.get_default_value()
×
1070
    assert isinstance(be_key, str), f"{be_key!r} should be str"
×
1071
    return be_key
×
1072

1073

1074
def block_explorer_tuple(config: 'SimpleConfig') -> Optional[Tuple[str, dict]]:
1✔
1075
    custom_be = config.BLOCK_EXPLORER_CUSTOM
×
1076
    if custom_be:
×
1077
        if isinstance(custom_be, str):
×
1078
            return custom_be, _block_explorer_default_api_loc
×
1079
        if isinstance(custom_be, (tuple, list)) and len(custom_be) == 2:
×
1080
            return tuple(custom_be)
×
1081
        _logger.warning(f"not using {config.cv.BLOCK_EXPLORER_CUSTOM.key()!r} from config. "
×
1082
                        f"expected a str or a pair but got {custom_be!r}")
1083
        return None
×
1084
    else:
1085
        # using one of the hardcoded block explorers
1086
        return block_explorer_info().get(block_explorer(config))
×
1087

1088

1089
def block_explorer_URL(config: 'SimpleConfig', kind: str, item: str) -> Optional[str]:
1✔
1090
    be_tuple = block_explorer_tuple(config)
×
1091
    if not be_tuple:
×
1092
        return
×
1093
    explorer_url, explorer_dict = be_tuple
×
1094
    kind_str = explorer_dict.get(kind)
×
1095
    if kind_str is None:
×
1096
        return
×
1097
    if explorer_url[-1] != "/":
×
1098
        explorer_url += "/"
×
1099
    url_parts = [explorer_url, kind_str, item]
×
1100
    return ''.join(url_parts)
×
1101

1102

1103
# Python bug (http://bugs.python.org/issue1927) causes raw_input
1104
# to be redirected improperly between stdin/stderr on Unix systems
1105
#TODO: py3
1106
def raw_input(prompt=None):
1✔
1107
    if prompt:
×
1108
        sys.stdout.write(prompt)
×
1109
    return builtin_raw_input()
×
1110

1111

1112
builtin_raw_input = builtins.input
1✔
1113
builtins.input = raw_input
1✔
1114

1115

1116
def parse_json(message):
1✔
1117
    # TODO: check \r\n pattern
1118
    n = message.find(b'\n')
×
1119
    if n == -1:
×
1120
        return None, message
×
1121
    try:
×
1122
        j = json.loads(message[0:n].decode('utf8'))
×
1123
    except Exception:
×
1124
        j = None
×
1125
    return j, message[n+1:]
×
1126

1127

1128
def setup_thread_excepthook():
1✔
1129
    """
1130
    Workaround for `sys.excepthook` thread bug from:
1131
    http://bugs.python.org/issue1230540
1132

1133
    Call once from the main thread before creating any threads.
1134
    """
1135

1136
    init_original = threading.Thread.__init__
×
1137

1138
    def init(self, *args, **kwargs):
×
1139

1140
        init_original(self, *args, **kwargs)
×
1141
        run_original = self.run
×
1142

1143
        def run_with_except_hook(*args2, **kwargs2):
×
1144
            try:
×
1145
                run_original(*args2, **kwargs2)
×
1146
            except Exception:
×
1147
                sys.excepthook(*sys.exc_info())
×
1148

1149
        self.run = run_with_except_hook
×
1150

1151
    threading.Thread.__init__ = init
×
1152

1153

1154
def send_exception_to_crash_reporter(e: BaseException):
1✔
1155
    from .base_crash_reporter import send_exception_to_crash_reporter
×
1156
    send_exception_to_crash_reporter(e)
×
1157

1158

1159
def versiontuple(v):
1✔
1160
    return tuple(map(int, (v.split("."))))
1✔
1161

1162

1163
def read_json_file(path):
1✔
1164
    try:
1✔
1165
        with open(path, 'r', encoding='utf-8') as f:
1✔
1166
            data = json.loads(f.read())
1✔
1167
    except json.JSONDecodeError:
×
1168
        _logger.exception('')
×
1169
        raise FileImportFailed(_("Invalid JSON code."))
×
1170
    except BaseException as e:
×
1171
        _logger.exception('')
×
1172
        raise FileImportFailed(e)
×
1173
    return data
1✔
1174

1175

1176
def write_json_file(path, data):
1✔
1177
    try:
×
1178
        with open(path, 'w+', encoding='utf-8') as f:
×
1179
            json.dump(data, f, indent=4, sort_keys=True, cls=MyEncoder)
×
1180
    except (IOError, os.error) as e:
×
1181
        _logger.exception('')
×
1182
        raise FileExportFailed(e)
×
1183

1184

1185
def os_chmod(path, mode):
1✔
1186
    """os.chmod aware of tmpfs"""
1187
    try:
1✔
1188
        os.chmod(path, mode)
1✔
1189
    except OSError as e:
×
1190
        xdg_runtime_dir = os.environ.get("XDG_RUNTIME_DIR", None)
×
1191
        if xdg_runtime_dir and is_subpath(path, xdg_runtime_dir):
×
1192
            _logger.info(f"Tried to chmod in tmpfs. Skipping... {e!r}")
×
1193
        else:
1194
            raise
×
1195

1196

1197
def make_dir(path, *, allow_symlink=True):
1✔
1198
    """Makes directory if it does not yet exist.
1199
    Also sets sane 0700 permissions on the dir.
1200
    """
1201
    if not os.path.exists(path):
1✔
1202
        if not allow_symlink and os.path.islink(path):
1✔
1203
            raise Exception('Dangling link: ' + path)
×
1204
        try:
1✔
1205
            os.mkdir(path)
1✔
1206
        except FileExistsError:
×
1207
            # this can happen in a multiprocess race, e.g. when an electrum daemon
1208
            # and an electrum cli command are launched in rapid fire
1209
            pass
×
1210
        os_chmod(path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
1✔
1211
        assert os.path.exists(path)
1✔
1212

1213

1214
def is_subpath(long_path: str, short_path: str) -> bool:
1✔
1215
    """Returns whether long_path is a sub-path of short_path."""
1216
    try:
1✔
1217
        common = os.path.commonpath([long_path, short_path])
1✔
1218
    except ValueError:
1✔
1219
        return False
1✔
1220
    short_path = standardize_path(short_path)
1✔
1221
    common     = standardize_path(common)
1✔
1222
    return short_path == common
1✔
1223

1224

1225
def log_exceptions(func):
1✔
1226
    """Decorator to log AND re-raise exceptions."""
1227
    assert inspect.iscoroutinefunction(func), 'func needs to be a coroutine'
1✔
1228

1229
    @functools.wraps(func)
1✔
1230
    async def wrapper(*args, **kwargs):
1✔
1231
        self = args[0] if len(args) > 0 else None
1✔
1232
        try:
1✔
1233
            return await func(*args, **kwargs)
1✔
1234
        except asyncio.CancelledError as e:
1✔
1235
            raise
1✔
1236
        except BaseException as e:
1✔
1237
            mylogger = self.logger if hasattr(self, 'logger') else _logger
1✔
1238
            try:
1✔
1239
                mylogger.exception(f"Exception in {func.__name__}: {repr(e)}")
1✔
1240
            except BaseException as e2:
×
1241
                print(f"logging exception raised: {repr(e2)}... orig exc: {repr(e)} in {func.__name__}")
×
1242
            raise
1✔
1243
    return wrapper
1✔
1244

1245

1246
def ignore_exceptions(func):
1✔
1247
    """Decorator to silently swallow all exceptions."""
1248
    assert inspect.iscoroutinefunction(func), 'func needs to be a coroutine'
1✔
1249

1250
    @functools.wraps(func)
1✔
1251
    async def wrapper(*args, **kwargs):
1✔
1252
        try:
1✔
1253
            return await func(*args, **kwargs)
1✔
1254
        except Exception as e:
1✔
1255
            pass
×
1256
    return wrapper
1✔
1257

1258

1259
def with_lock(func):
1✔
1260
    """Decorator to enforce a lock on a function call."""
1261
    @functools.wraps(func)
1✔
1262
    def func_wrapper(self, *args, **kwargs):
1✔
1263
        with self.lock:
1✔
1264
            return func(self, *args, **kwargs)
1✔
1265
    return func_wrapper
1✔
1266

1267

1268
@dataclass(frozen=True, kw_only=True)
1✔
1269
class TxMinedInfo:
1✔
1270
    _height: int                       # height of block that mined tx
1✔
1271
    conf: Optional[int] = None         # number of confirmations, SPV verified. >=0, or None (None means unknown)
1✔
1272
    timestamp: Optional[int] = None    # timestamp of block that mined tx
1✔
1273
    txpos: Optional[int] = None        # position of tx in serialized block
1✔
1274
    header_hash: Optional[str] = None  # hash of block that mined tx
1✔
1275
    wanted_height: Optional[int] = None  # in case of timelock, min abs block height
1✔
1276

1277
    def height(self) -> int:
1✔
1278
        """Treat unverified heights as unconfirmed."""
1279
        h = self._height
1✔
1280
        if h > 0:
1✔
1281
            if self.conf is not None and self.conf >= 1:
1✔
1282
                return h
1✔
1283
            return 0  # treat it as unconfirmed until SPV-ed
1✔
1284
        else:  # h <= 0
1285
            return h
1✔
1286

1287
    def short_id(self) -> Optional[str]:
1✔
1288
        if self.txpos is not None and self.txpos >= 0:
×
1289
            assert self.height() > 0
×
1290
            return f"{self.height()}x{self.txpos}"
×
1291
        return None
×
1292

1293
    def is_local_like(self) -> bool:
1✔
1294
        """Returns whether the tx is local-like (LOCAL/FUTURE)."""
1295
        from .address_synchronizer import TX_HEIGHT_UNCONFIRMED, TX_HEIGHT_UNCONF_PARENT
×
1296
        if self.height() > 0:
×
1297
            return False
×
1298
        if self.height() in (TX_HEIGHT_UNCONFIRMED, TX_HEIGHT_UNCONF_PARENT):
×
1299
            return False
×
1300
        return True
×
1301

1302

1303
class ShortID(bytes):
1✔
1304

1305
    def __repr__(self):
1✔
1306
        return f"<ShortID: {format_short_id(self)}>"
1✔
1307

1308
    def __str__(self):
1✔
1309
        return format_short_id(self)
1✔
1310

1311
    @classmethod
1✔
1312
    def from_components(cls, block_height: int, tx_pos_in_block: int, output_index: int) -> 'ShortID':
1✔
1313
        bh = block_height.to_bytes(3, byteorder='big')
1✔
1314
        tpos = tx_pos_in_block.to_bytes(3, byteorder='big')
1✔
1315
        oi = output_index.to_bytes(2, byteorder='big')
1✔
1316
        return ShortID(bh + tpos + oi)
1✔
1317

1318
    @classmethod
1✔
1319
    def from_str(cls, scid: str) -> 'ShortID':
1✔
1320
        """Parses a formatted scid str, e.g. '643920x356x0'."""
1321
        components = scid.split("x")
1✔
1322
        if len(components) != 3:
1✔
1323
            raise ValueError(f"failed to parse ShortID: {scid!r}")
×
1324
        try:
1✔
1325
            components = [int(x) for x in components]
1✔
1326
        except ValueError:
×
1327
            raise ValueError(f"failed to parse ShortID: {scid!r}") from None
×
1328
        return ShortID.from_components(*components)
1✔
1329

1330
    @classmethod
1✔
1331
    def normalize(cls, data: Union[None, str, bytes, 'ShortID']) -> Optional['ShortID']:
1✔
1332
        if isinstance(data, ShortID) or data is None:
1✔
1333
            return data
1✔
1334
        if isinstance(data, str):
1✔
1335
            assert len(data) == 16
1✔
1336
            return ShortID.fromhex(data)
1✔
1337
        if isinstance(data, (bytes, bytearray)):
1✔
1338
            assert len(data) == 8
1✔
1339
            return ShortID(data)
1✔
1340

1341
    @property
1✔
1342
    def block_height(self) -> int:
1✔
1343
        return int.from_bytes(self[:3], byteorder='big')
1✔
1344

1345
    @property
1✔
1346
    def txpos(self) -> int:
1✔
1347
        return int.from_bytes(self[3:6], byteorder='big')
1✔
1348

1349
    @property
1✔
1350
    def output_index(self) -> int:
1✔
1351
        return int.from_bytes(self[6:8], byteorder='big')
1✔
1352

1353

1354
def format_short_id(short_channel_id: Optional[bytes]):
1✔
1355
    if not short_channel_id:
1✔
1356
        return _('Not yet available')
×
1357
    return str(int.from_bytes(short_channel_id[:3], 'big')) \
1✔
1358
        + 'x' + str(int.from_bytes(short_channel_id[3:6], 'big')) \
1359
        + 'x' + str(int.from_bytes(short_channel_id[6:], 'big'))
1360

1361

1362
def make_aiohttp_proxy_connector(proxy: 'ProxySettings', ssl_context: Optional[ssl.SSLContext] = None) -> ProxyConnector:
1✔
1363
    return ProxyConnector(
×
1364
        proxy_type=ProxyType.SOCKS5 if proxy.mode == 'socks5' else ProxyType.SOCKS4,
1365
        host=proxy.host,
1366
        port=int(proxy.port),
1367
        username=proxy.user,
1368
        password=proxy.password,
1369
        rdns=True,  # needed to prevent DNS leaks over proxy
1370
        ssl=ssl_context,
1371
    )
1372

1373

1374
def make_aiohttp_session(proxy: Optional['ProxySettings'], headers=None, timeout=None):
1✔
1375
    if headers is None:
×
1376
        headers = {'User-Agent': 'Electrum'}
×
1377
    if timeout is None:
×
1378
        # The default timeout is high intentionally.
1379
        # DNS on some systems can be really slow, see e.g. #5337
1380
        timeout = aiohttp.ClientTimeout(total=45)
×
1381
    elif isinstance(timeout, (int, float)):
×
1382
        timeout = aiohttp.ClientTimeout(total=timeout)
×
1383
    ssl_context = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH, cafile=ca_path)
×
1384

1385
    if proxy and proxy.enabled:
×
1386
        connector = make_aiohttp_proxy_connector(proxy, ssl_context)
×
1387
    else:
1388
        connector = aiohttp.TCPConnector(ssl=ssl_context)
×
1389

1390
    return aiohttp.ClientSession(headers=headers, timeout=timeout, connector=connector)
×
1391

1392

1393
class OldTaskGroup(aiorpcx.TaskGroup):
1✔
1394
    """Automatically raises exceptions on join; as in aiorpcx prior to version 0.20.
1395
    That is, when using TaskGroup as a context manager, if any task encounters an exception,
1396
    we would like that exception to be re-raised (propagated out). For the wait=all case,
1397
    the OldTaskGroup class is emulating the following code-snippet:
1398
    ```
1399
    async with TaskGroup() as group:
1400
        await group.spawn(task1())
1401
        await group.spawn(task2())
1402

1403
        async for task in group:
1404
            if not task.cancelled():
1405
                task.result()
1406
    ```
1407
    So instead of the above, one can just write:
1408
    ```
1409
    async with OldTaskGroup() as group:
1410
        await group.spawn(task1())
1411
        await group.spawn(task2())
1412
    ```
1413
    # TODO see if we can migrate to asyncio.timeout, introduced in python 3.11, and use stdlib instead of aiorpcx.curio...
1414
    """
1415
    async def join(self):
1✔
1416
        if self._wait is all:
1✔
1417
            exc = False
1✔
1418
            try:
1✔
1419
                async for task in self:
1✔
1420
                    if not task.cancelled():
1✔
1421
                        task.result()
1✔
1422
            except BaseException:  # including asyncio.CancelledError
1✔
1423
                exc = True
1✔
1424
                raise
1✔
1425
            finally:
1426
                if exc:
1✔
1427
                    await self.cancel_remaining()
1✔
1428
                await super().join()
1✔
1429
        else:
1430
            await super().join()
1✔
1431
            if self.completed:
1✔
1432
                self.completed.result()
1✔
1433

1434

1435
# We monkey-patch aiorpcx TimeoutAfter (used by timeout_after and ignore_after API),
1436
# to fix a timing issue present in asyncio as a whole re timing out tasks.
1437
# To see the issue we are trying to fix, consider example:
1438
#     async def outer_task():
1439
#         async with timeout_after(0.1):
1440
#             await inner_task()
1441
# When the 0.1 sec timeout expires, inner_task will get cancelled by timeout_after (=internal cancellation).
1442
# If around the same time (in terms of event loop iterations) another coroutine
1443
# cancels outer_task (=external cancellation), there will be a race.
1444
# Both cancellations work by propagating a CancelledError out to timeout_after, which then
1445
# needs to decide (in TimeoutAfter.__aexit__) whether it's due to an internal or external cancellation.
1446
# AFAICT asyncio provides no reliable way of distinguishing between the two.
1447
# This patch tries to always give priority to external cancellations.
1448
# see https://github.com/kyuupichan/aiorpcX/issues/44
1449
# see https://github.com/aio-libs/async-timeout/issues/229
1450
# see https://bugs.python.org/issue42130 and https://bugs.python.org/issue45098
1451
# TODO see if we can migrate to asyncio.timeout, introduced in python 3.11, and use stdlib instead of aiorpcx.curio...
1452
def _aiorpcx_monkeypatched_set_new_deadline(task, deadline):
1✔
1453
    def timeout_task():
1✔
1454
        task._orig_cancel()
1✔
1455
        task._timed_out = None if getattr(task, "_externally_cancelled", False) else deadline
1✔
1456

1457
    def mycancel(*args, **kwargs):
1✔
1458
        task._orig_cancel(*args, **kwargs)
1✔
1459
        task._externally_cancelled = True
1✔
1460
        task._timed_out = None
1✔
1461

1462
    if not hasattr(task, "_orig_cancel"):
1✔
1463
        task._orig_cancel = task.cancel
1✔
1464
        task.cancel = mycancel
1✔
1465
    task._deadline_handle = task._loop.call_at(deadline, timeout_task)
1✔
1466

1467

1468
def _aiorpcx_monkeypatched_set_task_deadline(task, deadline):
1✔
1469
    ret = _aiorpcx_orig_set_task_deadline(task, deadline)
1✔
1470
    task._externally_cancelled = None
1✔
1471
    return ret
1✔
1472

1473

1474
def _aiorpcx_monkeypatched_unset_task_deadline(task):
1✔
1475
    if hasattr(task, "_orig_cancel"):
1✔
1476
        task.cancel = task._orig_cancel
1✔
1477
        del task._orig_cancel
1✔
1478
    return _aiorpcx_orig_unset_task_deadline(task)
1✔
1479

1480

1481
_aiorpcx_orig_set_task_deadline    = aiorpcx.curio._set_task_deadline
1✔
1482
_aiorpcx_orig_unset_task_deadline  = aiorpcx.curio._unset_task_deadline
1✔
1483

1484
aiorpcx.curio._set_new_deadline    = _aiorpcx_monkeypatched_set_new_deadline
1✔
1485
aiorpcx.curio._set_task_deadline   = _aiorpcx_monkeypatched_set_task_deadline
1✔
1486
aiorpcx.curio._unset_task_deadline = _aiorpcx_monkeypatched_unset_task_deadline
1✔
1487

1488

1489
async def wait_for2(fut: Awaitable, timeout: Union[int, float, None]):
1✔
1490
    """Replacement for asyncio.wait_for,
1491
     due to bugs: https://bugs.python.org/issue42130 and https://github.com/python/cpython/issues/86296 ,
1492
     which are only fixed in python 3.12+.
1493
     """
1494
    if sys.version_info[:3] >= (3, 12):
1✔
1495
        return await asyncio.wait_for(fut, timeout)
×
1496
    else:
1497
        async with async_timeout(timeout):
1✔
1498
            return await asyncio.ensure_future(fut, loop=get_running_loop())
1✔
1499

1500

1501
if hasattr(asyncio, 'timeout'):  # python 3.11+
1✔
1502
    async_timeout = asyncio.timeout
×
1503
else:
1504
    class TimeoutAfterAsynciolike(aiorpcx.curio.TimeoutAfter):
1✔
1505
        async def __aexit__(self, exc_type, exc_value, tb):
1✔
1506
            try:
1✔
1507
                await super().__aexit__(exc_type, exc_value, tb)
1✔
1508
            except (aiorpcx.TaskTimeout, aiorpcx.UncaughtTimeoutError):
1✔
1509
                raise asyncio.TimeoutError from None
1✔
1510
            except aiorpcx.TimeoutCancellationError:
×
1511
                raise asyncio.CancelledError from None
×
1512

1513
    def async_timeout(delay: Union[int, float, None]):
1✔
1514
        if delay is None:
1✔
1515
            return nullcontext()
1✔
1516
        return TimeoutAfterAsynciolike(delay)
1✔
1517

1518

1519
class NetworkJobOnDefaultServer(Logger, ABC):
1✔
1520
    """An abstract base class for a job that runs on the main network
1521
    interface. Every time the main interface changes, the job is
1522
    restarted, and some of its internals are reset.
1523
    """
1524
    def __init__(self, network: 'Network'):
1✔
1525
        Logger.__init__(self)
1✔
1526
        self.network = network
1✔
1527
        self.interface = None  # type: Interface
1✔
1528
        self._restart_lock = asyncio.Lock()
1✔
1529
        # Ensure fairness between NetworkJobs. e.g. if multiple wallets
1530
        # are open, a large wallet's Synchronizer should not starve the small wallets:
1531
        self._network_request_semaphore = asyncio.Semaphore(100)
1✔
1532

1533
        self._reset()
1✔
1534
        # every time the main interface changes, restart:
1535
        register_callback(self._restart, ['default_server_changed'])
1✔
1536
        # also schedule a one-off restart now, as there might already be a main interface:
1537
        asyncio.run_coroutine_threadsafe(self._restart(), network.asyncio_loop)
1✔
1538

1539
    def _reset(self):
1✔
1540
        """Initialise fields. Called every time the underlying
1541
        server connection changes.
1542
        """
1543
        self.taskgroup = OldTaskGroup()
1✔
1544
        self.reset_request_counters()
1✔
1545

1546
    async def _start(self, interface: 'Interface'):
1✔
1547
        self.logger.debug(f"starting. interface.server={repr(str(interface.server))}")
1✔
1548
        self.interface = interface
1✔
1549

1550
        taskgroup = self.taskgroup
1✔
1551

1552
        async def run_tasks_wrapper():
1✔
1553
            self.logger.debug(f"starting taskgroup ({hex(id(taskgroup))}).")
1✔
1554
            try:
1✔
1555
                await self._run_tasks(taskgroup=taskgroup)
1✔
1556
            except Exception as e:
1✔
1557
                self.logger.error(f"taskgroup died ({hex(id(taskgroup))}). exc={e!r}")
×
1558
                raise
×
1559
            finally:
1560
                self.logger.debug(f"taskgroup stopped ({hex(id(taskgroup))}).")
1✔
1561
        await interface.taskgroup.spawn(run_tasks_wrapper)
1✔
1562

1563
    @abstractmethod
1✔
1564
    async def _run_tasks(self, *, taskgroup: OldTaskGroup) -> None:
1✔
1565
        """Start tasks in taskgroup. Called every time the underlying
1566
        server connection changes.
1567
        """
1568
        # If self.taskgroup changed, don't start tasks. This can happen if we have
1569
        # been restarted *just now*, i.e. after the _run_tasks coroutine object was created.
1570
        if taskgroup != self.taskgroup:
1✔
1571
            raise asyncio.CancelledError()
×
1572

1573
    async def stop(self, *, full_shutdown: bool = True):
1✔
1574
        self.logger.debug(f"stopping. {full_shutdown=}")
1✔
1575
        if full_shutdown:
1✔
1576
            unregister_callback(self._restart)
×
1577
        await self.taskgroup.cancel_remaining()
1✔
1578

1579
    @log_exceptions
1✔
1580
    async def _restart(self, *args):
1✔
1581
        interface = self.network.interface
1✔
1582
        if interface is None:
1✔
1583
            return  # we should get called again soon
1✔
1584

1585
        async with self._restart_lock:
1✔
1586
            await self.stop(full_shutdown=False)
1✔
1587
            self._reset()
1✔
1588
            await self._start(interface)
1✔
1589

1590
    def reset_request_counters(self):
1✔
1591
        self._requests_sent = 0
1✔
1592
        self._requests_answered = 0
1✔
1593

1594
    def num_requests_sent_and_answered(self) -> Tuple[int, int]:
1✔
1595
        return self._requests_sent, self._requests_answered
×
1596

1597
    @property
1✔
1598
    def session(self):
1✔
1599
        s = self.interface.session
1✔
1600
        assert s is not None
1✔
1601
        return s
1✔
1602

1603

1604
async def detect_tor_socks_proxy() -> Optional[Tuple[str, int]]:
1✔
1605
    # Probable ports for Tor to listen at
1606
    candidates = [
×
1607
        ("127.0.0.1", 9050),
1608
        ("127.0.0.1", 9051),
1609
        ("127.0.0.1", 9150),
1610
    ]
1611

1612
    proxy_addr = None
×
1613

1614
    async def test_net_addr(net_addr):
×
1615
        is_tor = await is_tor_socks_port(*net_addr)
×
1616
        # set result, and cancel remaining probes
1617
        if is_tor:
×
1618
            nonlocal proxy_addr
1619
            proxy_addr = net_addr
×
1620
            await group.cancel_remaining()
×
1621

1622
    async with OldTaskGroup() as group:
×
1623
        for net_addr in candidates:
×
1624
            await group.spawn(test_net_addr(net_addr))
×
1625
    return proxy_addr
×
1626

1627

1628
@log_exceptions
1✔
1629
async def is_tor_socks_port(host: str, port: int) -> bool:
1✔
1630
    # mimic "tor-resolve 0.0.0.0".
1631
    # see https://github.com/spesmilo/electrum/issues/7317#issuecomment-1369281075
1632
    # > this is a socks5 handshake, followed by a socks RESOLVE request as defined in
1633
    # > [tor's socks extension spec](https://github.com/torproject/torspec/blob/7116c9cdaba248aae07a3f1d0e15d9dd102f62c5/socks-extensions.txt#L63),
1634
    # > resolving 0.0.0.0, which being an IP, tor resolves itself without needing to ask a relay.
1635
    writer = None
×
1636
    try:
×
1637
        async with async_timeout(10):
×
1638
            reader, writer = await asyncio.open_connection(host, port)
×
1639
            writer.write(b'\x05\x01\x00\x05\xf0\x00\x03\x070.0.0.0\x00\x00')
×
1640
            await writer.drain()
×
1641
            data = await reader.read(1024)
×
1642
            if data == b'\x05\x00\x05\x00\x00\x01\x00\x00\x00\x00\x00\x00':
×
1643
                return True
×
1644
            return False
×
1645
    except (OSError, asyncio.TimeoutError):
×
1646
        return False
×
1647
    finally:
1648
        if writer:
×
1649
            writer.close()
×
1650

1651

1652
AS_LIB_USER_I_WANT_TO_MANAGE_MY_OWN_ASYNCIO_LOOP = False  # used by unit tests
1✔
1653

1654
_asyncio_event_loop = None  # type: Optional[asyncio.AbstractEventLoop]
1✔
1655

1656

1657
def get_asyncio_loop() -> asyncio.AbstractEventLoop:
1✔
1658
    """Returns the global asyncio event loop we use."""
1659
    if loop := _asyncio_event_loop:
1✔
1660
        return loop
1✔
1661
    if AS_LIB_USER_I_WANT_TO_MANAGE_MY_OWN_ASYNCIO_LOOP:
1✔
1662
        if loop := get_running_loop():
1✔
1663
            return loop
1✔
1664
    raise Exception("event loop not created yet")
×
1665

1666

1667
def create_and_start_event_loop() -> Tuple[asyncio.AbstractEventLoop,
1✔
1668
                                           asyncio.Future,
1669
                                           threading.Thread]:
1670
    global _asyncio_event_loop
1671
    if _asyncio_event_loop is not None:
×
1672
        raise Exception("there is already a running event loop")
×
1673

1674
    # asyncio.get_event_loop() became deprecated in python3.10. (see https://github.com/python/cpython/issues/83710)
1675
    # We set a custom event loop policy purely to be compatible with code that
1676
    # relies on asyncio.get_event_loop().
1677
    # - in python 3.8-3.9, asyncio.Event.__init__, asyncio.Lock.__init__,
1678
    #   and similar, calls get_event_loop. see https://github.com/python/cpython/pull/23420
1679
    class MyEventLoopPolicy(asyncio.DefaultEventLoopPolicy):
×
1680
        def get_event_loop(self):
×
1681
            # In case electrum is being used as a library, there might be other
1682
            # event loops in use besides ours. To minimise interfering with those,
1683
            # if there is a loop running in the current thread, return that:
1684
            running_loop = get_running_loop()
×
1685
            if running_loop is not None:
×
1686
                return running_loop
×
1687
            # Otherwise, return our global loop:
1688
            return get_asyncio_loop()
×
1689
    asyncio.set_event_loop_policy(MyEventLoopPolicy())
×
1690

1691
    loop = asyncio.new_event_loop()
×
1692
    _asyncio_event_loop = loop
×
1693

1694
    def on_exception(loop, context):
×
1695
        """Suppress spurious messages it appears we cannot control."""
1696
        SUPPRESS_MESSAGE_REGEX = re.compile('SSL handshake|Fatal read error on|'
×
1697
                                            'SSL error in data received')
1698
        message = context.get('message')
×
1699
        if message and SUPPRESS_MESSAGE_REGEX.match(message):
×
1700
            return
×
1701
        loop.default_exception_handler(context)
×
1702

1703
    def run_event_loop():
×
1704
        try:
×
1705
            loop.run_until_complete(stopping_fut)
×
1706
        finally:
1707
            # clean-up
1708
            try:
×
1709
                pending_tasks = asyncio.gather(*asyncio.all_tasks(loop), return_exceptions=True)
×
1710
                pending_tasks.cancel()
×
1711
                with suppress(asyncio.CancelledError):
×
1712
                    loop.run_until_complete(pending_tasks)
×
1713
                loop.run_until_complete(loop.shutdown_asyncgens())
×
1714
                if isinstance(loop, asyncio.BaseEventLoop):
×
1715
                    loop.run_until_complete(loop.shutdown_default_executor())
×
1716
            except Exception as e:
×
1717
                _logger.debug(f"exception when cleaning up asyncio event loop: {e}")
×
1718

1719
            global _asyncio_event_loop
1720
            _asyncio_event_loop = None
×
1721
            loop.close()
×
1722

1723
    loop.set_exception_handler(on_exception)
×
1724
    _set_custom_task_factory(loop)
×
1725
    # loop.set_debug(True)
1726
    stopping_fut = loop.create_future()
×
1727
    loop_thread = threading.Thread(
×
1728
        target=run_event_loop,
1729
        name='EventLoop',
1730
    )
1731
    loop_thread.start()
×
1732
    # Wait until the loop actually starts.
1733
    # On a slow PC, or with a debugger attached, this can take a few dozens of ms,
1734
    # and if we returned without a running loop, weird things can happen...
1735
    t0 = time.monotonic()
×
1736
    while not loop.is_running():
×
1737
        time.sleep(0.01)
×
1738
        if time.monotonic() - t0 > 5:
×
1739
            raise Exception("been waiting for 5 seconds but asyncio loop would not start!")
×
1740
    return loop, stopping_fut, loop_thread
×
1741

1742

1743
_running_asyncio_tasks = set()  # type: Set[asyncio.Future]
1✔
1744

1745

1746
def _set_custom_task_factory(loop: asyncio.AbstractEventLoop):
1✔
1747
    """Wrap task creation to track pending and running tasks.
1748
    When tasks are created, asyncio only maintains a weak reference to them.
1749
    Hence, the garbage collector might destroy the task mid-execution.
1750
    To avoid this, we store a strong reference for the task until it completes.
1751

1752
    Without this, a lot of APIs are basically Heisenbug-generators... e.g.:
1753
    - "asyncio.create_task"
1754
    - "loop.create_task"
1755
    - "asyncio.ensure_future"
1756
    - "asyncio.run_coroutine_threadsafe"
1757

1758
    related:
1759
        - https://bugs.python.org/issue44665
1760
        - https://github.com/python/cpython/issues/88831
1761
        - https://github.com/python/cpython/issues/91887
1762
        - https://textual.textualize.io/blog/2023/02/11/the-heisenbug-lurking-in-your-async-code/
1763
        - https://github.com/python/cpython/issues/91887#issuecomment-1434816045
1764
        - "Task was destroyed but it is pending!"
1765
    """
1766

1767
    platform_task_factory = loop.get_task_factory()
1✔
1768

1769
    def factory(loop_, coro, **kwargs):
1✔
1770
        if platform_task_factory is not None:
1✔
1771
            task = platform_task_factory(loop_, coro, **kwargs)
×
1772
        else:
1773
            task = asyncio.Task(coro, loop=loop_, **kwargs)
1✔
1774
        _running_asyncio_tasks.add(task)
1✔
1775
        task.add_done_callback(_running_asyncio_tasks.discard)
1✔
1776
        return task
1✔
1777

1778
    loop.set_task_factory(factory)
1✔
1779

1780

1781
def run_sync_function_on_asyncio_thread(func: Callable[[], Any], *, block: bool) -> None:
1✔
1782
    """Run a non-async fn on the asyncio thread. Can be called from any thread.
1783

1784
    If the current thread is already the asyncio thread, func is guaranteed
1785
    to have been completed when this method returns.
1786

1787
    For any other thread, we only wait for completion if `block` is True.
1788
    """
1789
    assert not inspect.iscoroutinefunction(func), "func must be a non-async function"
1✔
1790
    asyncio_loop = get_asyncio_loop()
1✔
1791
    if get_running_loop() == asyncio_loop:  # we are running on the asyncio thread
1✔
1792
        func()
1✔
1793
    else:  # non-asyncio thread
1794
        async def wrapper():
×
1795
            return func()
×
1796
        fut = asyncio.run_coroutine_threadsafe(wrapper(), loop=asyncio_loop)
×
1797
        if block:
×
1798
            fut.result()
×
1799
        else:
1800
            # add explicit logging of exceptions, otherwise they might get lost
1801
            tb1 = traceback.format_stack()[:-1]
×
1802
            tb1_str = "".join(tb1)
×
1803

1804
            def on_done(fut_: concurrent.futures.Future):
×
1805
                assert fut_.done()
×
1806
                if fut_.cancelled():
×
1807
                    _logger.debug(f"func cancelled. {func=}.")
×
1808
                elif exc := fut_.exception():
×
1809
                    # note: We explicitly log the first part of the traceback, tb1_str.
1810
                    #       The second part gets logged by setting "exc_info".
1811
                    _logger.error(
×
1812
                        f"func errored. {func=}. {exc=}"
1813
                        f"\n{tb1_str}", exc_info=exc)
1814
            fut.add_done_callback(on_done)
×
1815

1816

1817
class OrderedDictWithIndex(OrderedDict):
1✔
1818
    """An OrderedDict that keeps track of the positions of keys.
1819

1820
    Note: very inefficient to modify contents, except to add new items.
1821
    """
1822

1823
    def __init__(self):
1✔
1824
        super().__init__()
1✔
1825
        self._key_to_pos = {}
1✔
1826
        self._pos_to_key = {}
1✔
1827

1828
    def _recalc_index(self):
1✔
1829
        self._key_to_pos = {key: pos for (pos, key) in enumerate(self.keys())}
×
1830
        self._pos_to_key = {pos: key for (pos, key) in enumerate(self.keys())}
×
1831

1832
    def pos_from_key(self, key):
1✔
1833
        return self._key_to_pos[key]
×
1834

1835
    def value_from_pos(self, pos):
1✔
1836
        key = self._pos_to_key[pos]
×
1837
        return self[key]
×
1838

1839
    def popitem(self, *args, **kwargs):
1✔
1840
        ret = super().popitem(*args, **kwargs)
×
1841
        self._recalc_index()
×
1842
        return ret
×
1843

1844
    def move_to_end(self, *args, **kwargs):
1✔
1845
        ret = super().move_to_end(*args, **kwargs)
×
1846
        self._recalc_index()
×
1847
        return ret
×
1848

1849
    def clear(self):
1✔
1850
        ret = super().clear()
×
1851
        self._recalc_index()
×
1852
        return ret
×
1853

1854
    def pop(self, *args, **kwargs):
1✔
1855
        ret = super().pop(*args, **kwargs)
×
1856
        self._recalc_index()
×
1857
        return ret
×
1858

1859
    def update(self, *args, **kwargs):
1✔
1860
        ret = super().update(*args, **kwargs)
×
1861
        self._recalc_index()
×
1862
        return ret
×
1863

1864
    def __delitem__(self, *args, **kwargs):
1✔
1865
        ret = super().__delitem__(*args, **kwargs)
×
1866
        self._recalc_index()
×
1867
        return ret
×
1868

1869
    def __setitem__(self, key, *args, **kwargs):
1✔
1870
        is_new_key = key not in self
1✔
1871
        ret = super().__setitem__(key, *args, **kwargs)
1✔
1872
        if is_new_key:
1✔
1873
            pos = len(self) - 1
1✔
1874
            self._key_to_pos[key] = pos
1✔
1875
            self._pos_to_key[pos] = key
1✔
1876
        return ret
1✔
1877

1878

1879
def make_object_immutable(obj):
1✔
1880
    """Makes the passed object immutable recursively."""
1881
    allowed_types = (
1✔
1882
        dict, MappingProxyType, list, tuple, set, frozenset, str, int, float, bool, bytes, type(None)
1883
    )
1884
    assert isinstance(obj, allowed_types), f"{type(obj)=} cannot be made immutable"
1✔
1885
    if isinstance(obj, (dict, MappingProxyType)):
1✔
1886
        return MappingProxyType({k: make_object_immutable(v) for k, v in obj.items()})
1✔
1887
    elif isinstance(obj, (list, tuple)):
1✔
1888
        return tuple(make_object_immutable(item) for item in obj)
1✔
1889
    elif isinstance(obj, (set, frozenset)):
1✔
1890
        return frozenset(make_object_immutable(item) for item in obj)
×
1891
    return obj
1✔
1892

1893

1894
def multisig_type(wallet_type):
1✔
1895
    """If wallet_type is mofn multi-sig, return [m, n],
1896
    otherwise return None."""
1897
    if not wallet_type:
1✔
1898
        return None
×
1899
    match = re.match(r'(\d+)of(\d+)', wallet_type)
1✔
1900
    if match:
1✔
1901
        match = [int(x) for x in match.group(1, 2)]
1✔
1902
    return match
1✔
1903

1904

1905
def is_ip_address(x: Union[str, bytes]) -> bool:
1✔
1906
    if isinstance(x, bytes):
1✔
1907
        x = x.decode("utf-8")
×
1908
    try:
1✔
1909
        ipaddress.ip_address(x)
1✔
1910
        return True
1✔
1911
    except ValueError:
1✔
1912
        return False
1✔
1913

1914

1915
def is_localhost(host: str) -> bool:
1✔
1916
    if str(host) in ('localhost', 'localhost.',):
1✔
1917
        return True
1✔
1918
    if host[0] == '[' and host[-1] == ']':  # IPv6
1✔
1919
        host = host[1:-1]
1✔
1920
    try:
1✔
1921
        ip_addr = ipaddress.ip_address(host)  # type: Union[IPv4Address, IPv6Address]
1✔
1922
        return ip_addr.is_loopback
1✔
1923
    except ValueError:
1✔
1924
        pass  # not an IP
1✔
1925
    return False
1✔
1926

1927

1928
def is_private_netaddress(host: str) -> bool:
1✔
1929
    if is_localhost(host):
1✔
1930
        return True
1✔
1931
    if host[0] == '[' and host[-1] == ']':  # IPv6
1✔
1932
        host = host[1:-1]
1✔
1933
    try:
1✔
1934
        ip_addr = ipaddress.ip_address(host)  # type: Union[IPv4Address, IPv6Address]
1✔
1935
        return ip_addr.is_private
1✔
1936
    except ValueError:
1✔
1937
        pass  # not an IP
1✔
1938
    return False
1✔
1939

1940

1941
def list_enabled_bits(x: int) -> Sequence[int]:
1✔
1942
    """e.g. 77 (0b1001101) --> (0, 2, 3, 6)"""
1943
    binary = bin(x)[2:]
1✔
1944
    rev_bin = reversed(binary)
1✔
1945
    return tuple(i for i, b in enumerate(rev_bin) if b == '1')
1✔
1946

1947

1948
async def resolve_dns_srv(host: str):
1✔
1949
    # FIXME this method is not using the network proxy. (although the proxy might not support UDP?)
1950
    srv_records = await dns.asyncresolver.resolve(host, 'SRV')
×
1951
    # priority: prefer lower
1952
    # weight: tie breaker; prefer higher
1953
    srv_records = sorted(srv_records, key=lambda x: (x.priority, -x.weight))
×
1954

1955
    def dict_from_srv_record(srv):
×
1956
        return {
×
1957
            'host': str(srv.target),
1958
            'port': srv.port,
1959
        }
1960
    return [dict_from_srv_record(srv) for srv in srv_records]
×
1961

1962

1963
def randrange(bound: int) -> int:
1✔
1964
    """Return a random integer k such that 1 <= k < bound, uniformly
1965
    distributed across that range.
1966
    This is guaranteed to be cryptographically strong.
1967
    """
1968
    # secrets.randbelow(bound) returns a random int: 0 <= r < bound,
1969
    # hence transformations:
1970
    return secrets.randbelow(bound - 1) + 1
1✔
1971

1972

1973
class CallbackManager(Logger):
1✔
1974
    # callbacks set by the GUI or any thread
1975
    # guarantee: the callbacks will always get triggered from the asyncio thread.
1976

1977
    # FIXME: There should be a way to prevent circular callbacks.
1978
    # At the very least, we need a distinction between callbacks that
1979
    # are for the GUI and callbacks between wallet components
1980

1981
    def __init__(self):
1✔
1982
        Logger.__init__(self)
1✔
1983
        self.callback_lock = threading.Lock()
1✔
1984
        self.callbacks = defaultdict(list)  # type: Dict[str, List[Callable]]  # note: needs self.callback_lock
1✔
1985

1986
    def register_callback(self, func: Callable, events: Sequence[str]) -> None:
1✔
1987
        with self.callback_lock:
1✔
1988
            for event in events:
1✔
1989
                self.callbacks[event].append(func)
1✔
1990

1991
    def unregister_callback(self, callback: Callable) -> None:
1✔
1992
        with self.callback_lock:
1✔
1993
            for callbacks in self.callbacks.values():
1✔
1994
                if callback in callbacks:
1✔
1995
                    callbacks.remove(callback)
1✔
1996

1997
    def clear_all_callbacks(self) -> None:
1✔
1998
        with self.callback_lock:
1✔
1999
            self.callbacks.clear()
1✔
2000

2001
    def trigger_callback(self, event: str, *args) -> None:
1✔
2002
        """Trigger a callback with given arguments.
2003
        Can be called from any thread. The callback itself will get scheduled
2004
        on the event loop.
2005
        """
2006
        loop = get_asyncio_loop()
1✔
2007
        assert loop.is_running(), "event loop not running"
1✔
2008
        with self.callback_lock:
1✔
2009
            callbacks = self.callbacks[event][:]
1✔
2010
        for callback in callbacks:
1✔
2011
            if inspect.iscoroutinefunction(callback):  # async cb
1✔
2012
                fut = asyncio.run_coroutine_threadsafe(callback(*args), loop)
1✔
2013

2014
                def on_done(fut_: concurrent.futures.Future):
1✔
2015
                    assert fut_.done()
1✔
2016
                    if fut_.cancelled():
1✔
2017
                        self.logger.debug(f"cb cancelled. {event=}.")
×
2018
                    elif exc := fut_.exception():
1✔
2019
                        self.logger.error(f"cb errored. {event=}. {exc=}", exc_info=exc)
×
2020
                fut.add_done_callback(on_done)
1✔
2021
            else:  # non-async cb
2022
                run_sync_function_on_asyncio_thread(partial(callback, *args), block=False)
1✔
2023

2024

2025
callback_mgr = CallbackManager()
1✔
2026
trigger_callback = callback_mgr.trigger_callback
1✔
2027
register_callback = callback_mgr.register_callback
1✔
2028
unregister_callback = callback_mgr.unregister_callback
1✔
2029
_event_listeners = defaultdict(set)  # type: Dict[str, Set[str]]
1✔
2030

2031

2032
class EventListener:
1✔
2033
    """Use as a mixin for a class that has methods to be triggered on events.
2034
    - Methods that receive the callbacks should be named "on_event_*" and decorated with @event_listener.
2035
    - register_callbacks() should be called exactly once per instance of EventListener, e.g. in __init__
2036
    - unregister_callbacks() should be called at least once, e.g. when the instance is destroyed
2037
    """
2038

2039
    def _list_callbacks(self):
1✔
2040
        for c in self.__class__.__mro__:
1✔
2041
            classpath = f"{c.__module__}.{c.__name__}"
1✔
2042
            for method_name in _event_listeners[classpath]:
1✔
2043
                method = getattr(self, method_name)
1✔
2044
                assert callable(method)
1✔
2045
                assert method_name.startswith('on_event_')
1✔
2046
                yield method_name[len('on_event_'):], method
1✔
2047

2048
    def register_callbacks(self):
1✔
2049
        for name, method in self._list_callbacks():
1✔
2050
            #_logger.debug(f'registering callback {method}')
2051
            register_callback(method, [name])
1✔
2052

2053
    def unregister_callbacks(self):
1✔
2054
        for name, method in self._list_callbacks():
1✔
2055
            #_logger.debug(f'unregistering callback {method}')
2056
            unregister_callback(method)
1✔
2057

2058

2059
def event_listener(func):
1✔
2060
    """To be used in subclasses of EventListener only. (how to enforce this programmatically?)"""
2061
    classname, method_name = func.__qualname__.split('.')
1✔
2062
    assert method_name.startswith('on_event_')
1✔
2063
    classpath = f"{func.__module__}.{classname}"
1✔
2064
    _event_listeners[classpath].add(method_name)
1✔
2065
    return func
1✔
2066

2067

2068
_NetAddrType = TypeVar("_NetAddrType")
1✔
2069
# requirements for _NetAddrType:
2070
# - reasonable __hash__() implementation (e.g. based on host/port of remote endpoint)
2071

2072

2073
class NetworkRetryManager(Generic[_NetAddrType]):
1✔
2074
    """Truncated Exponential Backoff for network connections."""
2075

2076
    def __init__(
1✔
2077
            self, *,
2078
            max_retry_delay_normal: float,
2079
            init_retry_delay_normal: float,
2080
            max_retry_delay_urgent: float = None,
2081
            init_retry_delay_urgent: float = None,
2082
    ):
2083
        self._last_tried_addr = {}  # type: Dict[_NetAddrType, Tuple[float, int]]  # (unix ts, num_attempts)
1✔
2084

2085
        # note: these all use "seconds" as unit
2086
        if max_retry_delay_urgent is None:
1✔
2087
            max_retry_delay_urgent = max_retry_delay_normal
1✔
2088
        if init_retry_delay_urgent is None:
1✔
2089
            init_retry_delay_urgent = init_retry_delay_normal
1✔
2090
        self._max_retry_delay_normal = max_retry_delay_normal
1✔
2091
        self._init_retry_delay_normal = init_retry_delay_normal
1✔
2092
        self._max_retry_delay_urgent = max_retry_delay_urgent
1✔
2093
        self._init_retry_delay_urgent = init_retry_delay_urgent
1✔
2094

2095
    def _trying_addr_now(self, addr: _NetAddrType) -> None:
1✔
2096
        last_time, num_attempts = self._last_tried_addr.get(addr, (0, 0))
×
2097
        # we add up to 1 second of noise to the time, so that clients are less likely
2098
        # to get synchronised and bombard the remote in connection waves:
2099
        cur_time = time.time() + random.random()
×
2100
        self._last_tried_addr[addr] = cur_time, num_attempts + 1
×
2101

2102
    def _on_connection_successfully_established(self, addr: _NetAddrType) -> None:
1✔
2103
        self._last_tried_addr[addr] = time.time(), 0
×
2104

2105
    def _can_retry_addr(self, addr: _NetAddrType, *,
1✔
2106
                        now: float = None, urgent: bool = False) -> bool:
2107
        if now is None:
×
2108
            now = time.time()
×
2109
        last_time, num_attempts = self._last_tried_addr.get(addr, (0, 0))
×
2110
        if urgent:
×
2111
            max_delay = self._max_retry_delay_urgent
×
2112
            init_delay = self._init_retry_delay_urgent
×
2113
        else:
2114
            max_delay = self._max_retry_delay_normal
×
2115
            init_delay = self._init_retry_delay_normal
×
2116
        delay = self.__calc_delay(multiplier=init_delay, max_delay=max_delay, num_attempts=num_attempts)
×
2117
        next_time = last_time + delay
×
2118
        return next_time < now
×
2119

2120
    @classmethod
1✔
2121
    def __calc_delay(cls, *, multiplier: float, max_delay: float,
1✔
2122
                     num_attempts: int) -> float:
2123
        num_attempts = min(num_attempts, 100_000)
×
2124
        try:
×
2125
            res = multiplier * 2 ** num_attempts
×
2126
        except OverflowError:
×
2127
            return max_delay
×
2128
        return max(0, min(max_delay, res))
×
2129

2130
    def _clear_addr_retry_times(self) -> None:
1✔
2131
        self._last_tried_addr.clear()
1✔
2132

2133

2134
class ESocksProxy(aiorpcx.SOCKSProxy):
1✔
2135
    # note: proxy will not leak DNS as create_connection()
2136
    # sets (local DNS) resolve=False by default
2137

2138
    async def open_connection(self, host=None, port=None, **kwargs):
1✔
2139
        loop = asyncio.get_running_loop()
×
2140
        reader = asyncio.StreamReader(loop=loop)
×
2141
        protocol = asyncio.StreamReaderProtocol(reader, loop=loop)
×
2142
        transport, _ = await self.create_connection(
×
2143
            lambda: protocol, host, port, **kwargs)
2144
        writer = asyncio.StreamWriter(transport, protocol, reader, loop)
×
2145
        return reader, writer
×
2146

2147
    @classmethod
1✔
2148
    def from_network_settings(cls, network: Optional['Network']) -> Optional['ESocksProxy']:
1✔
2149
        if not network or not network.proxy or not network.proxy.enabled:
1✔
2150
            return None
1✔
2151
        proxy = network.proxy
×
2152
        username, pw = proxy.user, proxy.password
×
2153
        if not username or not pw:
×
2154
            # is_proxy_tor is tri-state; None indicates it is still probing the proxy to test for TOR
2155
            if network.is_proxy_tor:
×
2156
                auth = aiorpcx.socks.SOCKSRandomAuth()
×
2157
            else:
2158
                auth = None
×
2159
        else:
2160
            auth = aiorpcx.socks.SOCKSUserAuth(username, pw)
×
2161
        addr = aiorpcx.NetAddress(proxy.host, proxy.port)
×
2162
        if proxy.mode == "socks4":
×
2163
            ret = cls(addr, aiorpcx.socks.SOCKS4a, auth)
×
2164
        elif proxy.mode == "socks5":
×
2165
            ret = cls(addr, aiorpcx.socks.SOCKS5, auth)
×
2166
        else:
2167
            raise NotImplementedError  # http proxy not available with aiorpcx
×
2168
        return ret
×
2169

2170

2171
class JsonRPCError(Exception):
1✔
2172

2173
    class Codes(enum.IntEnum):
1✔
2174
        # application-specific error codes
2175
        USERFACING = 1
1✔
2176
        INTERNAL = 2
1✔
2177

2178
    def __init__(self, *, code: int, message: str, data: Optional[dict] = None):
1✔
2179
        Exception.__init__(self)
×
2180
        self.code = code
×
2181
        self.message = message
×
2182
        self.data = data
×
2183

2184

2185
class JsonRPCClient:
1✔
2186

2187
    def __init__(self, session: aiohttp.ClientSession, url: str):
1✔
2188
        self.session = session
×
2189
        self.url = url
×
2190
        self._id = 0
×
2191

2192
    async def request(self, endpoint, *args):
1✔
2193
        """Send request to server, parse and return result.
2194
        note: parsing code is naive, the server is assumed to be well-behaved.
2195
              Up to the caller to handle exceptions, including those arising from parsing errors.
2196
        """
2197
        self._id += 1
×
2198
        data = ('{"jsonrpc": "2.0", "id":"%d", "method": "%s", "params": %s }'
×
2199
                % (self._id, endpoint, json.dumps(args)))
2200
        async with self.session.post(self.url, data=data) as resp:
×
2201
            if resp.status == 200:
×
2202
                r = await resp.json()
×
2203
                result = r.get('result')
×
2204
                error = r.get('error')
×
2205
                if error:
×
2206
                    raise JsonRPCError(code=error["code"], message=error["message"], data=error.get("data"))
×
2207
                else:
2208
                    return result
×
2209
            else:
2210
                text = await resp.text()
×
2211
                return 'Error: ' + str(text)
×
2212

2213
    def add_method(self, endpoint):
1✔
2214
        async def coro(*args):
×
2215
            return await self.request(endpoint, *args)
×
2216
        setattr(self, endpoint, coro)
×
2217

2218

2219
T = TypeVar('T')
1✔
2220

2221

2222
def random_shuffled_copy(x: Iterable[T]) -> List[T]:
1✔
2223
    """Returns a shuffled copy of the input."""
2224
    x_copy = list(x)  # copy
1✔
2225
    random.shuffle(x_copy)  # shuffle in-place
1✔
2226
    return x_copy
1✔
2227

2228

2229
def test_read_write_permissions(path) -> None:
1✔
2230
    # note: There might already be a file at 'path'.
2231
    #       Make sure we do NOT overwrite/corrupt that!
2232
    temp_path = "%s.tmptest.%s" % (path, os.getpid())
1✔
2233
    echo = "fs r/w test"
1✔
2234
    try:
1✔
2235
        # test READ permissions for actual path
2236
        if os.path.exists(path):
1✔
2237
            with open(path, "rb") as f:
1✔
2238
                f.read(1)  # read 1 byte
1✔
2239
        # test R/W sanity for "similar" path
2240
        with open(temp_path, "w", encoding='utf-8') as f:
1✔
2241
            f.write(echo)
1✔
2242
        with open(temp_path, "r", encoding='utf-8') as f:
1✔
2243
            echo2 = f.read()
1✔
2244
        os.remove(temp_path)
1✔
2245
    except Exception as e:
×
2246
        raise IOError(e) from e
×
2247
    if echo != echo2:
1✔
2248
        raise IOError('echo sanity-check failed')
×
2249

2250

2251
class classproperty(property):
1✔
2252
    """~read-only class-level @property
2253
    from https://stackoverflow.com/a/13624858 by denis-ryzhkov
2254
    """
2255
    def __get__(self, owner_self, owner_cls):
1✔
2256
        return self.fget(owner_cls)
1✔
2257

2258

2259
def sticky_property(val):
1✔
2260
    """Creates a 'property' whose value cannot be changed and that cannot be deleted.
2261
    Attempts to change the value are silently ignored.
2262

2263
    >>> class C: pass
2264
    ...
2265
    >>> setattr(C, 'x', sticky_property(3))
2266
    >>> c = C()
2267
    >>> c.x
2268
    3
2269
    >>> c.x = 2
2270
    >>> c.x
2271
    3
2272
    >>> del c.x
2273
    >>> c.x
2274
    3
2275
    """
2276
    return property(
1✔
2277
        fget=lambda self: val,
2278
        fset=lambda *args, **kwargs: None,
2279
        fdel=lambda *args, **kwargs: None,
2280
    )
2281

2282

2283
def get_running_loop() -> Optional[asyncio.AbstractEventLoop]:
1✔
2284
    """Returns the asyncio event loop that is *running in this thread*, if any."""
2285
    try:
1✔
2286
        return asyncio.get_running_loop()
1✔
2287
    except RuntimeError:
×
2288
        return None
×
2289

2290

2291
def error_text_str_to_safe_str(err: str, *, max_len: Optional[int] = 500) -> str:
1✔
2292
    """Converts an untrusted error string to a sane printable ascii str.
2293
    Never raises.
2294
    """
2295
    text = error_text_bytes_to_safe_str(
1✔
2296
        err.encode("ascii", errors='backslashreplace'),
2297
        max_len=None)
2298
    return truncate_text(text, max_len=max_len)
1✔
2299

2300

2301
def error_text_bytes_to_safe_str(err: bytes, *, max_len: Optional[int] = 500) -> str:
1✔
2302
    """Converts an untrusted error bytes text to a sane printable ascii str.
2303
    Never raises.
2304

2305
    Note that naive ascii conversion would be insufficient. Fun stuff:
2306
    >>> b = b"my_long_prefix_blabla" + 21 * b"\x08" + b"malicious_stuff"
2307
    >>> s = b.decode("ascii")
2308
    >>> print(s)
2309
    malicious_stuffblabla
2310
    """
2311
    # convert to ascii, to get rid of unicode stuff
2312
    ascii_text = err.decode("ascii", errors='backslashreplace')
1✔
2313
    # do repr to handle ascii special chars (especially when printing/logging the str)
2314
    text = repr(ascii_text)
1✔
2315
    return truncate_text(text, max_len=max_len)
1✔
2316

2317

2318
def truncate_text(text: str, *, max_len: Optional[int]) -> str:
1✔
2319
    if max_len is None or len(text) <= max_len:
1✔
2320
        return text
1✔
2321
    else:
2322
        return text[:max_len] + f"... (truncated. orig_len={len(text)})"
1✔
2323

2324

2325
def nostr_pow_worker(nonce, nostr_pubk, target_bits, hash_function, hash_len_bits, shutdown):
1✔
2326
    """Function to generate PoW for Nostr, to be spawned in a ProcessPoolExecutor."""
2327
    hash_preimage = b'electrum-' + nostr_pubk
×
2328
    while True:
×
2329
        # we cannot check is_set on each iteration as it has a lot of overhead, this way we can check
2330
        # it with low overhead (just the additional range counter)
2331
        for i in range(1000000):
×
2332
            digest = hash_function(hash_preimage + nonce.to_bytes(32, 'big')).digest()
×
2333
            if int.from_bytes(digest, 'big') < (1 << (hash_len_bits - target_bits)):
×
2334
                shutdown.set()
×
2335
                return hash, nonce
×
2336
            nonce += 1
×
2337
        if shutdown.is_set():
×
2338
            return None, None
×
2339

2340

2341
async def gen_nostr_ann_pow(nostr_pubk: bytes, target_bits: int) -> Tuple[int, int]:
1✔
2342
    """Generate a PoW for a Nostr announcement. The PoW is hash[b'electrum-'+pubk+nonce]"""
2343
    import multiprocessing  # not available on Android, so we import it here
×
2344
    hash_function = hashlib.sha256
×
2345
    hash_len_bits = 256
×
2346
    max_nonce: int = (1 << (32 * 8)) - 1  # 32-byte nonce
×
2347
    start_nonce = 0
×
2348

2349
    max_workers = max(multiprocessing.cpu_count() - 1, 1)  # use all but one CPU
×
2350
    manager = multiprocessing.Manager()
×
2351
    shutdown = manager.Event()
×
2352
    with ProcessPoolExecutor(max_workers=max_workers) as executor:
×
2353
        tasks = []
×
2354
        loop = asyncio.get_running_loop()
×
2355
        for task in range(0, max_workers):
×
2356
            task = loop.run_in_executor(
×
2357
                executor,
2358
                nostr_pow_worker,
2359
                start_nonce,
2360
                nostr_pubk,
2361
                target_bits,
2362
                hash_function,
2363
                hash_len_bits,
2364
                shutdown
2365
            )
2366
            tasks.append(task)
×
2367
            start_nonce += max_nonce // max_workers  # split the nonce range between the processes
×
2368
            if start_nonce > max_nonce:  # make sure we don't go over the max_nonce
×
2369
                start_nonce = random.randint(0, int(max_nonce * 0.75))
×
2370

2371
        done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
×
2372
        hash_res, nonce_res = done.pop().result()
×
2373
        executor.shutdown(wait=False, cancel_futures=True)
×
2374

2375
    return nonce_res, get_nostr_ann_pow_amount(nostr_pubk, nonce_res)
×
2376

2377

2378
def get_nostr_ann_pow_amount(nostr_pubk: bytes, nonce: Optional[int]) -> int:
1✔
2379
    """Return the amount of leading zero bits for a nostr announcement PoW."""
2380
    if not nonce:
×
2381
        return 0
×
2382
    hash_function = hashlib.sha256
×
2383
    hash_len_bits = 256
×
2384
    hash_preimage = b'electrum-' + nostr_pubk
×
2385

2386
    digest = hash_function(hash_preimage + nonce.to_bytes(32, 'big')).digest()
×
2387
    digest = int.from_bytes(digest, 'big')
×
2388
    return hash_len_bits - digest.bit_length()
×
2389

2390

2391
class OnchainHistoryItem(NamedTuple):
1✔
2392
    txid: str
1✔
2393
    amount_sat: int
1✔
2394
    fee_sat: int
1✔
2395
    balance_sat: int
1✔
2396
    tx_mined_status: TxMinedInfo
1✔
2397
    group_id: Optional[str]
1✔
2398
    label: Optional[str]
1✔
2399
    monotonic_timestamp: int
1✔
2400
    group_id: Optional[str]
1✔
2401
    def to_dict(self):
1✔
2402
        return {
1✔
2403
            'txid': self.txid,
2404
            'amount_sat': self.amount_sat,
2405
            'fee_sat': self.fee_sat,
2406
            'height': self.tx_mined_status.height(),
2407
            'confirmations': self.tx_mined_status.conf,
2408
            'timestamp': self.tx_mined_status.timestamp,
2409
            'monotonic_timestamp': self.monotonic_timestamp,
2410
            'incoming': True if self.amount_sat>0 else False,
2411
            'bc_value': Satoshis(self.amount_sat),
2412
            'bc_balance': Satoshis(self.balance_sat),
2413
            'date': timestamp_to_datetime(self.tx_mined_status.timestamp),
2414
            'txpos_in_block': self.tx_mined_status.txpos,
2415
            'wanted_height': self.tx_mined_status.wanted_height,
2416
            'label': self.label,
2417
            'group_id': self.group_id,
2418
        }
2419

2420

2421
class LightningHistoryItem(NamedTuple):
1✔
2422
    payment_hash: Optional[str]
1✔
2423
    preimage: Optional[str]
1✔
2424
    amount_msat: int
1✔
2425
    fee_msat: Optional[int]
1✔
2426
    type: str
1✔
2427
    group_id: Optional[str]
1✔
2428
    timestamp: int
1✔
2429
    label: Optional[str]
1✔
2430
    direction: Optional[int]
1✔
2431
    def to_dict(self):
1✔
2432
        return {
×
2433
            'type': self.type,
2434
            'label': self.label,
2435
            'timestamp': self.timestamp or 0,
2436
            'date': timestamp_to_datetime(self.timestamp),
2437
            'amount_msat': self.amount_msat,
2438
            'fee_msat': self.fee_msat,
2439
            'payment_hash': self.payment_hash,
2440
            'preimage': self.preimage,
2441
            'group_id': self.group_id,
2442
            'ln_value': Satoshis(Decimal(self.amount_msat) / 1000),
2443
            'direction': self.direction,
2444
        }
2445

2446

2447
@dataclass(kw_only=True, slots=True)
1✔
2448
class ChoiceItem:
1✔
2449
    key: Any
1✔
2450
    label: str  # user facing string
1✔
2451
    extra_data: Any = None
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc