• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

spesmilo / electrum / 5650251131912192

31 Jul 2025 04:32PM UTC coverage: 60.09% (+0.002%) from 60.088%
5650251131912192

push

CirrusCI

web-flow
Merge pull request #10083 from f321x/format_cli_help

cli: set formatter_class for command descriptions

22073 of 36733 relevant lines covered (60.09%)

3.0 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

35.36
/electrum/interface.py
1
#!/usr/bin/env python
2
#
3
# Electrum - lightweight Bitcoin client
4
# Copyright (C) 2011 thomasv@gitorious
5
#
6
# Permission is hereby granted, free of charge, to any person
7
# obtaining a copy of this software and associated documentation files
8
# (the "Software"), to deal in the Software without restriction,
9
# including without limitation the rights to use, copy, modify, merge,
10
# publish, distribute, sublicense, and/or sell copies of the Software,
11
# and to permit persons to whom the Software is furnished to do so,
12
# subject to the following conditions:
13
#
14
# The above copyright notice and this permission notice shall be
15
# included in all copies or substantial portions of the Software.
16
#
17
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
18
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
19
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
20
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
21
# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
22
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
23
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
# SOFTWARE.
25
import os
5✔
26
import re
5✔
27
import ssl
5✔
28
import sys
5✔
29
import time
5✔
30
import traceback
5✔
31
import asyncio
5✔
32
import socket
5✔
33
from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set, NamedTuple, Any, Sequence, Dict
5✔
34
from collections import defaultdict
5✔
35
from ipaddress import IPv4Network, IPv6Network, ip_address, IPv6Address, IPv4Address
5✔
36
import itertools
5✔
37
import logging
5✔
38
import hashlib
5✔
39
import functools
5✔
40
import random
5✔
41
import enum
5✔
42

43
import aiorpcx
5✔
44
from aiorpcx import RPCSession, Notification, NetAddress, NewlineFramer
5✔
45
from aiorpcx.curio import timeout_after, TaskTimeout
5✔
46
from aiorpcx.jsonrpc import JSONRPC, CodeMessageError
5✔
47
from aiorpcx.rawsocket import RSClient, RSTransport
5✔
48
import certifi
5✔
49

50
from .util import (ignore_exceptions, log_exceptions, bfh, ESocksProxy,
5✔
51
                   is_integer, is_non_negative_integer, is_hash256_str, is_hex_str,
52
                   is_int_or_float, is_non_negative_int_or_float, OldTaskGroup)
53
from . import util
5✔
54
from . import x509
5✔
55
from . import pem
5✔
56
from . import version
5✔
57
from . import blockchain
5✔
58
from .blockchain import Blockchain, HEADER_SIZE, CHUNK_SIZE
5✔
59
from . import bitcoin
5✔
60
from . import constants
5✔
61
from .i18n import _
5✔
62
from .logging import Logger
5✔
63
from .transaction import Transaction
5✔
64
from .fee_policy import FEE_ETA_TARGETS
5✔
65

66
if TYPE_CHECKING:
2✔
67
    from .network import Network
68
    from .simple_config import SimpleConfig
69

70

71
ca_path = certifi.where()
5✔
72

73
BUCKET_NAME_OF_ONION_SERVERS = 'onion'
5✔
74

75
_KNOWN_NETWORK_PROTOCOLS = {'t', 's'}
5✔
76
PREFERRED_NETWORK_PROTOCOL = 's'
5✔
77
assert PREFERRED_NETWORK_PROTOCOL in _KNOWN_NETWORK_PROTOCOLS
5✔
78

79
MAX_NUM_HEADERS_PER_REQUEST = 2016
5✔
80
assert MAX_NUM_HEADERS_PER_REQUEST >= CHUNK_SIZE
5✔
81

82

83
class NetworkTimeout:
5✔
84
    # seconds
85
    class Generic:
5✔
86
        NORMAL = 30
5✔
87
        RELAXED = 45
5✔
88
        MOST_RELAXED = 600
5✔
89

90
    class Urgent(Generic):
5✔
91
        NORMAL = 10
5✔
92
        RELAXED = 20
5✔
93
        MOST_RELAXED = 60
5✔
94

95

96
def assert_non_negative_integer(val: Any) -> None:
5✔
97
    if not is_non_negative_integer(val):
×
98
        raise RequestCorrupted(f'{val!r} should be a non-negative integer')
×
99

100

101
def assert_integer(val: Any) -> None:
5✔
102
    if not is_integer(val):
×
103
        raise RequestCorrupted(f'{val!r} should be an integer')
×
104

105

106
def assert_int_or_float(val: Any) -> None:
5✔
107
    if not is_int_or_float(val):
×
108
        raise RequestCorrupted(f'{val!r} should be int or float')
×
109

110

111
def assert_non_negative_int_or_float(val: Any) -> None:
5✔
112
    if not is_non_negative_int_or_float(val):
×
113
        raise RequestCorrupted(f'{val!r} should be a non-negative int or float')
×
114

115

116
def assert_hash256_str(val: Any) -> None:
5✔
117
    if not is_hash256_str(val):
×
118
        raise RequestCorrupted(f'{val!r} should be a hash256 str')
×
119

120

121
def assert_hex_str(val: Any) -> None:
5✔
122
    if not is_hex_str(val):
×
123
        raise RequestCorrupted(f'{val!r} should be a hex str')
×
124

125

126
def assert_dict_contains_field(d: Any, *, field_name: str) -> Any:
5✔
127
    if not isinstance(d, dict):
×
128
        raise RequestCorrupted(f'{d!r} should be a dict')
×
129
    if field_name not in d:
×
130
        raise RequestCorrupted(f'required field {field_name!r} missing from dict')
×
131
    return d[field_name]
×
132

133

134
def assert_list_or_tuple(val: Any) -> None:
5✔
135
    if not isinstance(val, (list, tuple)):
×
136
        raise RequestCorrupted(f'{val!r} should be a list or tuple')
×
137

138

139
class ChainResolutionMode(enum.Enum):
5✔
140
    CATCHUP = enum.auto()
5✔
141
    BACKWARD = enum.auto()
5✔
142
    BINARY = enum.auto()
5✔
143
    FORK = enum.auto()
5✔
144
    NO_FORK = enum.auto()
5✔
145

146

147
class NotificationSession(RPCSession):
5✔
148

149
    def __init__(self, *args, interface: 'Interface', **kwargs):
5✔
150
        super(NotificationSession, self).__init__(*args, **kwargs)
×
151
        self.subscriptions = defaultdict(list)
×
152
        self.cache = {}
×
153
        self._msg_counter = itertools.count(start=1)
×
154
        self.interface = interface
×
155
        self.taskgroup = interface.taskgroup
×
156
        self.cost_hard_limit = 0  # disable aiorpcx resource limits
×
157

158
    async def handle_request(self, request):
5✔
159
        self.maybe_log(f"--> {request}")
×
160
        try:
×
161
            if isinstance(request, Notification):
×
162
                params, result = request.args[:-1], request.args[-1]
×
163
                key = self.get_hashable_key_for_rpc_call(request.method, params)
×
164
                if key in self.subscriptions:
×
165
                    self.cache[key] = result
×
166
                    for queue in self.subscriptions[key]:
×
167
                        await queue.put(request.args)
×
168
                else:
169
                    raise Exception(f'unexpected notification')
×
170
            else:
171
                raise Exception(f'unexpected request. not a notification')
×
172
        except Exception as e:
×
173
            self.interface.logger.info(f"error handling request {request}. exc: {repr(e)}")
×
174
            await self.close()
×
175

176
    async def send_request(self, *args, timeout=None, **kwargs):
5✔
177
        # note: semaphores/timeouts/backpressure etc are handled by
178
        # aiorpcx. the timeout arg here in most cases should not be set
179
        msg_id = next(self._msg_counter)
×
180
        self.maybe_log(f"<-- {args} {kwargs} (id: {msg_id})")
×
181
        try:
×
182
            # note: RPCSession.send_request raises TaskTimeout in case of a timeout.
183
            # TaskTimeout is a subclass of CancelledError, which is *suppressed* in TaskGroups
184
            response = await util.wait_for2(
×
185
                super().send_request(*args, **kwargs),
186
                timeout)
187
        except (TaskTimeout, asyncio.TimeoutError) as e:
×
188
            self.maybe_log(f"--> request timed out: {args} (id: {msg_id})")
×
189
            raise RequestTimedOut(f'request timed out: {args} (id: {msg_id})') from e
×
190
        except CodeMessageError as e:
×
191
            self.maybe_log(f"--> {repr(e)} (id: {msg_id})")
×
192
            raise
×
193
        except BaseException as e:  # cancellations, etc. are useful for debugging
×
194
            self.maybe_log(f"--> {repr(e)} (id: {msg_id})")
×
195
            raise
×
196
        else:
197
            self.maybe_log(f"--> {response} (id: {msg_id})")
×
198
            return response
×
199

200
    def set_default_timeout(self, timeout):
5✔
201
        assert hasattr(self, "sent_request_timeout")  # in base class
×
202
        self.sent_request_timeout = timeout
×
203
        assert hasattr(self, "max_send_delay")        # in base class
×
204
        self.max_send_delay = timeout
×
205

206
    async def subscribe(self, method: str, params: List, queue: asyncio.Queue):
5✔
207
        # note: until the cache is written for the first time,
208
        # each 'subscribe' call might make a request on the network.
209
        key = self.get_hashable_key_for_rpc_call(method, params)
×
210
        self.subscriptions[key].append(queue)
×
211
        if key in self.cache:
×
212
            result = self.cache[key]
×
213
        else:
214
            result = await self.send_request(method, params)
×
215
            self.cache[key] = result
×
216
        await queue.put(params + [result])
×
217

218
    def unsubscribe(self, queue):
5✔
219
        """Unsubscribe a callback to free object references to enable GC."""
220
        # note: we can't unsubscribe from the server, so we keep receiving
221
        # subsequent notifications
222
        for v in self.subscriptions.values():
×
223
            if queue in v:
×
224
                v.remove(queue)
×
225

226
    @classmethod
5✔
227
    def get_hashable_key_for_rpc_call(cls, method, params):
5✔
228
        """Hashable index for subscriptions and cache"""
229
        return str(method) + repr(params)
×
230

231
    def maybe_log(self, msg: str) -> None:
5✔
232
        if not self.interface: return
×
233
        if self.interface.debug or self.interface.network.debug:
×
234
            self.interface.logger.debug(msg)
×
235

236
    def default_framer(self):
5✔
237
        # overridden so that max_size can be customized
238
        max_size = self.interface.network.config.NETWORK_MAX_INCOMING_MSG_SIZE
×
239
        assert max_size > 500_000, f"{max_size=} (< 500_000) is too small"
×
240
        return NewlineFramer(max_size=max_size)
×
241

242
    async def close(self, *, force_after: int = None):
5✔
243
        """Closes the connection and waits for it to be closed.
244
        We try to flush buffered data to the wire, which can take some time.
245
        """
246
        if force_after is None:
×
247
            # We give up after a while and just abort the connection.
248
            # Note: specifically if the server is running Fulcrum, waiting seems hopeless,
249
            #       the connection must be aborted (see https://github.com/cculianu/Fulcrum/issues/76)
250
            # Note: if the ethernet cable was pulled or wifi disconnected, that too might
251
            #       wait until this timeout is triggered
252
            force_after = 1  # seconds
×
253
        await super().close(force_after=force_after)
×
254

255

256
class NetworkException(Exception): pass
5✔
257

258

259
class GracefulDisconnect(NetworkException):
5✔
260
    log_level = logging.INFO
5✔
261

262
    def __init__(self, *args, log_level=None, **kwargs):
5✔
263
        Exception.__init__(self, *args, **kwargs)
5✔
264
        if log_level is not None:
5✔
265
            self.log_level = log_level
×
266

267

268
class RequestTimedOut(GracefulDisconnect):
5✔
269
    def __str__(self):
5✔
270
        return _("Network request timed out.")
×
271

272

273
class RequestCorrupted(Exception): pass
5✔
274

275
class ErrorParsingSSLCert(Exception): pass
5✔
276
class ErrorGettingSSLCertFromServer(Exception): pass
5✔
277
class ErrorSSLCertFingerprintMismatch(Exception): pass
5✔
278
class InvalidOptionCombination(Exception): pass
5✔
279
class ConnectError(NetworkException): pass
5✔
280

281

282
class _RSClient(RSClient):
5✔
283
    async def create_connection(self):
5✔
284
        try:
×
285
            return await super().create_connection()
×
286
        except OSError as e:
×
287
            # note: using "from e" here will set __cause__ of ConnectError
288
            raise ConnectError(e) from e
×
289

290

291
class PaddedRSTransport(RSTransport):
5✔
292
    """A raw socket transport that provides basic countermeasures against traffic analysis
293
    by padding the jsonrpc payload with whitespaces to have ~uniform-size TCP packets.
294
    (it is assumed that a network observer does not see plaintext transport contents,
295
    due to it being wrapped e.g. in TLS)
296
    """
297

298
    MIN_PACKET_SIZE = 1024
5✔
299
    WAIT_FOR_BUFFER_GROWTH_SECONDS = 1.0
5✔
300

301
    session: Optional['RPCSession']
5✔
302

303
    def __init__(self, *args, **kwargs):
5✔
304
        RSTransport.__init__(self, *args, **kwargs)
×
305
        self._sbuffer = bytearray()  # "send buffer"
×
306
        self._sbuffer_task = None  # type: Optional[asyncio.Task]
×
307
        self._sbuffer_has_data_evt = asyncio.Event()
×
308
        self._last_send = time.monotonic()
×
309
        self._force_send = False  # type: bool
×
310

311
    # note: this does not call super().write() but is a complete reimplementation
312
    async def write(self, message):
5✔
313
        await self._can_send.wait()
×
314
        if self.is_closing():
×
315
            return
×
316
        framed_message = self._framer.frame(message)
×
317
        self._sbuffer += framed_message
×
318
        self._sbuffer_has_data_evt.set()
×
319
        self._maybe_consume_sbuffer()
×
320

321
    def _maybe_consume_sbuffer(self) -> None:
5✔
322
        """Maybe take some data from sbuffer and send it on the wire."""
323
        if not self._can_send.is_set() or self.is_closing():
×
324
            return
×
325
        buf = self._sbuffer
×
326
        if not buf:
×
327
            return
×
328
        # if there is enough data in the buffer, or if we haven't sent in a while, send now:
329
        if not (
×
330
            self._force_send
331
            or len(buf) >= self.MIN_PACKET_SIZE
332
            or self._last_send + self.WAIT_FOR_BUFFER_GROWTH_SECONDS < time.monotonic()
333
        ):
334
            return
×
335
        assert buf[-2:] in (b"}\n", b"]\n"), f"unexpected json-rpc terminator: {buf[-2:]=!r}"
×
336
        # either (1) pad length to next power of two, to create "lsize" packet:
337
        payload_lsize = len(buf)
×
338
        total_lsize = max(self.MIN_PACKET_SIZE, 2 ** (payload_lsize.bit_length()))
×
339
        npad_lsize = total_lsize - payload_lsize
×
340
        # or if that wasted a lot of bandwidth with padding, (2) defer sending some messages
341
        # and create a packet with half that size ("ssize", s for small)
342
        total_ssize = max(self.MIN_PACKET_SIZE, total_lsize // 2)
×
343
        payload_ssize = buf.rfind(b"\n", 0, total_ssize)
×
344
        if payload_ssize != -1:
×
345
            payload_ssize += 1  # for "\n" char
×
346
            npad_ssize = total_ssize - payload_ssize
×
347
        else:
348
            npad_ssize = float("inf")
×
349
        # decide between (1) and (2):
350
        if self._force_send or npad_lsize <= npad_ssize:
×
351
            # (1) create "lsize" packet: consume full buffer
352
            npad = npad_lsize
×
353
            p_idx = payload_lsize
×
354
        else:
355
            # (2) create "ssize" packet: consume some, but defer some for later
356
            npad = npad_ssize
×
357
            p_idx = payload_ssize
×
358
        # pad by adding spaces near end
359
        # self.session.maybe_log(
360
        #     f"PaddedRSTransport. calling low-level write(). "
361
        #     f"chose between (lsize:{payload_lsize}+{npad_lsize}, ssize:{payload_ssize}+{npad_ssize}). "
362
        #     f"won: {'tie' if npad_lsize == npad_ssize else 'lsize' if npad_lsize < npad_ssize else 'ssize'}."
363
        # )
364
        json_rpc_terminator = buf[p_idx-2:p_idx]
×
365
        assert json_rpc_terminator in (b"}\n", b"]\n"), f"unexpected {json_rpc_terminator=!r}"
×
366
        buf2 = buf[:p_idx-2] + (npad * b" ") + json_rpc_terminator
×
367
        self._asyncio_transport.write(buf2)
×
368
        self._last_send = time.monotonic()
×
369
        del self._sbuffer[:p_idx]
×
370
        if not self._sbuffer:
×
371
            self._sbuffer_has_data_evt.clear()
×
372

373
    async def _poll_sbuffer(self):
5✔
374
        while not self.is_closing():
×
375
            await self._can_send.wait()
×
376
            await self._sbuffer_has_data_evt.wait()  # to avoid busy-waiting
×
377
            self._maybe_consume_sbuffer()
×
378
            # If there is still data in the buffer, sleep until it would time out.
379
            # note: If the transport is ~idle, when we wake up, we will send the current buf data,
380
            #       but if busy, we might wake up to completely new buffer contents. Either is fine.
381
            if len(self._sbuffer) > 0:
×
382
                timeout_abs = self._last_send + self.WAIT_FOR_BUFFER_GROWTH_SECONDS
×
383
                timeout_rel = max(0.0, timeout_abs - time.monotonic())
×
384
                await asyncio.sleep(timeout_rel)
×
385

386
    def connection_made(self, transport: asyncio.BaseTransport):
5✔
387
        super().connection_made(transport)
×
388
        if isinstance(self.session, NotificationSession):
×
389
            coro = self.session.taskgroup.spawn(self._poll_sbuffer())
×
390
            self._sbuffer_task = self.loop.create_task(coro)
×
391
        else:
392
            # This a short-lived "fetch_certificate"-type session.
393
            # No polling here, we always force-empty the buffer.
394
            self._force_send = True
×
395

396

397
class ServerAddr:
5✔
398

399
    def __init__(self, host: str, port: Union[int, str], *, protocol: str = None):
5✔
400
        assert isinstance(host, str), repr(host)
5✔
401
        if protocol is None:
5✔
402
            protocol = 's'
×
403
        if not host:
5✔
404
            raise ValueError('host must not be empty')
×
405
        if host[0] == '[' and host[-1] == ']':  # IPv6
5✔
406
            host = host[1:-1]
5✔
407
        try:
5✔
408
            net_addr = NetAddress(host, port)  # this validates host and port
5✔
409
        except Exception as e:
5✔
410
            raise ValueError(f"cannot construct ServerAddr: invalid host or port (host={host}, port={port})") from e
5✔
411
        if protocol not in _KNOWN_NETWORK_PROTOCOLS:
5✔
412
            raise ValueError(f"invalid network protocol: {protocol}")
×
413
        self.host = str(net_addr.host)  # canonical form (if e.g. IPv6 address)
5✔
414
        self.port = int(net_addr.port)
5✔
415
        self.protocol = protocol
5✔
416
        self._net_addr_str = str(net_addr)
5✔
417

418
    @classmethod
5✔
419
    def from_str(cls, s: str) -> 'ServerAddr':
5✔
420
        """Constructs a ServerAddr or raises ValueError."""
421
        # host might be IPv6 address, hence do rsplit:
422
        host, port, protocol = str(s).rsplit(':', 2)
5✔
423
        return ServerAddr(host=host, port=port, protocol=protocol)
5✔
424

425
    @classmethod
5✔
426
    def from_str_with_inference(cls, s: str) -> Optional['ServerAddr']:
5✔
427
        """Construct ServerAddr from str, guessing missing details.
428
        Does not raise - just returns None if guessing failed.
429
        Ongoing compatibility not guaranteed.
430
        """
431
        if not s:
5✔
432
            return None
×
433
        host = ""
5✔
434
        if s[0] == "[" and "]" in s:  # IPv6 address
5✔
435
            host_end = s.index("]")
5✔
436
            host = s[1:host_end]
5✔
437
            s = s[host_end+1:]
5✔
438
        items = str(s).rsplit(':', 2)
5✔
439
        if len(items) < 2:
5✔
440
            return None  # although maybe we could guess the port too?
5✔
441
        host = host or items[0]
5✔
442
        port = items[1]
5✔
443
        if len(items) >= 3:
5✔
444
            protocol = items[2]
5✔
445
        else:
446
            protocol = PREFERRED_NETWORK_PROTOCOL
5✔
447
        try:
5✔
448
            return ServerAddr(host=host, port=port, protocol=protocol)
5✔
449
        except ValueError:
5✔
450
            return None
5✔
451

452
    def to_friendly_name(self) -> str:
5✔
453
        # note: this method is closely linked to from_str_with_inference
454
        if self.protocol == 's':  # hide trailing ":s"
5✔
455
            return self.net_addr_str()
5✔
456
        return str(self)
5✔
457

458
    def __str__(self):
5✔
459
        return '{}:{}'.format(self.net_addr_str(), self.protocol)
5✔
460

461
    def to_json(self) -> str:
5✔
462
        return str(self)
×
463

464
    def __repr__(self):
5✔
465
        return f'<ServerAddr host={self.host} port={self.port} protocol={self.protocol}>'
×
466

467
    def net_addr_str(self) -> str:
5✔
468
        return self._net_addr_str
5✔
469

470
    def __eq__(self, other):
5✔
471
        if not isinstance(other, ServerAddr):
5✔
472
            return False
×
473
        return (self.host == other.host
5✔
474
                and self.port == other.port
475
                and self.protocol == other.protocol)
476

477
    def __ne__(self, other):
5✔
478
        return not (self == other)
×
479

480
    def __hash__(self):
5✔
481
        return hash((self.host, self.port, self.protocol))
×
482

483

484
def _get_cert_path_for_host(*, config: 'SimpleConfig', host: str) -> str:
5✔
485
    filename = host
5✔
486
    try:
5✔
487
        ip = ip_address(host)
5✔
488
    except ValueError:
5✔
489
        pass
5✔
490
    else:
491
        if isinstance(ip, IPv6Address):
×
492
            filename = f"ipv6_{ip.packed.hex()}"
×
493
    return os.path.join(config.path, 'certs', filename)
5✔
494

495

496
class Interface(Logger):
5✔
497

498
    def __init__(self, *, network: 'Network', server: ServerAddr):
5✔
499
        assert isinstance(server, ServerAddr), f"expected ServerAddr, got {type(server)}"
5✔
500
        self.ready = network.asyncio_loop.create_future()
5✔
501
        self.got_disconnected = asyncio.Event()
5✔
502
        self.server = server
5✔
503
        Logger.__init__(self)
5✔
504
        assert network.config.path
5✔
505
        self.cert_path = _get_cert_path_for_host(config=network.config, host=self.host)
5✔
506
        self.blockchain = None  # type: Optional[Blockchain]
5✔
507
        self._requested_chunks = set()  # type: Set[int]
5✔
508
        self.network = network
5✔
509
        self.session = None  # type: Optional[NotificationSession]
5✔
510
        self._ipaddr_bucket = None
5✔
511
        # Set up proxy.
512
        # - for servers running on localhost, the proxy is not used. If user runs their own server
513
        #   on same machine, this lets them enable the proxy (which is used for e.g. FX rates).
514
        #   note: we could maybe relax this further and bypass the proxy for all private
515
        #         addresses...? e.g. 192.168.x.x
516
        if util.is_localhost(server.host):
5✔
517
            self.logger.info(f"looks like localhost: not using proxy for this server")
×
518
            self.proxy = None
×
519
        else:
520
            self.proxy = ESocksProxy.from_network_settings(network)
5✔
521

522
        # Latest block header and corresponding height, as claimed by the server.
523
        # Note that these values are updated before they are verified.
524
        # Especially during initial header sync, verification can take a long time.
525
        # Failing verification will get the interface closed.
526
        self.tip_header = None  # type: Optional[dict]
5✔
527
        self.tip = 0
5✔
528

529
        self._headers_cache = {}  # type: Dict[int, bytes]
5✔
530

531
        self.fee_estimates_eta = {}  # type: Dict[int, int]
5✔
532

533
        # Dump network messages (only for this interface).  Set at runtime from the console.
534
        self.debug = False
5✔
535

536
        self.taskgroup = OldTaskGroup()
5✔
537

538
        async def spawn_task():
5✔
539
            task = await self.network.taskgroup.spawn(self.run())
5✔
540
            task.set_name(f"interface::{str(server)}")
5✔
541
        asyncio.run_coroutine_threadsafe(spawn_task(), self.network.asyncio_loop)
5✔
542

543
    @property
5✔
544
    def host(self):
5✔
545
        return self.server.host
5✔
546

547
    @property
5✔
548
    def port(self):
5✔
549
        return self.server.port
×
550

551
    @property
5✔
552
    def protocol(self):
5✔
553
        return self.server.protocol
×
554

555
    def diagnostic_name(self):
5✔
556
        return self.server.net_addr_str()
5✔
557

558
    def __str__(self):
5✔
559
        return f"<Interface {self.diagnostic_name()}>"
×
560

561
    async def is_server_ca_signed(self, ca_ssl_context: ssl.SSLContext) -> bool:
5✔
562
        """Given a CA enforcing SSL context, returns True if the connection
563
        can be established. Returns False if the server has a self-signed
564
        certificate but otherwise is okay. Any other failures raise.
565
        """
566
        try:
×
567
            await self.open_session(ssl_context=ca_ssl_context, exit_early=True)
×
568
        except ConnectError as e:
×
569
            cause = e.__cause__
×
570
            if (isinstance(cause, ssl.SSLCertVerificationError)
×
571
                    and cause.reason == 'CERTIFICATE_VERIFY_FAILED'
572
                    and cause.verify_code == 18):  # "self signed certificate"
573
                # Good. We will use this server as self-signed.
574
                return False
×
575
            # Not good. Cannot use this server.
576
            raise
×
577
        # Good. We will use this server as CA-signed.
578
        return True
×
579

580
    async def _try_saving_ssl_cert_for_first_time(self, ca_ssl_context: ssl.SSLContext) -> None:
5✔
581
        ca_signed = await self.is_server_ca_signed(ca_ssl_context)
×
582
        if ca_signed:
×
583
            if self._get_expected_fingerprint():
×
584
                raise InvalidOptionCombination("cannot use --serverfingerprint with CA signed servers")
×
585
            with open(self.cert_path, 'w') as f:
×
586
                # empty file means this is CA signed, not self-signed
587
                f.write('')
×
588
        else:
589
            await self._save_certificate()
×
590

591
    def _is_saved_ssl_cert_available(self):
5✔
592
        if not os.path.exists(self.cert_path):
×
593
            return False
×
594
        with open(self.cert_path, 'r') as f:
×
595
            contents = f.read()
×
596
        if contents == '':  # CA signed
×
597
            if self._get_expected_fingerprint():
×
598
                raise InvalidOptionCombination("cannot use --serverfingerprint with CA signed servers")
×
599
            return True
×
600
        # pinned self-signed cert
601
        try:
×
602
            b = pem.dePem(contents, 'CERTIFICATE')
×
603
        except SyntaxError as e:
×
604
            self.logger.info(f"error parsing already saved cert: {e}")
×
605
            raise ErrorParsingSSLCert(e) from e
×
606
        try:
×
607
            x = x509.X509(b)
×
608
        except Exception as e:
×
609
            self.logger.info(f"error parsing already saved cert: {e}")
×
610
            raise ErrorParsingSSLCert(e) from e
×
611
        try:
×
612
            x.check_date()
×
613
        except x509.CertificateError as e:
×
614
            self.logger.info(f"certificate has expired: {e}")
×
615
            os.unlink(self.cert_path)  # delete pinned cert only in this case
×
616
            return False
×
617
        self._verify_certificate_fingerprint(bytes(b))
×
618
        return True
×
619

620
    async def _get_ssl_context(self) -> Optional[ssl.SSLContext]:
5✔
621
        if self.protocol != 's':
×
622
            # using plaintext TCP
623
            return None
×
624

625
        # see if we already have cert for this server; or get it for the first time
626
        ca_sslc = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH, cafile=ca_path)
×
627
        if not self._is_saved_ssl_cert_available():
×
628
            try:
×
629
                await self._try_saving_ssl_cert_for_first_time(ca_sslc)
×
630
            except (OSError, ConnectError, aiorpcx.socks.SOCKSError) as e:
×
631
                raise ErrorGettingSSLCertFromServer(e) from e
×
632
        # now we have a file saved in our certificate store
633
        siz = os.stat(self.cert_path).st_size
×
634
        if siz == 0:
×
635
            # CA signed cert
636
            sslc = ca_sslc
×
637
        else:
638
            # pinned self-signed cert
639
            sslc = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH, cafile=self.cert_path)
×
640
            # note: Flag "ssl.VERIFY_X509_STRICT" is enabled by default in python 3.13+ (disabled in older versions).
641
            #       We explicitly disable it as it breaks lots of servers.
642
            sslc.verify_flags &= ~ssl.VERIFY_X509_STRICT
×
643
            sslc.check_hostname = False
×
644
        return sslc
×
645

646
    def handle_disconnect(func):
5✔
647
        @functools.wraps(func)
5✔
648
        async def wrapper_func(self: 'Interface', *args, **kwargs):
5✔
649
            try:
×
650
                return await func(self, *args, **kwargs)
×
651
            except GracefulDisconnect as e:
×
652
                self.logger.log(e.log_level, f"disconnecting due to {repr(e)}")
×
653
            except aiorpcx.jsonrpc.RPCError as e:
×
654
                self.logger.warning(f"disconnecting due to {repr(e)}")
×
655
                self.logger.debug(f"(disconnect) trace for {repr(e)}", exc_info=True)
×
656
            finally:
657
                self.got_disconnected.set()
×
658
                # Make sure taskgroup gets cleaned-up. This explicit clean-up is needed here
659
                # in case the "with taskgroup" ctx mgr never got a chance to run:
660
                await self.taskgroup.cancel_remaining()
×
661
                await self.network.connection_down(self)
×
662
                # if was not 'ready' yet, schedule waiting coroutines:
663
                self.ready.cancel()
×
664
        return wrapper_func
5✔
665

666
    @ignore_exceptions  # do not kill network.taskgroup
5✔
667
    @log_exceptions
5✔
668
    @handle_disconnect
5✔
669
    async def run(self):
5✔
670
        try:
×
671
            ssl_context = await self._get_ssl_context()
×
672
        except (ErrorParsingSSLCert, ErrorGettingSSLCertFromServer) as e:
×
673
            self.logger.info(f'disconnecting due to: {repr(e)}')
×
674
            return
×
675
        try:
×
676
            await self.open_session(ssl_context=ssl_context)
×
677
        except (asyncio.CancelledError, ConnectError, aiorpcx.socks.SOCKSError) as e:
×
678
            # make SSL errors for main interface more visible (to help servers ops debug cert pinning issues)
679
            if (isinstance(e, ConnectError) and isinstance(e.__cause__, ssl.SSLError)
×
680
                    and self.is_main_server() and not self.network.auto_connect):
681
                self.logger.warning(f'Cannot connect to main server due to SSL error '
×
682
                                    f'(maybe cert changed compared to "{self.cert_path}"). Exc: {repr(e)}')
683
            else:
684
                self.logger.info(f'disconnecting due to: {repr(e)}')
×
685
            return
×
686

687
    def _mark_ready(self) -> None:
5✔
688
        if self.ready.cancelled():
×
689
            raise GracefulDisconnect('conn establishment was too slow; *ready* future was cancelled')
×
690
        if self.ready.done():
×
691
            return
×
692

693
        assert self.tip_header
×
694
        chain = blockchain.check_header(self.tip_header)
×
695
        if not chain:
×
696
            self.blockchain = blockchain.get_best_chain()
×
697
        else:
698
            self.blockchain = chain
×
699
        assert self.blockchain is not None
×
700

701
        self.logger.info(f"set blockchain with height {self.blockchain.height()}")
×
702

703
        self.ready.set_result(1)
×
704

705
    def is_connected_and_ready(self) -> bool:
5✔
706
        return self.ready.done() and not self.got_disconnected.is_set()
×
707

708
    async def _save_certificate(self) -> None:
5✔
709
        if not os.path.exists(self.cert_path):
×
710
            # we may need to retry this a few times, in case the handshake hasn't completed
711
            for _ in range(10):
×
712
                dercert = await self._fetch_certificate()
×
713
                if dercert:
×
714
                    self.logger.info("succeeded in getting cert")
×
715
                    self._verify_certificate_fingerprint(dercert)
×
716
                    with open(self.cert_path, 'w') as f:
×
717
                        cert = ssl.DER_cert_to_PEM_cert(dercert)
×
718
                        # workaround android bug
719
                        cert = re.sub("([^\n])-----END CERTIFICATE-----","\\1\n-----END CERTIFICATE-----",cert)
×
720
                        f.write(cert)
×
721
                        # even though close flushes, we can't fsync when closed.
722
                        # and we must flush before fsyncing, cause flush flushes to OS buffer
723
                        # fsync writes to OS buffer to disk
724
                        f.flush()
×
725
                        os.fsync(f.fileno())
×
726
                    break
×
727
                await asyncio.sleep(1)
×
728
            else:
729
                raise GracefulDisconnect("could not get certificate after 10 tries")
×
730

731
    async def _fetch_certificate(self) -> bytes:
5✔
732
        sslc = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_CLIENT)
×
733
        sslc.check_hostname = False
×
734
        sslc.verify_mode = ssl.CERT_NONE
×
735
        async with _RSClient(
×
736
            session_factory=RPCSession,
737
            host=self.host, port=self.port,
738
            ssl=sslc,
739
            proxy=self.proxy,
740
            transport=PaddedRSTransport,
741
        ) as session:
742
            asyncio_transport = session.transport._asyncio_transport  # type: asyncio.BaseTransport
×
743
            ssl_object = asyncio_transport.get_extra_info("ssl_object")  # type: ssl.SSLObject
×
744
            return ssl_object.getpeercert(binary_form=True)
×
745

746
    def _get_expected_fingerprint(self) -> Optional[str]:
5✔
747
        if self.is_main_server():
×
748
            return self.network.config.NETWORK_SERVERFINGERPRINT
×
749
        return None
×
750

751
    def _verify_certificate_fingerprint(self, certificate: bytes) -> None:
5✔
752
        expected_fingerprint = self._get_expected_fingerprint()
×
753
        if not expected_fingerprint:
×
754
            return
×
755
        fingerprint = hashlib.sha256(certificate).hexdigest()
×
756
        fingerprints_match = fingerprint.lower() == expected_fingerprint.lower()
×
757
        if not fingerprints_match:
×
758
            util.trigger_callback('cert_mismatch')
×
759
            raise ErrorSSLCertFingerprintMismatch('Refusing to connect to server due to cert fingerprint mismatch')
×
760
        self.logger.info("cert fingerprint verification passed")
×
761

762
    async def _maybe_warm_headers_cache(self, *, from_height: int, to_height: int, mode: ChainResolutionMode) -> None:
5✔
763
        """Populate header cache for block heights in range [from_height, to_height]."""
764
        assert from_height <= to_height, (from_height, to_height)
×
765
        assert to_height - from_height < MAX_NUM_HEADERS_PER_REQUEST
×
766
        if all(height in self._headers_cache for height in range(from_height, to_height+1)):
×
767
            # cache already has all requested headers
768
            return
×
769
        # use lower timeout as we usually have network.bhi_lock here
770
        timeout = self.network.get_network_timeout_seconds(NetworkTimeout.Urgent)
×
771
        count = to_height - from_height + 1
×
772
        headers = await self.get_block_headers(start_height=from_height, count=count, timeout=timeout, mode=mode)
×
773
        for idx, raw_header in enumerate(headers):
×
774
            header_height = from_height + idx
×
775
            self._headers_cache[header_height] = raw_header
×
776

777
    async def get_block_header(self, height: int, *, mode: ChainResolutionMode) -> dict:
5✔
778
        if not is_non_negative_integer(height):
×
779
            raise Exception(f"{repr(height)} is not a block height")
×
780
        #self.logger.debug(f'get_block_header() {height} in {mode=}')
781
        # use lower timeout as we usually have network.bhi_lock here
782
        timeout = self.network.get_network_timeout_seconds(NetworkTimeout.Urgent)
×
783
        if raw_header := self._headers_cache.get(height):
×
784
            return blockchain.deserialize_header(raw_header, height)
×
785
        self.logger.info(f'requesting block header {height} in {mode=}')
×
786
        res = await self.session.send_request('blockchain.block.header', [height], timeout=timeout)
×
787
        return blockchain.deserialize_header(bytes.fromhex(res), height)
×
788

789
    async def get_block_headers(
5✔
790
        self,
791
        *,
792
        start_height: int,
793
        count: int,
794
        timeout=None,
795
        mode: Optional[ChainResolutionMode] = None,
796
    ) -> Sequence[bytes]:
797
        """Request a number of consecutive block headers, starting at `start_height`.
798
        `count` is the num of requested headers, BUT note the server might return fewer than this
799
        (if range would extend beyond its tip).
800
        note: the returned headers are not verified or parsed at all.
801
        """
802
        if not is_non_negative_integer(start_height):
×
803
            raise Exception(f"{repr(start_height)} is not a block height")
×
804
        if not is_non_negative_integer(count) or not (0 < count <= MAX_NUM_HEADERS_PER_REQUEST):
×
805
            raise Exception(f"{repr(count)} not an int in range ]0, {MAX_NUM_HEADERS_PER_REQUEST}]")
×
806
        self.logger.info(
×
807
            f"requesting block headers: [{start_height}, {start_height+count-1}], {count=}"
808
            + (f" (in {mode=})" if mode is not None else "")
809
        )
810
        res = await self.session.send_request('blockchain.block.headers', [start_height, count], timeout=timeout)
×
811
        # check response
812
        assert_dict_contains_field(res, field_name='count')
×
813
        assert_dict_contains_field(res, field_name='hex')
×
814
        assert_dict_contains_field(res, field_name='max')
×
815
        assert_non_negative_integer(res['count'])
×
816
        assert_non_negative_integer(res['max'])
×
817
        assert_hex_str(res['hex'])
×
818
        if len(res['hex']) != HEADER_SIZE * 2 * res['count']:
×
819
            raise RequestCorrupted('inconsistent chunk hex and count')
×
820
        # we never request more than MAX_NUM_HEADERS_IN_REQUEST headers, but we enforce those fit in a single response
821
        if res['max'] < MAX_NUM_HEADERS_PER_REQUEST:
×
822
            raise RequestCorrupted(f"server uses too low 'max' count for block.headers: {res['max']} < {MAX_NUM_HEADERS_PER_REQUEST}")
×
823
        if res['count'] > count:
×
824
            raise RequestCorrupted(f"asked for {count} headers but got more: {res['count']}")
×
825
        elif res['count'] < count:
×
826
            # we only tolerate getting fewer headers if it is due to reaching the tip
827
            end_height = start_height + res['count'] - 1
×
828
            if end_height < self.tip:  # still below tip. why did server not send more?!
×
829
                raise RequestCorrupted(
×
830
                    f"asked for {count} headers but got fewer: {res['count']}. ({start_height=}, {self.tip=})")
831
        # checks done.
832
        headers = list(util.chunks(bfh(res['hex']), size=HEADER_SIZE))
×
833
        return headers
×
834

835
    async def request_chunk_below_max_checkpoint(
5✔
836
        self,
837
        *,
838
        height: int,
839
    ) -> None:
840
        if not is_non_negative_integer(height):
×
841
            raise Exception(f"{repr(height)} is not a block height")
×
842
        assert height <= constants.net.max_checkpoint(), f"{height=} must be <= cp={constants.net.max_checkpoint()}"
×
843
        index = height // CHUNK_SIZE
×
844
        if index in self._requested_chunks:
×
845
            return None
×
846
        self.logger.debug(f"requesting chunk from height {height}")
×
847
        try:
×
848
            self._requested_chunks.add(index)
×
849
            headers = await self.get_block_headers(start_height=index * CHUNK_SIZE, count=CHUNK_SIZE)
×
850
        finally:
851
            self._requested_chunks.discard(index)
×
852
        conn = self.blockchain.connect_chunk(index, data=b"".join(headers))
×
853
        if not conn:
×
854
            raise RequestCorrupted(f"chunk ({index=}, for {height=}) does not connect to blockchain")
×
855
        return None
×
856

857
    async def _fast_forward_chain(
5✔
858
        self,
859
        *,
860
        height: int,  # usually local chain tip + 1
861
        tip: int,  # server tip. we should not request past this.
862
    ) -> int:
863
        """Request some headers starting at `height` to grow the blockchain of this interface.
864
        Returns number of headers we managed to connect, starting at `height`.
865
        """
866
        if not is_non_negative_integer(height):
×
867
            raise Exception(f"{repr(height)} is not a block height")
×
868
        if not is_non_negative_integer(tip):
×
869
            raise Exception(f"{repr(tip)} is not a block height")
×
870
        if not (height > constants.net.max_checkpoint()
×
871
                or height == 0 == constants.net.max_checkpoint()):
872
            raise Exception(f"{height=} must be > cp={constants.net.max_checkpoint()}")
×
873
        assert height <= tip, f"{height=} must be <= {tip=}"
×
874
        # Request a few chunks of headers concurrently.
875
        # tradeoffs:
876
        # - more chunks: higher memory requirements
877
        # - more chunks: higher concurrency => syncing needs fewer network round-trips
878
        # - if a chunk does not connect, bandwidth for all later chunks is wasted
879
        async with OldTaskGroup() as group:
×
880
            tasks = []  # type: List[Tuple[int, asyncio.Task[Sequence[bytes]]]]
×
881
            index0 = height // CHUNK_SIZE
×
882
            for chunk_cnt in range(10):
×
883
                index = index0 + chunk_cnt
×
884
                start_height = index * CHUNK_SIZE
×
885
                if start_height > tip:
×
886
                    break
×
887
                end_height = min(start_height + CHUNK_SIZE - 1, tip)
×
888
                size = end_height - start_height + 1
×
889
                tasks.append((index, await group.spawn(self.get_block_headers(start_height=start_height, count=size))))
×
890
        # try to connect chunks
891
        num_headers = 0
×
892
        for index, task in tasks:
×
893
            headers = task.result()
×
894
            conn = self.blockchain.connect_chunk(index, data=b"".join(headers))
×
895
            if not conn:
×
896
                break
×
897
            num_headers += len(headers)
×
898
        # We started at a chunk boundary, instead of requested `height`. Need to correct for that.
899
        offset = height - index0 * CHUNK_SIZE
×
900
        return max(0, num_headers - offset)
×
901

902
    def is_main_server(self) -> bool:
5✔
903
        return (self.network.interface == self or
×
904
                self.network.interface is None and self.network.default_server == self.server)
905

906
    async def open_session(
5✔
907
        self,
908
        *,
909
        ssl_context: Optional[ssl.SSLContext],
910
        exit_early: bool = False,
911
    ):
912
        session_factory = lambda *args, iface=self, **kwargs: NotificationSession(*args, **kwargs, interface=iface)
×
913
        async with _RSClient(
×
914
            session_factory=session_factory,
915
            host=self.host, port=self.port,
916
            ssl=ssl_context,
917
            proxy=self.proxy,
918
            transport=PaddedRSTransport,
919
        ) as session:
920
            self.session = session  # type: NotificationSession
×
921
            self.session.set_default_timeout(self.network.get_network_timeout_seconds(NetworkTimeout.Generic))
×
922
            try:
×
923
                ver = await session.send_request('server.version', [self.client_name(), version.PROTOCOL_VERSION])
×
924
            except aiorpcx.jsonrpc.RPCError as e:
×
925
                raise GracefulDisconnect(e)  # probably 'unsupported protocol version'
×
926
            if exit_early:
×
927
                return
×
928
            if ver[1] != version.PROTOCOL_VERSION:
×
929
                raise GracefulDisconnect(f'server violated protocol-version-negotiation. '
×
930
                                         f'we asked for {version.PROTOCOL_VERSION!r}, they sent {ver[1]!r}')
931
            if not self.network.check_interface_against_healthy_spread_of_connected_servers(self):
×
932
                raise GracefulDisconnect(f'too many connected servers already '
×
933
                                         f'in bucket {self.bucket_based_on_ipaddress()}')
934
            self.logger.info(f"connection established. version: {ver}")
×
935

936
            try:
×
937
                async with self.taskgroup as group:
×
938
                    await group.spawn(self.ping)
×
939
                    await group.spawn(self.request_fee_estimates)
×
940
                    await group.spawn(self.run_fetch_blocks)
×
941
                    await group.spawn(self.monitor_connection)
×
942
            except aiorpcx.jsonrpc.RPCError as e:
×
943
                if e.code in (
×
944
                    JSONRPC.EXCESSIVE_RESOURCE_USAGE,
945
                    JSONRPC.SERVER_BUSY,
946
                    JSONRPC.METHOD_NOT_FOUND,
947
                    JSONRPC.INTERNAL_ERROR,
948
                ):
949
                    log_level = logging.WARNING if self.is_main_server() else logging.INFO
×
950
                    raise GracefulDisconnect(e, log_level=log_level) from e
×
951
                raise
×
952
            finally:
953
                self.got_disconnected.set()  # set this ASAP, ideally before any awaits
×
954

955
    async def monitor_connection(self):
5✔
956
        while True:
×
957
            await asyncio.sleep(1)
×
958
            # If the session/transport is no longer open, we disconnect.
959
            # e.g. if the remote cleanly sends EOF, we would handle that here.
960
            # note: If the user pulls the ethernet cable or disconnects wifi,
961
            #       ideally we would detect that here, so that the GUI/etc can reflect that.
962
            #       - On Android, this seems to work reliably , where asyncio.BaseProtocol.connection_lost()
963
            #         gets called with e.g. ConnectionAbortedError(103, 'Software caused connection abort').
964
            #       - On desktop Linux/Win, it seems BaseProtocol.connection_lost() is not called in such cases.
965
            #         Hence, in practice the connection issue will only be detected the next time we try
966
            #         to send a message (plus timeout), which can take minutes...
967
            if not self.session or self.session.is_closing():
×
968
                raise GracefulDisconnect('session was closed')
×
969

970
    async def ping(self):
5✔
971
        # We periodically send a "ping" msg to make sure the server knows we are still here.
972
        # Adding a bit of randomness generates some noise against traffic analysis.
973
        while True:
×
974
            await asyncio.sleep(random.random() * 300)
×
975
            await self.session.send_request('server.ping')
×
976
            await self._maybe_send_noise()
×
977

978
    async def _maybe_send_noise(self):
5✔
979
        while random.random() < 0.2:
×
980
            await asyncio.sleep(random.random())
×
981
            await self.session.send_request('server.ping')
×
982

983
    async def request_fee_estimates(self):
5✔
984
        while True:
×
985
            async with OldTaskGroup() as group:
×
986
                fee_tasks = []
×
987
                for i in FEE_ETA_TARGETS[0:-1]:
×
988
                    fee_tasks.append((i, await group.spawn(self.get_estimatefee(i))))
×
989
            for nblock_target, task in fee_tasks:
×
990
                fee = task.result()
×
991
                if fee < 0: continue
×
992
                assert isinstance(fee, int)
×
993
                self.fee_estimates_eta[nblock_target] = fee
×
994
            self.network.update_fee_estimates()
×
995
            await asyncio.sleep(60)
×
996

997
    async def close(self, *, force_after: int = None):
5✔
998
        """Closes the connection and waits for it to be closed.
999
        We try to flush buffered data to the wire, which can take some time.
1000
        """
1001
        if self.session:
×
1002
            await self.session.close(force_after=force_after)
×
1003
        # monitor_connection will cancel tasks
1004

1005
    async def run_fetch_blocks(self):
5✔
1006
        header_queue = asyncio.Queue()
×
1007
        await self.session.subscribe('blockchain.headers.subscribe', [], header_queue)
×
1008
        while True:
×
1009
            item = await header_queue.get()
×
1010
            raw_header = item[0]
×
1011
            height = raw_header['height']
×
1012
            header_bytes = bfh(raw_header['hex'])
×
1013
            header_dict = blockchain.deserialize_header(header_bytes, height)
×
1014
            self.tip_header = header_dict
×
1015
            self.tip = height
×
1016
            if self.tip < constants.net.max_checkpoint():
×
1017
                raise GracefulDisconnect(
×
1018
                    f"server tip below max checkpoint. ({self.tip} < {constants.net.max_checkpoint()})")
1019
            self._mark_ready()
×
1020
            self._headers_cache.clear()  # tip changed, so assume anything could have happened with chain
×
1021
            self._headers_cache[height] = header_bytes
×
1022
            try:
×
1023
                blockchain_updated = await self._process_header_at_tip()
×
1024
            finally:
1025
                self._headers_cache.clear()  # to reduce memory usage
×
1026
            # header processing done
1027
            if self.is_main_server() or blockchain_updated:
×
1028
                self.logger.info(f"new chain tip. {height=}")
×
1029
            if blockchain_updated:
×
1030
                util.trigger_callback('blockchain_updated')
×
1031
            util.trigger_callback('network_updated')
×
1032
            await self.network.switch_unwanted_fork_interface()
×
1033
            await self.network.switch_lagging_interface()
×
1034
            await self.taskgroup.spawn(self._maybe_send_noise())
×
1035

1036
    async def _process_header_at_tip(self) -> bool:
5✔
1037
        """Returns:
1038
        False - boring fast-forward: we already have this header as part of this blockchain from another interface,
1039
        True - new header we didn't have, or reorg
1040
        """
1041
        height, header = self.tip, self.tip_header
×
1042
        async with self.network.bhi_lock:
×
1043
            if self.blockchain.height() >= height and self.blockchain.check_header(header):
×
1044
                # another interface amended the blockchain
1045
                return False
×
1046
            await self.sync_until(height)
×
1047
            return True
×
1048

1049
    async def sync_until(
5✔
1050
        self,
1051
        height: int,
1052
        *,
1053
        next_height: Optional[int] = None,  # sync target. typically the tip, except in unit tests
1054
    ) -> Tuple[ChainResolutionMode, int]:
1055
        if next_height is None:
5✔
1056
            next_height = self.tip
5✔
1057
        last = None  # type: Optional[ChainResolutionMode]
5✔
1058
        while last is None or height <= next_height:
5✔
1059
            prev_last, prev_height = last, height
5✔
1060
            if next_height > height + 144:
5✔
1061
                # We are far from the tip.
1062
                # It is more efficient to process headers in large batches (CPU/disk_usage/logging).
1063
                # (but this wastes a little bandwidth, if we are not on a chunk boundary)
1064
                num_headers = await self._fast_forward_chain(
×
1065
                    height=height, tip=next_height)
1066
                if num_headers == 0:
×
1067
                    if height <= constants.net.max_checkpoint():
×
1068
                        raise GracefulDisconnect('server chain conflicts with checkpoints or genesis')
×
1069
                    last, height = await self.step(height)
×
1070
                    continue
×
1071
                # report progress to gui/etc
1072
                util.trigger_callback('blockchain_updated')
×
1073
                util.trigger_callback('network_updated')
×
1074
                height += num_headers
×
1075
                assert height <= next_height+1, (height, self.tip)
×
1076
                last = ChainResolutionMode.CATCHUP
×
1077
            else:
1078
                # We are close to the tip, so process headers one-by-one.
1079
                # (note: due to headers_cache, to save network latency, this can still batch-request headers)
1080
                last, height = await self.step(height)
5✔
1081
            assert (prev_last, prev_height) != (last, height), 'had to prevent infinite loop in interface.sync_until'
5✔
1082
        return last, height
5✔
1083

1084
    async def step(
5✔
1085
        self,
1086
        height: int,
1087
    ) -> Tuple[ChainResolutionMode, int]:
1088
        assert 0 <= height <= self.tip, (height, self.tip)
5✔
1089
        await self._maybe_warm_headers_cache(
5✔
1090
            from_height=height,
1091
            to_height=min(self.tip, height+MAX_NUM_HEADERS_PER_REQUEST-1),
1092
            mode=ChainResolutionMode.CATCHUP,
1093
        )
1094
        header = await self.get_block_header(height, mode=ChainResolutionMode.CATCHUP)
5✔
1095

1096
        chain = blockchain.check_header(header)
5✔
1097
        if chain:
5✔
1098
            self.blockchain = chain
5✔
1099
            # note: there is an edge case here that is not handled.
1100
            # we might know the blockhash (enough for check_header) but
1101
            # not have the header itself. e.g. regtest chain with only genesis.
1102
            # this situation resolves itself on the next block
1103
            return ChainResolutionMode.CATCHUP, height+1
5✔
1104

1105
        can_connect = blockchain.can_connect(header)
5✔
1106
        if not can_connect:
5✔
1107
            self.logger.info(f"can't connect new block: {height=}")
5✔
1108
            height, header, bad, bad_header = await self._search_headers_backwards(height, header=header)
5✔
1109
            chain = blockchain.check_header(header)
5✔
1110
            can_connect = blockchain.can_connect(header)
5✔
1111
            assert chain or can_connect
5✔
1112
        if can_connect:
5✔
1113
            height += 1
5✔
1114
            self.blockchain = can_connect
5✔
1115
            self.blockchain.save_header(header)
5✔
1116
            return ChainResolutionMode.CATCHUP, height
5✔
1117

1118
        good, bad, bad_header = await self._search_headers_binary(height, bad, bad_header, chain)
5✔
1119
        return await self._resolve_potential_chain_fork_given_forkpoint(good, bad, bad_header)
5✔
1120

1121
    async def _search_headers_binary(
5✔
1122
        self,
1123
        height: int,
1124
        bad: int,
1125
        bad_header: dict,
1126
        chain: Optional[Blockchain],
1127
    ) -> Tuple[int, int, dict]:
1128
        assert bad == bad_header['block_height']
5✔
1129
        _assert_header_does_not_check_against_any_chain(bad_header)
5✔
1130

1131
        self.blockchain = chain
5✔
1132
        good = height
5✔
1133
        while True:
5✔
1134
            assert 0 <= good < bad, (good, bad)
5✔
1135
            height = (good + bad) // 2
5✔
1136
            self.logger.info(f"binary step. good {good}, bad {bad}, height {height}")
5✔
1137
            if bad - good + 1 <= MAX_NUM_HEADERS_PER_REQUEST:  # if interval is small, trade some bandwidth for lower latency
5✔
1138
                await self._maybe_warm_headers_cache(
5✔
1139
                    from_height=good, to_height=bad, mode=ChainResolutionMode.BINARY)
1140
            header = await self.get_block_header(height, mode=ChainResolutionMode.BINARY)
5✔
1141
            chain = blockchain.check_header(header)
5✔
1142
            if chain:
5✔
1143
                self.blockchain = chain
5✔
1144
                good = height
5✔
1145
            else:
1146
                bad = height
5✔
1147
                bad_header = header
5✔
1148
            if good + 1 == bad:
5✔
1149
                break
5✔
1150

1151
        if not self.blockchain.can_connect(bad_header, check_height=False):
5✔
1152
            raise Exception('unexpected bad header during binary: {}'.format(bad_header))
×
1153
        _assert_header_does_not_check_against_any_chain(bad_header)
5✔
1154

1155
        self.logger.info(f"binary search exited. good {good}, bad {bad}. {chain=}")
5✔
1156
        return good, bad, bad_header
5✔
1157

1158
    async def _resolve_potential_chain_fork_given_forkpoint(
5✔
1159
        self,
1160
        good: int,
1161
        bad: int,
1162
        bad_header: dict,
1163
    ) -> Tuple[ChainResolutionMode, int]:
1164
        assert good + 1 == bad
5✔
1165
        assert bad == bad_header['block_height']
5✔
1166
        _assert_header_does_not_check_against_any_chain(bad_header)
5✔
1167
        # 'good' is the height of a block 'good_header', somewhere in self.blockchain.
1168
        # bad_header connects to good_header; bad_header itself is NOT in self.blockchain.
1169

1170
        bh = self.blockchain.height()
5✔
1171
        assert bh >= good, (bh, good)
5✔
1172
        if bh == good:
5✔
1173
            height = good + 1
5✔
1174
            self.logger.info(f"catching up from {height}")
5✔
1175
            return ChainResolutionMode.NO_FORK, height
5✔
1176

1177
        # this is a new fork we don't yet have
1178
        height = bad + 1
5✔
1179
        self.logger.info(f"new fork at bad height {bad}")
5✔
1180
        b = self.blockchain.fork(bad_header)  # type: Blockchain
5✔
1181
        self.blockchain = b
5✔
1182
        assert b.forkpoint == bad
5✔
1183
        return ChainResolutionMode.FORK, height
5✔
1184

1185
    async def _search_headers_backwards(
5✔
1186
        self,
1187
        height: int,
1188
        *,
1189
        header: dict,
1190
    ) -> Tuple[int, dict, int, dict]:
1191
        async def iterate():
5✔
1192
            nonlocal height, header
1193
            checkp = False
5✔
1194
            if height <= constants.net.max_checkpoint():
5✔
1195
                height = constants.net.max_checkpoint()
×
1196
                checkp = True
×
1197
            header = await self.get_block_header(height, mode=ChainResolutionMode.BACKWARD)
5✔
1198
            chain = blockchain.check_header(header)
5✔
1199
            can_connect = blockchain.can_connect(header)
5✔
1200
            if chain or can_connect:
5✔
1201
                return False
5✔
1202
            if checkp:
5✔
1203
                raise GracefulDisconnect("server chain conflicts with checkpoints")
×
1204
            return True
5✔
1205

1206
        bad, bad_header = height, header
5✔
1207
        _assert_header_does_not_check_against_any_chain(bad_header)
5✔
1208
        with blockchain.blockchains_lock: chains = list(blockchain.blockchains.values())
5✔
1209
        local_max = max([0] + [x.height() for x in chains])
5✔
1210
        height = min(local_max + 1, height - 1)
5✔
1211
        assert height >= 0
5✔
1212

1213
        await self._maybe_warm_headers_cache(
5✔
1214
            from_height=max(0, height-10), to_height=height, mode=ChainResolutionMode.BACKWARD)
1215

1216
        delta = 2
5✔
1217
        while await iterate():
5✔
1218
            bad, bad_header = height, header
5✔
1219
            height -= delta
5✔
1220
            delta *= 2
5✔
1221

1222
        _assert_header_does_not_check_against_any_chain(bad_header)
5✔
1223
        self.logger.info(f"exiting backward mode at {height}")
5✔
1224
        return height, header, bad, bad_header
5✔
1225

1226
    @classmethod
5✔
1227
    def client_name(cls) -> str:
5✔
1228
        return f'electrum/{version.ELECTRUM_VERSION}'
×
1229

1230
    def is_tor(self):
5✔
1231
        return self.host.endswith('.onion')
×
1232

1233
    def ip_addr(self) -> Optional[str]:
5✔
1234
        session = self.session
×
1235
        if not session: return None
×
1236
        peer_addr = session.remote_address()
×
1237
        if not peer_addr: return None
×
1238
        return str(peer_addr.host)
×
1239

1240
    def bucket_based_on_ipaddress(self) -> str:
5✔
1241
        def do_bucket():
×
1242
            if self.is_tor():
×
1243
                return BUCKET_NAME_OF_ONION_SERVERS
×
1244
            try:
×
1245
                ip_addr = ip_address(self.ip_addr())  # type: Union[IPv4Address, IPv6Address]
×
1246
            except ValueError:
×
1247
                return ''
×
1248
            if not ip_addr:
×
1249
                return ''
×
1250
            if ip_addr.is_loopback:  # localhost is exempt
×
1251
                return ''
×
1252
            if ip_addr.version == 4:
×
1253
                slash16 = IPv4Network(ip_addr).supernet(prefixlen_diff=32-16)
×
1254
                return str(slash16)
×
1255
            elif ip_addr.version == 6:
×
1256
                slash48 = IPv6Network(ip_addr).supernet(prefixlen_diff=128-48)
×
1257
                return str(slash48)
×
1258
            return ''
×
1259

1260
        if not self._ipaddr_bucket:
×
1261
            self._ipaddr_bucket = do_bucket()
×
1262
        return self._ipaddr_bucket
×
1263

1264
    async def get_merkle_for_transaction(self, tx_hash: str, tx_height: int) -> dict:
5✔
1265
        if not is_hash256_str(tx_hash):
×
1266
            raise Exception(f"{repr(tx_hash)} is not a txid")
×
1267
        if not is_non_negative_integer(tx_height):
×
1268
            raise Exception(f"{repr(tx_height)} is not a block height")
×
1269
        # do request
1270
        res = await self.session.send_request('blockchain.transaction.get_merkle', [tx_hash, tx_height])
×
1271
        # check response
1272
        block_height = assert_dict_contains_field(res, field_name='block_height')
×
1273
        merkle = assert_dict_contains_field(res, field_name='merkle')
×
1274
        pos = assert_dict_contains_field(res, field_name='pos')
×
1275
        # note: tx_height was just a hint to the server, don't enforce the response to match it
1276
        assert_non_negative_integer(block_height)
×
1277
        assert_non_negative_integer(pos)
×
1278
        assert_list_or_tuple(merkle)
×
1279
        for item in merkle:
×
1280
            assert_hash256_str(item)
×
1281
        return res
×
1282

1283
    async def get_transaction(self, tx_hash: str, *, timeout=None) -> str:
5✔
1284
        if not is_hash256_str(tx_hash):
×
1285
            raise Exception(f"{repr(tx_hash)} is not a txid")
×
1286
        raw = await self.session.send_request('blockchain.transaction.get', [tx_hash], timeout=timeout)
×
1287
        # validate response
1288
        if not is_hex_str(raw):
×
1289
            raise RequestCorrupted(f"received garbage (non-hex) as tx data (txid {tx_hash}): {raw!r}")
×
1290
        tx = Transaction(raw)
×
1291
        try:
×
1292
            tx.deserialize()  # see if raises
×
1293
        except Exception as e:
×
1294
            raise RequestCorrupted(f"cannot deserialize received transaction (txid {tx_hash})") from e
×
1295
        if tx.txid() != tx_hash:
×
1296
            raise RequestCorrupted(f"received tx does not match expected txid {tx_hash} (got {tx.txid()})")
×
1297
        return raw
×
1298

1299
    async def get_history_for_scripthash(self, sh: str) -> List[dict]:
5✔
1300
        if not is_hash256_str(sh):
×
1301
            raise Exception(f"{repr(sh)} is not a scripthash")
×
1302
        # do request
1303
        res = await self.session.send_request('blockchain.scripthash.get_history', [sh])
×
1304
        # check response
1305
        assert_list_or_tuple(res)
×
1306
        prev_height = 1
×
1307
        for tx_item in res:
×
1308
            height = assert_dict_contains_field(tx_item, field_name='height')
×
1309
            assert_dict_contains_field(tx_item, field_name='tx_hash')
×
1310
            assert_integer(height)
×
1311
            assert_hash256_str(tx_item['tx_hash'])
×
1312
            if height in (-1, 0):
×
1313
                assert_dict_contains_field(tx_item, field_name='fee')
×
1314
                assert_non_negative_integer(tx_item['fee'])
×
1315
                prev_height = float("inf")  # this ensures confirmed txs can't follow mempool txs
×
1316
            else:
1317
                # check monotonicity of heights
1318
                if height < prev_height:
×
1319
                    raise RequestCorrupted(f'heights of confirmed txs must be in increasing order')
×
1320
                prev_height = height
×
1321
        hashes = set(map(lambda item: item['tx_hash'], res))
×
1322
        if len(hashes) != len(res):
×
1323
            # Either server is sending garbage... or maybe if server is race-prone
1324
            # a recently mined tx could be included in both last block and mempool?
1325
            # Still, it's simplest to just disregard the response.
1326
            raise RequestCorrupted(f"server history has non-unique txids for sh={sh}")
×
1327
        return res
×
1328

1329
    async def listunspent_for_scripthash(self, sh: str) -> List[dict]:
5✔
1330
        if not is_hash256_str(sh):
×
1331
            raise Exception(f"{repr(sh)} is not a scripthash")
×
1332
        # do request
1333
        res = await self.session.send_request('blockchain.scripthash.listunspent', [sh])
×
1334
        # check response
1335
        assert_list_or_tuple(res)
×
1336
        for utxo_item in res:
×
1337
            assert_dict_contains_field(utxo_item, field_name='tx_pos')
×
1338
            assert_dict_contains_field(utxo_item, field_name='value')
×
1339
            assert_dict_contains_field(utxo_item, field_name='tx_hash')
×
1340
            assert_dict_contains_field(utxo_item, field_name='height')
×
1341
            assert_non_negative_integer(utxo_item['tx_pos'])
×
1342
            assert_non_negative_integer(utxo_item['value'])
×
1343
            assert_non_negative_integer(utxo_item['height'])
×
1344
            assert_hash256_str(utxo_item['tx_hash'])
×
1345
        return res
×
1346

1347
    async def get_balance_for_scripthash(self, sh: str) -> dict:
5✔
1348
        if not is_hash256_str(sh):
×
1349
            raise Exception(f"{repr(sh)} is not a scripthash")
×
1350
        # do request
1351
        res = await self.session.send_request('blockchain.scripthash.get_balance', [sh])
×
1352
        # check response
1353
        assert_dict_contains_field(res, field_name='confirmed')
×
1354
        assert_dict_contains_field(res, field_name='unconfirmed')
×
1355
        assert_non_negative_integer(res['confirmed'])
×
1356
        assert_integer(res['unconfirmed'])
×
1357
        return res
×
1358

1359
    async def get_txid_from_txpos(self, tx_height: int, tx_pos: int, merkle: bool):
5✔
1360
        if not is_non_negative_integer(tx_height):
×
1361
            raise Exception(f"{repr(tx_height)} is not a block height")
×
1362
        if not is_non_negative_integer(tx_pos):
×
1363
            raise Exception(f"{repr(tx_pos)} should be non-negative integer")
×
1364
        # do request
1365
        res = await self.session.send_request(
×
1366
            'blockchain.transaction.id_from_pos',
1367
            [tx_height, tx_pos, merkle],
1368
        )
1369
        # check response
1370
        if merkle:
×
1371
            assert_dict_contains_field(res, field_name='tx_hash')
×
1372
            assert_dict_contains_field(res, field_name='merkle')
×
1373
            assert_hash256_str(res['tx_hash'])
×
1374
            assert_list_or_tuple(res['merkle'])
×
1375
            for node_hash in res['merkle']:
×
1376
                assert_hash256_str(node_hash)
×
1377
        else:
1378
            assert_hash256_str(res)
×
1379
        return res
×
1380

1381
    async def get_fee_histogram(self) -> Sequence[Tuple[Union[float, int], int]]:
5✔
1382
        # do request
1383
        res = await self.session.send_request('mempool.get_fee_histogram')
×
1384
        # check response
1385
        assert_list_or_tuple(res)
×
1386
        prev_fee = float('inf')
×
1387
        for fee, s in res:
×
1388
            assert_non_negative_int_or_float(fee)
×
1389
            assert_non_negative_integer(s)
×
1390
            if fee >= prev_fee:  # check monotonicity
×
1391
                raise RequestCorrupted(f'fees must be in decreasing order')
×
1392
            prev_fee = fee
×
1393
        return res
×
1394

1395
    async def get_server_banner(self) -> str:
5✔
1396
        # do request
1397
        res = await self.session.send_request('server.banner')
×
1398
        # check response
1399
        if not isinstance(res, str):
×
1400
            raise RequestCorrupted(f'{res!r} should be a str')
×
1401
        return res
×
1402

1403
    async def get_donation_address(self) -> str:
5✔
1404
        # do request
1405
        res = await self.session.send_request('server.donation_address')
×
1406
        # check response
1407
        if not res:  # ignore empty string
×
1408
            return ''
×
1409
        if not bitcoin.is_address(res):
×
1410
            # note: do not hard-fail -- allow server to use future-type
1411
            #       bitcoin address we do not recognize
1412
            self.logger.info(f"invalid donation address from server: {repr(res)}")
×
1413
            res = ''
×
1414
        return res
×
1415

1416
    async def get_relay_fee(self) -> int:
5✔
1417
        """Returns the min relay feerate in sat/kbyte."""
1418
        # do request
1419
        res = await self.session.send_request('blockchain.relayfee')
×
1420
        # check response
1421
        assert_non_negative_int_or_float(res)
×
1422
        relayfee = int(res * bitcoin.COIN)
×
1423
        relayfee = max(0, relayfee)
×
1424
        return relayfee
×
1425

1426
    async def get_estimatefee(self, num_blocks: int) -> int:
5✔
1427
        """Returns a feerate estimate for getting confirmed within
1428
        num_blocks blocks, in sat/kbyte.
1429
        Returns -1 if the server could not provide an estimate.
1430
        """
1431
        if not is_non_negative_integer(num_blocks):
×
1432
            raise Exception(f"{repr(num_blocks)} is not a num_blocks")
×
1433
        # do request
1434
        try:
×
1435
            res = await self.session.send_request('blockchain.estimatefee', [num_blocks])
×
1436
        except aiorpcx.jsonrpc.ProtocolError as e:
×
1437
            # The protocol spec says the server itself should already have returned -1
1438
            # if it cannot provide an estimate, however apparently "electrs" does not conform
1439
            # and sends an error instead. Convert it here:
1440
            if "cannot estimate fee" in e.message:
×
1441
                res = -1
×
1442
            else:
1443
                raise
×
1444
        except aiorpcx.jsonrpc.RPCError as e:
×
1445
            # The protocol spec says the server itself should already have returned -1
1446
            # if it cannot provide an estimate. "Fulcrum" often sends:
1447
            #   aiorpcx.jsonrpc.RPCError: (-32603, 'internal error: bitcoind request timed out')
1448
            if e.code == JSONRPC.INTERNAL_ERROR:
×
1449
                res = -1
×
1450
            else:
1451
                raise
×
1452
        # check response
1453
        if res != -1:
×
1454
            assert_non_negative_int_or_float(res)
×
1455
            res = int(res * bitcoin.COIN)
×
1456
        return res
×
1457

1458

1459
def _assert_header_does_not_check_against_any_chain(header: dict) -> None:
5✔
1460
    chain_bad = blockchain.check_header(header)
5✔
1461
    if chain_bad:
5✔
1462
        raise Exception('bad_header must not check!')
×
1463

1464

1465
def check_cert(host, cert):
5✔
1466
    try:
×
1467
        b = pem.dePem(cert, 'CERTIFICATE')
×
1468
        x = x509.X509(b)
×
1469
    except Exception:
×
1470
        traceback.print_exc(file=sys.stdout)
×
1471
        return
×
1472

1473
    try:
×
1474
        x.check_date()
×
1475
        expired = False
×
1476
    except Exception:
×
1477
        expired = True
×
1478

1479
    m = "host: %s\n"%host
×
1480
    m += "has_expired: %s\n"% expired
×
1481
    util.print_msg(m)
×
1482

1483

1484
# Used by tests
1485
def _match_hostname(name, val):
5✔
1486
    if val == name:
×
1487
        return True
×
1488

1489
    return val.startswith('*.') and name.endswith(val[1:])
×
1490

1491

1492
def test_certificates():
5✔
1493
    from .simple_config import SimpleConfig
×
1494
    config = SimpleConfig()
×
1495
    mydir = os.path.join(config.path, "certs")
×
1496
    certs = os.listdir(mydir)
×
1497
    for c in certs:
×
1498
        p = os.path.join(mydir,c)
×
1499
        with open(p, encoding='utf-8') as f:
×
1500
            cert = f.read()
×
1501
        check_cert(c, cert)
×
1502

1503
if __name__ == "__main__":
5✔
1504
    test_certificates()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc