• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

spesmilo / electrum / 4993374941347840

06 Jun 2025 04:42PM UTC coverage: 59.782% (+0.004%) from 59.778%
4993374941347840

push

CirrusCI

SomberNight
interface: small clean-up. intro ChainResolutionMode.

- type hints
- minor API changes
- no functional changes

29 of 42 new or added lines in 4 files covered. (69.05%)

2 existing lines in 2 files now uncovered.

21901 of 36635 relevant lines covered (59.78%)

2.99 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

35.86
/electrum/interface.py
1
#!/usr/bin/env python
2
#
3
# Electrum - lightweight Bitcoin client
4
# Copyright (C) 2011 thomasv@gitorious
5
#
6
# Permission is hereby granted, free of charge, to any person
7
# obtaining a copy of this software and associated documentation files
8
# (the "Software"), to deal in the Software without restriction,
9
# including without limitation the rights to use, copy, modify, merge,
10
# publish, distribute, sublicense, and/or sell copies of the Software,
11
# and to permit persons to whom the Software is furnished to do so,
12
# subject to the following conditions:
13
#
14
# The above copyright notice and this permission notice shall be
15
# included in all copies or substantial portions of the Software.
16
#
17
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
18
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
19
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
20
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
21
# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
22
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
23
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
# SOFTWARE.
25
import os
5✔
26
import re
5✔
27
import ssl
5✔
28
import sys
5✔
29
import time
5✔
30
import traceback
5✔
31
import asyncio
5✔
32
import socket
5✔
33
from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set, NamedTuple, Any, Sequence, Dict
5✔
34
from collections import defaultdict
5✔
35
from ipaddress import IPv4Network, IPv6Network, ip_address, IPv6Address, IPv4Address
5✔
36
import itertools
5✔
37
import logging
5✔
38
import hashlib
5✔
39
import functools
5✔
40
import random
5✔
41
import enum
5✔
42

43
import aiorpcx
5✔
44
from aiorpcx import RPCSession, Notification, NetAddress, NewlineFramer
5✔
45
from aiorpcx.curio import timeout_after, TaskTimeout
5✔
46
from aiorpcx.jsonrpc import JSONRPC, CodeMessageError
5✔
47
from aiorpcx.rawsocket import RSClient, RSTransport
5✔
48
import certifi
5✔
49

50
from .util import (ignore_exceptions, log_exceptions, bfh, ESocksProxy,
5✔
51
                   is_integer, is_non_negative_integer, is_hash256_str, is_hex_str,
52
                   is_int_or_float, is_non_negative_int_or_float, OldTaskGroup)
53
from . import util
5✔
54
from . import x509
5✔
55
from . import pem
5✔
56
from . import version
5✔
57
from . import blockchain
5✔
58
from .blockchain import Blockchain, HEADER_SIZE
5✔
59
from . import bitcoin
5✔
60
from . import constants
5✔
61
from .i18n import _
5✔
62
from .logging import Logger
5✔
63
from .transaction import Transaction
5✔
64
from .fee_policy import FEE_ETA_TARGETS
5✔
65

66
if TYPE_CHECKING:
5✔
67
    from .network import Network
×
68
    from .simple_config import SimpleConfig
×
69

70

71
ca_path = certifi.where()
5✔
72

73
BUCKET_NAME_OF_ONION_SERVERS = 'onion'
5✔
74

75
_KNOWN_NETWORK_PROTOCOLS = {'t', 's'}
5✔
76
PREFERRED_NETWORK_PROTOCOL = 's'
5✔
77
assert PREFERRED_NETWORK_PROTOCOL in _KNOWN_NETWORK_PROTOCOLS
5✔
78

79

80
class NetworkTimeout:
5✔
81
    # seconds
82
    class Generic:
5✔
83
        NORMAL = 30
5✔
84
        RELAXED = 45
5✔
85
        MOST_RELAXED = 600
5✔
86

87
    class Urgent(Generic):
5✔
88
        NORMAL = 10
5✔
89
        RELAXED = 20
5✔
90
        MOST_RELAXED = 60
5✔
91

92

93
def assert_non_negative_integer(val: Any) -> None:
5✔
94
    if not is_non_negative_integer(val):
×
95
        raise RequestCorrupted(f'{val!r} should be a non-negative integer')
×
96

97

98
def assert_integer(val: Any) -> None:
5✔
99
    if not is_integer(val):
×
100
        raise RequestCorrupted(f'{val!r} should be an integer')
×
101

102

103
def assert_int_or_float(val: Any) -> None:
5✔
104
    if not is_int_or_float(val):
×
105
        raise RequestCorrupted(f'{val!r} should be int or float')
×
106

107

108
def assert_non_negative_int_or_float(val: Any) -> None:
5✔
109
    if not is_non_negative_int_or_float(val):
×
110
        raise RequestCorrupted(f'{val!r} should be a non-negative int or float')
×
111

112

113
def assert_hash256_str(val: Any) -> None:
5✔
114
    if not is_hash256_str(val):
×
115
        raise RequestCorrupted(f'{val!r} should be a hash256 str')
×
116

117

118
def assert_hex_str(val: Any) -> None:
5✔
119
    if not is_hex_str(val):
×
120
        raise RequestCorrupted(f'{val!r} should be a hex str')
×
121

122

123
def assert_dict_contains_field(d: Any, *, field_name: str) -> Any:
5✔
124
    if not isinstance(d, dict):
×
125
        raise RequestCorrupted(f'{d!r} should be a dict')
×
126
    if field_name not in d:
×
127
        raise RequestCorrupted(f'required field {field_name!r} missing from dict')
×
128
    return d[field_name]
×
129

130

131
def assert_list_or_tuple(val: Any) -> None:
5✔
132
    if not isinstance(val, (list, tuple)):
×
133
        raise RequestCorrupted(f'{val!r} should be a list or tuple')
×
134

135

136
class ChainResolutionMode(enum.Enum):
5✔
137
    CATCHUP = enum.auto()
5✔
138
    BACKWARD = enum.auto()
5✔
139
    BINARY = enum.auto()
5✔
140
    FORK = enum.auto()
5✔
141
    NO_FORK = enum.auto()
5✔
142

143

144
class NotificationSession(RPCSession):
5✔
145

146
    def __init__(self, *args, interface: 'Interface', **kwargs):
5✔
147
        super(NotificationSession, self).__init__(*args, **kwargs)
×
148
        self.subscriptions = defaultdict(list)
×
149
        self.cache = {}
×
150
        self._msg_counter = itertools.count(start=1)
×
151
        self.interface = interface
×
152
        self.taskgroup = interface.taskgroup
×
153
        self.cost_hard_limit = 0  # disable aiorpcx resource limits
×
154

155
    async def handle_request(self, request):
5✔
156
        self.maybe_log(f"--> {request}")
×
157
        try:
×
158
            if isinstance(request, Notification):
×
159
                params, result = request.args[:-1], request.args[-1]
×
160
                key = self.get_hashable_key_for_rpc_call(request.method, params)
×
161
                if key in self.subscriptions:
×
162
                    self.cache[key] = result
×
163
                    for queue in self.subscriptions[key]:
×
164
                        await queue.put(request.args)
×
165
                else:
166
                    raise Exception(f'unexpected notification')
×
167
            else:
168
                raise Exception(f'unexpected request. not a notification')
×
169
        except Exception as e:
×
170
            self.interface.logger.info(f"error handling request {request}. exc: {repr(e)}")
×
171
            await self.close()
×
172

173
    async def send_request(self, *args, timeout=None, **kwargs):
5✔
174
        # note: semaphores/timeouts/backpressure etc are handled by
175
        # aiorpcx. the timeout arg here in most cases should not be set
176
        msg_id = next(self._msg_counter)
×
177
        self.maybe_log(f"<-- {args} {kwargs} (id: {msg_id})")
×
178
        try:
×
179
            # note: RPCSession.send_request raises TaskTimeout in case of a timeout.
180
            # TaskTimeout is a subclass of CancelledError, which is *suppressed* in TaskGroups
181
            response = await util.wait_for2(
×
182
                super().send_request(*args, **kwargs),
183
                timeout)
184
        except (TaskTimeout, asyncio.TimeoutError) as e:
×
185
            self.maybe_log(f"--> request timed out: {args} (id: {msg_id})")
×
186
            raise RequestTimedOut(f'request timed out: {args} (id: {msg_id})') from e
×
187
        except CodeMessageError as e:
×
188
            self.maybe_log(f"--> {repr(e)} (id: {msg_id})")
×
189
            raise
×
190
        except BaseException as e:  # cancellations, etc. are useful for debugging
×
191
            self.maybe_log(f"--> {repr(e)} (id: {msg_id})")
×
192
            raise
×
193
        else:
194
            self.maybe_log(f"--> {response} (id: {msg_id})")
×
195
            return response
×
196

197
    def set_default_timeout(self, timeout):
5✔
198
        assert hasattr(self, "sent_request_timeout")  # in base class
×
199
        self.sent_request_timeout = timeout
×
200
        assert hasattr(self, "max_send_delay")        # in base class
×
201
        self.max_send_delay = timeout
×
202

203
    async def subscribe(self, method: str, params: List, queue: asyncio.Queue):
5✔
204
        # note: until the cache is written for the first time,
205
        # each 'subscribe' call might make a request on the network.
206
        key = self.get_hashable_key_for_rpc_call(method, params)
×
207
        self.subscriptions[key].append(queue)
×
208
        if key in self.cache:
×
209
            result = self.cache[key]
×
210
        else:
211
            result = await self.send_request(method, params)
×
212
            self.cache[key] = result
×
213
        await queue.put(params + [result])
×
214

215
    def unsubscribe(self, queue):
5✔
216
        """Unsubscribe a callback to free object references to enable GC."""
217
        # note: we can't unsubscribe from the server, so we keep receiving
218
        # subsequent notifications
219
        for v in self.subscriptions.values():
×
220
            if queue in v:
×
221
                v.remove(queue)
×
222

223
    @classmethod
5✔
224
    def get_hashable_key_for_rpc_call(cls, method, params):
5✔
225
        """Hashable index for subscriptions and cache"""
226
        return str(method) + repr(params)
×
227

228
    def maybe_log(self, msg: str) -> None:
5✔
229
        if not self.interface: return
×
230
        if self.interface.debug or self.interface.network.debug:
×
231
            self.interface.logger.debug(msg)
×
232

233
    def default_framer(self):
5✔
234
        # overridden so that max_size can be customized
235
        max_size = self.interface.network.config.NETWORK_MAX_INCOMING_MSG_SIZE
×
236
        assert max_size > 500_000, f"{max_size=} (< 500_000) is too small"
×
237
        return NewlineFramer(max_size=max_size)
×
238

239
    async def close(self, *, force_after: int = None):
5✔
240
        """Closes the connection and waits for it to be closed.
241
        We try to flush buffered data to the wire, which can take some time.
242
        """
243
        if force_after is None:
×
244
            # We give up after a while and just abort the connection.
245
            # Note: specifically if the server is running Fulcrum, waiting seems hopeless,
246
            #       the connection must be aborted (see https://github.com/cculianu/Fulcrum/issues/76)
247
            # Note: if the ethernet cable was pulled or wifi disconnected, that too might
248
            #       wait until this timeout is triggered
249
            force_after = 1  # seconds
×
250
        await super().close(force_after=force_after)
×
251

252

253
class NetworkException(Exception): pass
5✔
254

255

256
class GracefulDisconnect(NetworkException):
5✔
257
    log_level = logging.INFO
5✔
258

259
    def __init__(self, *args, log_level=None, **kwargs):
5✔
260
        Exception.__init__(self, *args, **kwargs)
5✔
261
        if log_level is not None:
5✔
262
            self.log_level = log_level
×
263

264

265
class RequestTimedOut(GracefulDisconnect):
5✔
266
    def __str__(self):
5✔
267
        return _("Network request timed out.")
×
268

269

270
class RequestCorrupted(Exception): pass
5✔
271

272
class ErrorParsingSSLCert(Exception): pass
5✔
273
class ErrorGettingSSLCertFromServer(Exception): pass
5✔
274
class ErrorSSLCertFingerprintMismatch(Exception): pass
5✔
275
class InvalidOptionCombination(Exception): pass
5✔
276
class ConnectError(NetworkException): pass
5✔
277

278

279
class _RSClient(RSClient):
5✔
280
    async def create_connection(self):
5✔
281
        try:
×
282
            return await super().create_connection()
×
283
        except OSError as e:
×
284
            # note: using "from e" here will set __cause__ of ConnectError
285
            raise ConnectError(e) from e
×
286

287

288
class PaddedRSTransport(RSTransport):
5✔
289
    """A raw socket transport that provides basic countermeasures against traffic analysis
290
    by padding the jsonrpc payload with whitespaces to have ~uniform-size TCP packets.
291
    (it is assumed that a network observer does not see plaintext transport contents,
292
    due to it being wrapped e.g. in TLS)
293
    """
294

295
    MIN_PACKET_SIZE = 1024
5✔
296
    WAIT_FOR_BUFFER_GROWTH_SECONDS = 1.0
5✔
297

298
    session: Optional['RPCSession']
5✔
299

300
    def __init__(self, *args, **kwargs):
5✔
301
        RSTransport.__init__(self, *args, **kwargs)
×
302
        self._sbuffer = bytearray()  # "send buffer"
×
303
        self._sbuffer_task = None  # type: Optional[asyncio.Task]
×
304
        self._sbuffer_has_data_evt = asyncio.Event()
×
305
        self._last_send = time.monotonic()
×
306
        self._force_send = False  # type: bool
×
307

308
    # note: this does not call super().write() but is a complete reimplementation
309
    async def write(self, message):
5✔
310
        await self._can_send.wait()
×
311
        if self.is_closing():
×
312
            return
×
313
        framed_message = self._framer.frame(message)
×
314
        self._sbuffer += framed_message
×
315
        self._sbuffer_has_data_evt.set()
×
316
        self._maybe_consume_sbuffer()
×
317

318
    def _maybe_consume_sbuffer(self) -> None:
5✔
319
        """Maybe take some data from sbuffer and send it on the wire."""
320
        if not self._can_send.is_set() or self.is_closing():
×
321
            return
×
322
        buf = self._sbuffer
×
323
        if not buf:
×
324
            return
×
325
        # if there is enough data in the buffer, or if we haven't sent in a while, send now:
326
        if not (
×
327
            self._force_send
328
            or len(buf) >= self.MIN_PACKET_SIZE
329
            or self._last_send + self.WAIT_FOR_BUFFER_GROWTH_SECONDS < time.monotonic()
330
        ):
331
            return
×
332
        assert buf[-2:] in (b"}\n", b"]\n"), f"unexpected json-rpc terminator: {buf[-2:]=!r}"
×
333
        # either (1) pad length to next power of two, to create "lsize" packet:
334
        payload_lsize = len(buf)
×
335
        total_lsize = max(self.MIN_PACKET_SIZE, 2 ** (payload_lsize.bit_length()))
×
336
        npad_lsize = total_lsize - payload_lsize
×
337
        # or if that wasted a lot of bandwidth with padding, (2) defer sending some messages
338
        # and create a packet with half that size ("ssize", s for small)
339
        total_ssize = max(self.MIN_PACKET_SIZE, total_lsize // 2)
×
340
        payload_ssize = buf.rfind(b"\n", 0, total_ssize)
×
341
        if payload_ssize != -1:
×
342
            payload_ssize += 1  # for "\n" char
×
343
            npad_ssize = total_ssize - payload_ssize
×
344
        else:
345
            npad_ssize = float("inf")
×
346
        # decide between (1) and (2):
347
        if self._force_send or npad_lsize <= npad_ssize:
×
348
            # (1) create "lsize" packet: consume full buffer
349
            npad = npad_lsize
×
350
            p_idx = payload_lsize
×
351
        else:
352
            # (2) create "ssize" packet: consume some, but defer some for later
353
            npad = npad_ssize
×
354
            p_idx = payload_ssize
×
355
        # pad by adding spaces near end
356
        # self.session.maybe_log(
357
        #     f"PaddedRSTransport. calling low-level write(). "
358
        #     f"chose between (lsize:{payload_lsize}+{npad_lsize}, ssize:{payload_ssize}+{npad_ssize}). "
359
        #     f"won: {'tie' if npad_lsize == npad_ssize else 'lsize' if npad_lsize < npad_ssize else 'ssize'}."
360
        # )
361
        json_rpc_terminator = buf[p_idx-2:p_idx]
×
362
        assert json_rpc_terminator in (b"}\n", b"]\n"), f"unexpected {json_rpc_terminator=!r}"
×
363
        buf2 = buf[:p_idx-2] + (npad * b" ") + json_rpc_terminator
×
364
        self._asyncio_transport.write(buf2)
×
365
        self._last_send = time.monotonic()
×
366
        del self._sbuffer[:p_idx]
×
367
        if not self._sbuffer:
×
368
            self._sbuffer_has_data_evt.clear()
×
369

370
    async def _poll_sbuffer(self):
5✔
371
        while not self.is_closing():
×
372
            await self._can_send.wait()
×
373
            await self._sbuffer_has_data_evt.wait()  # to avoid busy-waiting
×
374
            self._maybe_consume_sbuffer()
×
375
            # If there is still data in the buffer, sleep until it would time out.
376
            # note: If the transport is ~idle, when we wake up, we will send the current buf data,
377
            #       but if busy, we might wake up to completely new buffer contents. Either is fine.
378
            if len(self._sbuffer) > 0:
×
379
                timeout_abs = self._last_send + self.WAIT_FOR_BUFFER_GROWTH_SECONDS
×
380
                timeout_rel = max(0.0, timeout_abs - time.monotonic())
×
381
                await asyncio.sleep(timeout_rel)
×
382

383
    def connection_made(self, transport: asyncio.BaseTransport):
5✔
384
        super().connection_made(transport)
×
385
        if isinstance(self.session, NotificationSession):
×
386
            coro = self.session.taskgroup.spawn(self._poll_sbuffer())
×
387
            self._sbuffer_task = self.loop.create_task(coro)
×
388
        else:
389
            # This a short-lived "fetch_certificate"-type session.
390
            # No polling here, we always force-empty the buffer.
391
            self._force_send = True
×
392

393

394
class ServerAddr:
5✔
395

396
    def __init__(self, host: str, port: Union[int, str], *, protocol: str = None):
5✔
397
        assert isinstance(host, str), repr(host)
5✔
398
        if protocol is None:
5✔
399
            protocol = 's'
×
400
        if not host:
5✔
401
            raise ValueError('host must not be empty')
×
402
        if host[0] == '[' and host[-1] == ']':  # IPv6
5✔
403
            host = host[1:-1]
5✔
404
        try:
5✔
405
            net_addr = NetAddress(host, port)  # this validates host and port
5✔
406
        except Exception as e:
5✔
407
            raise ValueError(f"cannot construct ServerAddr: invalid host or port (host={host}, port={port})") from e
5✔
408
        if protocol not in _KNOWN_NETWORK_PROTOCOLS:
5✔
409
            raise ValueError(f"invalid network protocol: {protocol}")
×
410
        self.host = str(net_addr.host)  # canonical form (if e.g. IPv6 address)
5✔
411
        self.port = int(net_addr.port)
5✔
412
        self.protocol = protocol
5✔
413
        self._net_addr_str = str(net_addr)
5✔
414

415
    @classmethod
5✔
416
    def from_str(cls, s: str) -> 'ServerAddr':
5✔
417
        """Constructs a ServerAddr or raises ValueError."""
418
        # host might be IPv6 address, hence do rsplit:
419
        host, port, protocol = str(s).rsplit(':', 2)
5✔
420
        return ServerAddr(host=host, port=port, protocol=protocol)
5✔
421

422
    @classmethod
5✔
423
    def from_str_with_inference(cls, s: str) -> Optional['ServerAddr']:
5✔
424
        """Construct ServerAddr from str, guessing missing details.
425
        Does not raise - just returns None if guessing failed.
426
        Ongoing compatibility not guaranteed.
427
        """
428
        if not s:
5✔
429
            return None
×
430
        host = ""
5✔
431
        if s[0] == "[" and "]" in s:  # IPv6 address
5✔
432
            host_end = s.index("]")
5✔
433
            host = s[1:host_end]
5✔
434
            s = s[host_end+1:]
5✔
435
        items = str(s).rsplit(':', 2)
5✔
436
        if len(items) < 2:
5✔
437
            return None  # although maybe we could guess the port too?
5✔
438
        host = host or items[0]
5✔
439
        port = items[1]
5✔
440
        if len(items) >= 3:
5✔
441
            protocol = items[2]
5✔
442
        else:
443
            protocol = PREFERRED_NETWORK_PROTOCOL
5✔
444
        try:
5✔
445
            return ServerAddr(host=host, port=port, protocol=protocol)
5✔
446
        except ValueError:
5✔
447
            return None
5✔
448

449
    def to_friendly_name(self) -> str:
5✔
450
        # note: this method is closely linked to from_str_with_inference
451
        if self.protocol == 's':  # hide trailing ":s"
5✔
452
            return self.net_addr_str()
5✔
453
        return str(self)
5✔
454

455
    def __str__(self):
5✔
456
        return '{}:{}'.format(self.net_addr_str(), self.protocol)
5✔
457

458
    def to_json(self) -> str:
5✔
459
        return str(self)
×
460

461
    def __repr__(self):
5✔
462
        return f'<ServerAddr host={self.host} port={self.port} protocol={self.protocol}>'
×
463

464
    def net_addr_str(self) -> str:
5✔
465
        return self._net_addr_str
5✔
466

467
    def __eq__(self, other):
5✔
468
        if not isinstance(other, ServerAddr):
5✔
469
            return False
×
470
        return (self.host == other.host
5✔
471
                and self.port == other.port
472
                and self.protocol == other.protocol)
473

474
    def __ne__(self, other):
5✔
475
        return not (self == other)
×
476

477
    def __hash__(self):
5✔
478
        return hash((self.host, self.port, self.protocol))
×
479

480

481
def _get_cert_path_for_host(*, config: 'SimpleConfig', host: str) -> str:
5✔
482
    filename = host
5✔
483
    try:
5✔
484
        ip = ip_address(host)
5✔
485
    except ValueError:
5✔
486
        pass
5✔
487
    else:
488
        if isinstance(ip, IPv6Address):
×
489
            filename = f"ipv6_{ip.packed.hex()}"
×
490
    return os.path.join(config.path, 'certs', filename)
5✔
491

492

493
class Interface(Logger):
5✔
494

495
    def __init__(self, *, network: 'Network', server: ServerAddr):
5✔
496
        self.ready = network.asyncio_loop.create_future()
5✔
497
        self.got_disconnected = asyncio.Event()
5✔
498
        self.server = server
5✔
499
        Logger.__init__(self)
5✔
500
        assert network.config.path
5✔
501
        self.cert_path = _get_cert_path_for_host(config=network.config, host=self.host)
5✔
502
        self.blockchain = None  # type: Optional[Blockchain]
5✔
503
        self._requested_chunks = set()  # type: Set[int]
5✔
504
        self.network = network
5✔
505
        self.session = None  # type: Optional[NotificationSession]
5✔
506
        self._ipaddr_bucket = None
5✔
507
        # Set up proxy.
508
        # - for servers running on localhost, the proxy is not used. If user runs their own server
509
        #   on same machine, this lets them enable the proxy (which is used for e.g. FX rates).
510
        #   note: we could maybe relax this further and bypass the proxy for all private
511
        #         addresses...? e.g. 192.168.x.x
512
        if util.is_localhost(server.host):
5✔
513
            self.logger.info(f"looks like localhost: not using proxy for this server")
×
514
            self.proxy = None
×
515
        else:
516
            self.proxy = ESocksProxy.from_network_settings(network)
5✔
517

518
        # Latest block header and corresponding height, as claimed by the server.
519
        # Note that these values are updated before they are verified.
520
        # Especially during initial header sync, verification can take a long time.
521
        # Failing verification will get the interface closed.
522
        self.tip_header = None  # type: Optional[dict]
5✔
523
        self.tip = 0
5✔
524

525
        self.fee_estimates_eta = {}  # type: Dict[int, int]
5✔
526

527
        # Dump network messages (only for this interface).  Set at runtime from the console.
528
        self.debug = False
5✔
529

530
        self.taskgroup = OldTaskGroup()
5✔
531

532
        async def spawn_task():
5✔
533
            task = await self.network.taskgroup.spawn(self.run())
5✔
534
            task.set_name(f"interface::{str(server)}")
5✔
535
        asyncio.run_coroutine_threadsafe(spawn_task(), self.network.asyncio_loop)
5✔
536

537
    @property
5✔
538
    def host(self):
5✔
539
        return self.server.host
5✔
540

541
    @property
5✔
542
    def port(self):
5✔
543
        return self.server.port
×
544

545
    @property
5✔
546
    def protocol(self):
5✔
547
        return self.server.protocol
×
548

549
    def diagnostic_name(self):
5✔
550
        return self.server.net_addr_str()
5✔
551

552
    def __str__(self):
5✔
553
        return f"<Interface {self.diagnostic_name()}>"
×
554

555
    async def is_server_ca_signed(self, ca_ssl_context: ssl.SSLContext) -> bool:
5✔
556
        """Given a CA enforcing SSL context, returns True if the connection
557
        can be established. Returns False if the server has a self-signed
558
        certificate but otherwise is okay. Any other failures raise.
559
        """
560
        try:
×
NEW
561
            await self.open_session(ssl_context=ca_ssl_context, exit_early=True)
×
562
        except ConnectError as e:
×
563
            cause = e.__cause__
×
564
            if (isinstance(cause, ssl.SSLCertVerificationError)
×
565
                    and cause.reason == 'CERTIFICATE_VERIFY_FAILED'
566
                    and cause.verify_code == 18):  # "self signed certificate"
567
                # Good. We will use this server as self-signed.
568
                return False
×
569
            # Not good. Cannot use this server.
570
            raise
×
571
        # Good. We will use this server as CA-signed.
572
        return True
×
573

574
    async def _try_saving_ssl_cert_for_first_time(self, ca_ssl_context: ssl.SSLContext) -> None:
5✔
575
        ca_signed = await self.is_server_ca_signed(ca_ssl_context)
×
576
        if ca_signed:
×
577
            if self._get_expected_fingerprint():
×
578
                raise InvalidOptionCombination("cannot use --serverfingerprint with CA signed servers")
×
579
            with open(self.cert_path, 'w') as f:
×
580
                # empty file means this is CA signed, not self-signed
581
                f.write('')
×
582
        else:
583
            await self._save_certificate()
×
584

585
    def _is_saved_ssl_cert_available(self):
5✔
586
        if not os.path.exists(self.cert_path):
×
587
            return False
×
588
        with open(self.cert_path, 'r') as f:
×
589
            contents = f.read()
×
590
        if contents == '':  # CA signed
×
591
            if self._get_expected_fingerprint():
×
592
                raise InvalidOptionCombination("cannot use --serverfingerprint with CA signed servers")
×
593
            return True
×
594
        # pinned self-signed cert
595
        try:
×
596
            b = pem.dePem(contents, 'CERTIFICATE')
×
597
        except SyntaxError as e:
×
598
            self.logger.info(f"error parsing already saved cert: {e}")
×
599
            raise ErrorParsingSSLCert(e) from e
×
600
        try:
×
601
            x = x509.X509(b)
×
602
        except Exception as e:
×
603
            self.logger.info(f"error parsing already saved cert: {e}")
×
604
            raise ErrorParsingSSLCert(e) from e
×
605
        try:
×
606
            x.check_date()
×
607
        except x509.CertificateError as e:
×
608
            self.logger.info(f"certificate has expired: {e}")
×
609
            os.unlink(self.cert_path)  # delete pinned cert only in this case
×
610
            return False
×
NEW
611
        self._verify_certificate_fingerprint(bytes(b))
×
612
        return True
×
613

614
    async def _get_ssl_context(self) -> Optional[ssl.SSLContext]:
5✔
615
        if self.protocol != 's':
×
616
            # using plaintext TCP
617
            return None
×
618

619
        # see if we already have cert for this server; or get it for the first time
620
        ca_sslc = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH, cafile=ca_path)
×
621
        if not self._is_saved_ssl_cert_available():
×
622
            try:
×
623
                await self._try_saving_ssl_cert_for_first_time(ca_sslc)
×
624
            except (OSError, ConnectError, aiorpcx.socks.SOCKSError) as e:
×
625
                raise ErrorGettingSSLCertFromServer(e) from e
×
626
        # now we have a file saved in our certificate store
627
        siz = os.stat(self.cert_path).st_size
×
628
        if siz == 0:
×
629
            # CA signed cert
630
            sslc = ca_sslc
×
631
        else:
632
            # pinned self-signed cert
633
            sslc = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH, cafile=self.cert_path)
×
634
            # note: Flag "ssl.VERIFY_X509_STRICT" is enabled by default in python 3.13+ (disabled in older versions).
635
            #       We explicitly disable it as it breaks lots of servers.
636
            sslc.verify_flags &= ~ssl.VERIFY_X509_STRICT
×
637
            sslc.check_hostname = False
×
638
        return sslc
×
639

640
    def handle_disconnect(func):
5✔
641
        @functools.wraps(func)
5✔
642
        async def wrapper_func(self: 'Interface', *args, **kwargs):
5✔
643
            try:
×
644
                return await func(self, *args, **kwargs)
×
645
            except GracefulDisconnect as e:
×
646
                self.logger.log(e.log_level, f"disconnecting due to {repr(e)}")
×
647
            except aiorpcx.jsonrpc.RPCError as e:
×
648
                self.logger.warning(f"disconnecting due to {repr(e)}")
×
649
                self.logger.debug(f"(disconnect) trace for {repr(e)}", exc_info=True)
×
650
            finally:
651
                self.got_disconnected.set()
×
652
                # Make sure taskgroup gets cleaned-up. This explicit clean-up is needed here
653
                # in case the "with taskgroup" ctx mgr never got a chance to run:
654
                await self.taskgroup.cancel_remaining()
×
655
                await self.network.connection_down(self)
×
656
                # if was not 'ready' yet, schedule waiting coroutines:
657
                self.ready.cancel()
×
658
        return wrapper_func
5✔
659

660
    @ignore_exceptions  # do not kill network.taskgroup
5✔
661
    @log_exceptions
5✔
662
    @handle_disconnect
5✔
663
    async def run(self):
5✔
664
        try:
×
665
            ssl_context = await self._get_ssl_context()
×
666
        except (ErrorParsingSSLCert, ErrorGettingSSLCertFromServer) as e:
×
667
            self.logger.info(f'disconnecting due to: {repr(e)}')
×
668
            return
×
669
        try:
×
NEW
670
            await self.open_session(ssl_context=ssl_context)
×
671
        except (asyncio.CancelledError, ConnectError, aiorpcx.socks.SOCKSError) as e:
×
672
            # make SSL errors for main interface more visible (to help servers ops debug cert pinning issues)
673
            if (isinstance(e, ConnectError) and isinstance(e.__cause__, ssl.SSLError)
×
674
                    and self.is_main_server() and not self.network.auto_connect):
675
                self.logger.warning(f'Cannot connect to main server due to SSL error '
×
676
                                    f'(maybe cert changed compared to "{self.cert_path}"). Exc: {repr(e)}')
677
            else:
678
                self.logger.info(f'disconnecting due to: {repr(e)}')
×
679
            return
×
680

681
    def _mark_ready(self) -> None:
5✔
682
        if self.ready.cancelled():
×
683
            raise GracefulDisconnect('conn establishment was too slow; *ready* future was cancelled')
×
684
        if self.ready.done():
×
685
            return
×
686

687
        assert self.tip_header
×
688
        chain = blockchain.check_header(self.tip_header)
×
689
        if not chain:
×
690
            self.blockchain = blockchain.get_best_chain()
×
691
        else:
692
            self.blockchain = chain
×
693
        assert self.blockchain is not None
×
694

695
        self.logger.info(f"set blockchain with height {self.blockchain.height()}")
×
696

697
        self.ready.set_result(1)
×
698

699
    def is_connected_and_ready(self) -> bool:
5✔
700
        return self.ready.done() and not self.got_disconnected.is_set()
×
701

702
    async def _save_certificate(self) -> None:
5✔
703
        if not os.path.exists(self.cert_path):
×
704
            # we may need to retry this a few times, in case the handshake hasn't completed
705
            for _ in range(10):
×
706
                dercert = await self._fetch_certificate()
×
707
                if dercert:
×
708
                    self.logger.info("succeeded in getting cert")
×
709
                    self._verify_certificate_fingerprint(dercert)
×
710
                    with open(self.cert_path, 'w') as f:
×
711
                        cert = ssl.DER_cert_to_PEM_cert(dercert)
×
712
                        # workaround android bug
713
                        cert = re.sub("([^\n])-----END CERTIFICATE-----","\\1\n-----END CERTIFICATE-----",cert)
×
714
                        f.write(cert)
×
715
                        # even though close flushes, we can't fsync when closed.
716
                        # and we must flush before fsyncing, cause flush flushes to OS buffer
717
                        # fsync writes to OS buffer to disk
718
                        f.flush()
×
719
                        os.fsync(f.fileno())
×
720
                    break
×
721
                await asyncio.sleep(1)
×
722
            else:
723
                raise GracefulDisconnect("could not get certificate after 10 tries")
×
724

725
    async def _fetch_certificate(self) -> bytes:
5✔
726
        sslc = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_CLIENT)
×
727
        sslc.check_hostname = False
×
728
        sslc.verify_mode = ssl.CERT_NONE
×
729
        async with _RSClient(
×
730
            session_factory=RPCSession,
731
            host=self.host, port=self.port,
732
            ssl=sslc,
733
            proxy=self.proxy,
734
            transport=PaddedRSTransport,
735
        ) as session:
736
            asyncio_transport = session.transport._asyncio_transport  # type: asyncio.BaseTransport
×
737
            ssl_object = asyncio_transport.get_extra_info("ssl_object")  # type: ssl.SSLObject
×
738
            return ssl_object.getpeercert(binary_form=True)
×
739

740
    def _get_expected_fingerprint(self) -> Optional[str]:
5✔
741
        if self.is_main_server():
×
742
            return self.network.config.NETWORK_SERVERFINGERPRINT
×
NEW
743
        return None
×
744

745
    def _verify_certificate_fingerprint(self, certificate: bytes) -> None:
5✔
746
        expected_fingerprint = self._get_expected_fingerprint()
×
747
        if not expected_fingerprint:
×
748
            return
×
749
        fingerprint = hashlib.sha256(certificate).hexdigest()
×
750
        fingerprints_match = fingerprint.lower() == expected_fingerprint.lower()
×
751
        if not fingerprints_match:
×
752
            util.trigger_callback('cert_mismatch')
×
753
            raise ErrorSSLCertFingerprintMismatch('Refusing to connect to server due to cert fingerprint mismatch')
×
754
        self.logger.info("cert fingerprint verification passed")
×
755

756
    async def get_block_header(self, height: int, *, mode: ChainResolutionMode) -> dict:
5✔
757
        if not is_non_negative_integer(height):
×
758
            raise Exception(f"{repr(height)} is not a block height")
×
NEW
759
        self.logger.info(f'requesting block header {height} in {mode=}')
×
760
        # use lower timeout as we usually have network.bhi_lock here
761
        timeout = self.network.get_network_timeout_seconds(NetworkTimeout.Urgent)
×
762
        res = await self.session.send_request('blockchain.block.header', [height], timeout=timeout)
×
763
        return blockchain.deserialize_header(bytes.fromhex(res), height)
×
764

765
    async def request_chunk(
5✔
766
        self,
767
        height: int,
768
        *,
769
        tip: Optional[int] = None,
770
        can_return_early: bool = False,
771
    ) -> Optional[Tuple[bool, int]]:
772
        if not is_non_negative_integer(height):
×
773
            raise Exception(f"{repr(height)} is not a block height")
×
774
        index = height // 2016
×
775
        if can_return_early and index in self._requested_chunks:
×
NEW
776
            return None
×
777
        self.logger.info(f"requesting chunk from height {height}")
×
778
        size = 2016
×
779
        if tip is not None:
×
780
            size = min(size, tip - index * 2016 + 1)
×
781
            size = max(size, 0)
×
782
        try:
×
783
            self._requested_chunks.add(index)
×
784
            res = await self.session.send_request('blockchain.block.headers', [index * 2016, size])
×
785
        finally:
786
            self._requested_chunks.discard(index)
×
787
        assert_dict_contains_field(res, field_name='count')
×
788
        assert_dict_contains_field(res, field_name='hex')
×
789
        assert_dict_contains_field(res, field_name='max')
×
790
        assert_non_negative_integer(res['count'])
×
791
        assert_non_negative_integer(res['max'])
×
792
        assert_hex_str(res['hex'])
×
793
        if len(res['hex']) != HEADER_SIZE * 2 * res['count']:
×
794
            raise RequestCorrupted('inconsistent chunk hex and count')
×
795
        # we never request more than 2016 headers, but we enforce those fit in a single response
796
        if res['max'] < 2016:
×
797
            raise RequestCorrupted(f"server uses too low 'max' count for block.headers: {res['max']} < 2016")
×
798
        if res['count'] != size:
×
799
            raise RequestCorrupted(f"expected {size} headers but only got {res['count']}")
×
800
        conn = self.blockchain.connect_chunk(index, res['hex'])
×
801
        if not conn:
×
802
            return conn, 0
×
803
        return conn, res['count']
×
804

805
    def is_main_server(self) -> bool:
5✔
806
        return (self.network.interface == self or
×
807
                self.network.interface is None and self.network.default_server == self.server)
808

809
    async def open_session(
5✔
810
        self,
811
        *,
812
        ssl_context: Optional[ssl.SSLContext],
813
        exit_early: bool = False,
814
    ):
815
        session_factory = lambda *args, iface=self, **kwargs: NotificationSession(*args, **kwargs, interface=iface)
×
816
        async with _RSClient(
×
817
            session_factory=session_factory,
818
            host=self.host, port=self.port,
819
            ssl=ssl_context,
820
            proxy=self.proxy,
821
            transport=PaddedRSTransport,
822
        ) as session:
823
            self.session = session  # type: NotificationSession
×
824
            self.session.set_default_timeout(self.network.get_network_timeout_seconds(NetworkTimeout.Generic))
×
825
            try:
×
826
                ver = await session.send_request('server.version', [self.client_name(), version.PROTOCOL_VERSION])
×
827
            except aiorpcx.jsonrpc.RPCError as e:
×
828
                raise GracefulDisconnect(e)  # probably 'unsupported protocol version'
×
829
            if exit_early:
×
830
                return
×
831
            if ver[1] != version.PROTOCOL_VERSION:
×
832
                raise GracefulDisconnect(f'server violated protocol-version-negotiation. '
×
833
                                         f'we asked for {version.PROTOCOL_VERSION!r}, they sent {ver[1]!r}')
834
            if not self.network.check_interface_against_healthy_spread_of_connected_servers(self):
×
835
                raise GracefulDisconnect(f'too many connected servers already '
×
836
                                         f'in bucket {self.bucket_based_on_ipaddress()}')
837
            self.logger.info(f"connection established. version: {ver}")
×
838

839
            try:
×
840
                async with self.taskgroup as group:
×
841
                    await group.spawn(self.ping)
×
842
                    await group.spawn(self.request_fee_estimates)
×
843
                    await group.spawn(self.run_fetch_blocks)
×
844
                    await group.spawn(self.monitor_connection)
×
845
            except aiorpcx.jsonrpc.RPCError as e:
×
846
                if e.code in (
×
847
                    JSONRPC.EXCESSIVE_RESOURCE_USAGE,
848
                    JSONRPC.SERVER_BUSY,
849
                    JSONRPC.METHOD_NOT_FOUND,
850
                    JSONRPC.INTERNAL_ERROR,
851
                ):
852
                    log_level = logging.WARNING if self.is_main_server() else logging.INFO
×
853
                    raise GracefulDisconnect(e, log_level=log_level) from e
×
854
                raise
×
855
            finally:
856
                self.got_disconnected.set()  # set this ASAP, ideally before any awaits
×
857

858
    async def monitor_connection(self):
5✔
859
        while True:
×
860
            await asyncio.sleep(1)
×
861
            # If the session/transport is no longer open, we disconnect.
862
            # e.g. if the remote cleanly sends EOF, we would handle that here.
863
            # note: If the user pulls the ethernet cable or disconnects wifi,
864
            #       ideally we would detect that here, so that the GUI/etc can reflect that.
865
            #       - On Android, this seems to work reliably , where asyncio.BaseProtocol.connection_lost()
866
            #         gets called with e.g. ConnectionAbortedError(103, 'Software caused connection abort').
867
            #       - On desktop Linux/Win, it seems BaseProtocol.connection_lost() is not called in such cases.
868
            #         Hence, in practice the connection issue will only be detected the next time we try
869
            #         to send a message (plus timeout), which can take minutes...
870
            if not self.session or self.session.is_closing():
×
871
                raise GracefulDisconnect('session was closed')
×
872

873
    async def ping(self):
5✔
874
        # We periodically send a "ping" msg to make sure the server knows we are still here.
875
        # Adding a bit of randomness generates some noise against traffic analysis.
876
        while True:
×
877
            await asyncio.sleep(random.random() * 300)
×
878
            await self.session.send_request('server.ping')
×
879
            await self._maybe_send_noise()
×
880

881
    async def _maybe_send_noise(self):
5✔
882
        while random.random() < 0.2:
×
883
            await asyncio.sleep(random.random())
×
884
            await self.session.send_request('server.ping')
×
885

886
    async def request_fee_estimates(self):
5✔
887
        while True:
×
888
            async with OldTaskGroup() as group:
×
889
                fee_tasks = []
×
890
                for i in FEE_ETA_TARGETS[0:-1]:
×
891
                    fee_tasks.append((i, await group.spawn(self.get_estimatefee(i))))
×
892
            for nblock_target, task in fee_tasks:
×
893
                fee = task.result()
×
894
                if fee < 0: continue
×
895
                assert isinstance(fee, int)
×
896
                self.fee_estimates_eta[nblock_target] = fee
×
897
            self.network.update_fee_estimates()
×
898
            await asyncio.sleep(60)
×
899

900
    async def close(self, *, force_after: int = None):
5✔
901
        """Closes the connection and waits for it to be closed.
902
        We try to flush buffered data to the wire, which can take some time.
903
        """
904
        if self.session:
×
905
            await self.session.close(force_after=force_after)
×
906
        # monitor_connection will cancel tasks
907

908
    async def run_fetch_blocks(self):
5✔
909
        header_queue = asyncio.Queue()
×
910
        await self.session.subscribe('blockchain.headers.subscribe', [], header_queue)
×
911
        while True:
×
912
            item = await header_queue.get()
×
913
            raw_header = item[0]
×
914
            height = raw_header['height']
×
915
            header = blockchain.deserialize_header(bfh(raw_header['hex']), height)
×
916
            self.tip_header = header
×
917
            self.tip = height
×
918
            if self.tip < constants.net.max_checkpoint():
×
919
                raise GracefulDisconnect('server tip below max checkpoint')
×
920
            self._mark_ready()
×
921
            blockchain_updated = await self._process_header_at_tip()
×
922
            # header processing done
923
            if self.is_main_server():
×
924
                self.logger.info(f"new chain tip on main interface. {height=}")
×
925
            if blockchain_updated:
×
926
                util.trigger_callback('blockchain_updated')
×
927
            util.trigger_callback('network_updated')
×
928
            await self.network.switch_unwanted_fork_interface()
×
929
            await self.network.switch_lagging_interface()
×
930
            await self.taskgroup.spawn(self._maybe_send_noise())
×
931

932
    async def _process_header_at_tip(self) -> bool:
5✔
933
        """Returns:
934
        False - boring fast-forward: we already have this header as part of this blockchain from another interface,
935
        True - new header we didn't have, or reorg
936
        """
937
        height, header = self.tip, self.tip_header
×
938
        async with self.network.bhi_lock:
×
939
            if self.blockchain.height() >= height and self.blockchain.check_header(header):
×
940
                # another interface amended the blockchain
941
                return False
×
NEW
942
            _, height = await self.step(height, header=header)
×
943
            # in the simple case, height == self.tip+1
944
            if height <= self.tip:
×
945
                await self.sync_until(height)
×
946
            return True
×
947

948
    async def sync_until(
5✔
949
        self,
950
        height: int,
951
        *,
952
        next_height: Optional[int] = None,
953
    ) -> Tuple[ChainResolutionMode, int]:
954
        if next_height is None:
5✔
955
            next_height = self.tip
×
956
        last = None  # type: Optional[ChainResolutionMode]
5✔
957
        while last is None or height <= next_height:
5✔
958
            prev_last, prev_height = last, height
5✔
959
            if next_height > height + 10:  # TODO make smarter. the protocol allows asking for n headers
5✔
NEW
960
                could_connect, num_headers = await self.request_chunk(height, tip=next_height)
×
961
                if not could_connect:
×
962
                    if height <= constants.net.max_checkpoint():
×
963
                        raise GracefulDisconnect('server chain conflicts with checkpoints or genesis')
×
964
                    last, height = await self.step(height)
×
965
                    continue
×
966
                util.trigger_callback('blockchain_updated')
×
967
                util.trigger_callback('network_updated')
×
968
                height = (height // 2016 * 2016) + num_headers
×
969
                assert height <= next_height+1, (height, self.tip)
×
NEW
970
                last = ChainResolutionMode.CATCHUP
×
971
            else:
972
                last, height = await self.step(height)
5✔
973
            assert (prev_last, prev_height) != (last, height), 'had to prevent infinite loop in interface.sync_until'
5✔
974
        return last, height
5✔
975

976
    async def step(
5✔
977
        self,
978
        height: int,
979
        *,
980
        header: Optional[dict] = None,  # at 'height'
981
    ) -> Tuple[ChainResolutionMode, int]:
982
        assert 0 <= height <= self.tip, (height, self.tip)
5✔
983
        if header is None:
5✔
984
            header = await self.get_block_header(height, mode=ChainResolutionMode.CATCHUP)
5✔
985

986
        chain = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header)
5✔
987
        if chain:
5✔
988
            self.blockchain = chain if isinstance(chain, Blockchain) else self.blockchain
×
989
            # note: there is an edge case here that is not handled.
990
            # we might know the blockhash (enough for check_header) but
991
            # not have the header itself. e.g. regtest chain with only genesis.
992
            # this situation resolves itself on the next block
NEW
993
            return ChainResolutionMode.CATCHUP, height+1
×
994

995
        can_connect = blockchain.can_connect(header) if 'mock' not in header else header['mock']['connect'](height)
5✔
996
        if not can_connect:
5✔
997
            self.logger.info(f"can't connect new block: {height=}")
5✔
998
            height, header, bad, bad_header = await self._search_headers_backwards(height, header=header)
5✔
999
            chain = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header)
5✔
1000
            can_connect = blockchain.can_connect(header) if 'mock' not in header else header['mock']['connect'](height)
5✔
1001
            assert chain or can_connect
5✔
1002
        if can_connect:
5✔
1003
            self.logger.info(f"new block: {height=}")
5✔
1004
            height += 1
5✔
1005
            if isinstance(can_connect, Blockchain):  # not when mocking
5✔
1006
                self.blockchain = can_connect
×
1007
                self.blockchain.save_header(header)
×
1008
            return ChainResolutionMode.CATCHUP, height
5✔
1009

1010
        good, bad, bad_header = await self._search_headers_binary(height, bad, bad_header, chain)
5✔
1011
        return await self._resolve_potential_chain_fork_given_forkpoint(good, bad, bad_header)
5✔
1012

1013
    async def _search_headers_binary(
5✔
1014
        self,
1015
        height: int,
1016
        bad: int,
1017
        bad_header: dict,
1018
        chain: Optional[Blockchain],
1019
    ) -> Tuple[int, int, dict]:
1020
        assert bad == bad_header['block_height']
5✔
1021
        _assert_header_does_not_check_against_any_chain(bad_header)
5✔
1022

1023
        self.blockchain = chain if isinstance(chain, Blockchain) else self.blockchain
5✔
1024
        good = height
5✔
1025
        while True:
5✔
1026
            assert good < bad, (good, bad)
5✔
1027
            height = (good + bad) // 2
5✔
1028
            self.logger.info(f"binary step. good {good}, bad {bad}, height {height}")
5✔
1029
            header = await self.get_block_header(height, mode=ChainResolutionMode.BINARY)
5✔
1030
            chain = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header)
5✔
1031
            if chain:
5✔
1032
                self.blockchain = chain if isinstance(chain, Blockchain) else self.blockchain
5✔
1033
                good = height
5✔
1034
            else:
1035
                bad = height
5✔
1036
                bad_header = header
5✔
1037
            if good + 1 == bad:
5✔
1038
                break
5✔
1039

1040
        mock = 'mock' in bad_header and bad_header['mock']['connect'](height)
5✔
1041
        real = not mock and self.blockchain.can_connect(bad_header, check_height=False)
5✔
1042
        if not real and not mock:
5✔
1043
            raise Exception('unexpected bad header during binary: {}'.format(bad_header))
×
1044
        _assert_header_does_not_check_against_any_chain(bad_header)
5✔
1045

1046
        self.logger.info(f"binary search exited. good {good}, bad {bad}")
5✔
1047
        return good, bad, bad_header
5✔
1048

1049
    async def _resolve_potential_chain_fork_given_forkpoint(
5✔
1050
        self,
1051
        good: int,
1052
        bad: int,
1053
        bad_header: dict,
1054
    ) -> Tuple[ChainResolutionMode, int]:
1055
        assert good + 1 == bad
5✔
1056
        assert bad == bad_header['block_height']
5✔
1057
        _assert_header_does_not_check_against_any_chain(bad_header)
5✔
1058
        # 'good' is the height of a block 'good_header', somewhere in self.blockchain.
1059
        # bad_header connects to good_header; bad_header itself is NOT in self.blockchain.
1060

1061
        bh = self.blockchain.height()
5✔
1062
        assert bh >= good, (bh, good)
5✔
1063
        if bh == good:
5✔
1064
            height = good + 1
×
1065
            self.logger.info(f"catching up from {height}")
×
NEW
1066
            return ChainResolutionMode.NO_FORK, height
×
1067

1068
        # this is a new fork we don't yet have
1069
        height = bad + 1
5✔
1070
        self.logger.info(f"new fork at bad height {bad}")
5✔
1071
        forkfun = self.blockchain.fork if 'mock' not in bad_header else bad_header['mock']['fork']
5✔
1072
        b = forkfun(bad_header)  # type: Blockchain
5✔
1073
        self.blockchain = b
5✔
1074
        assert b.forkpoint == bad
5✔
1075
        return ChainResolutionMode.FORK, height
5✔
1076

1077
    async def _search_headers_backwards(
5✔
1078
        self,
1079
        height: int,
1080
        *,
1081
        header: dict,
1082
    ) -> Tuple[int, dict, int, dict]:
1083
        async def iterate():
5✔
1084
            nonlocal height, header
1085
            checkp = False
5✔
1086
            if height <= constants.net.max_checkpoint():
5✔
1087
                height = constants.net.max_checkpoint()
×
1088
                checkp = True
×
1089
            header = await self.get_block_header(height, mode=ChainResolutionMode.BACKWARD)
5✔
1090
            chain = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header)
5✔
1091
            can_connect = blockchain.can_connect(header) if 'mock' not in header else header['mock']['connect'](height)
5✔
1092
            if chain or can_connect:
5✔
1093
                return False
5✔
1094
            if checkp:
5✔
1095
                raise GracefulDisconnect("server chain conflicts with checkpoints")
×
1096
            return True
5✔
1097

1098
        bad, bad_header = height, header
5✔
1099
        _assert_header_does_not_check_against_any_chain(bad_header)
5✔
1100
        with blockchain.blockchains_lock: chains = list(blockchain.blockchains.values())
5✔
1101
        local_max = max([0] + [x.height() for x in chains]) if 'mock' not in header else float('inf')
5✔
1102
        height = min(local_max + 1, height - 1)
5✔
1103
        while await iterate():
5✔
1104
            bad, bad_header = height, header
5✔
1105
            delta = self.tip - height
5✔
1106
            height = self.tip - 2 * delta
5✔
1107

1108
        _assert_header_does_not_check_against_any_chain(bad_header)
5✔
1109
        self.logger.info(f"exiting backward mode at {height}")
5✔
1110
        return height, header, bad, bad_header
5✔
1111

1112
    @classmethod
5✔
1113
    def client_name(cls) -> str:
5✔
1114
        return f'electrum/{version.ELECTRUM_VERSION}'
×
1115

1116
    def is_tor(self):
5✔
1117
        return self.host.endswith('.onion')
×
1118

1119
    def ip_addr(self) -> Optional[str]:
5✔
1120
        session = self.session
×
1121
        if not session: return None
×
1122
        peer_addr = session.remote_address()
×
1123
        if not peer_addr: return None
×
1124
        return str(peer_addr.host)
×
1125

1126
    def bucket_based_on_ipaddress(self) -> str:
5✔
1127
        def do_bucket():
×
1128
            if self.is_tor():
×
1129
                return BUCKET_NAME_OF_ONION_SERVERS
×
1130
            try:
×
1131
                ip_addr = ip_address(self.ip_addr())  # type: Union[IPv4Address, IPv6Address]
×
1132
            except ValueError:
×
1133
                return ''
×
1134
            if not ip_addr:
×
1135
                return ''
×
1136
            if ip_addr.is_loopback:  # localhost is exempt
×
1137
                return ''
×
1138
            if ip_addr.version == 4:
×
1139
                slash16 = IPv4Network(ip_addr).supernet(prefixlen_diff=32-16)
×
1140
                return str(slash16)
×
1141
            elif ip_addr.version == 6:
×
1142
                slash48 = IPv6Network(ip_addr).supernet(prefixlen_diff=128-48)
×
1143
                return str(slash48)
×
1144
            return ''
×
1145

1146
        if not self._ipaddr_bucket:
×
1147
            self._ipaddr_bucket = do_bucket()
×
1148
        return self._ipaddr_bucket
×
1149

1150
    async def get_merkle_for_transaction(self, tx_hash: str, tx_height: int) -> dict:
5✔
1151
        if not is_hash256_str(tx_hash):
×
1152
            raise Exception(f"{repr(tx_hash)} is not a txid")
×
1153
        if not is_non_negative_integer(tx_height):
×
1154
            raise Exception(f"{repr(tx_height)} is not a block height")
×
1155
        # do request
1156
        res = await self.session.send_request('blockchain.transaction.get_merkle', [tx_hash, tx_height])
×
1157
        # check response
1158
        block_height = assert_dict_contains_field(res, field_name='block_height')
×
1159
        merkle = assert_dict_contains_field(res, field_name='merkle')
×
1160
        pos = assert_dict_contains_field(res, field_name='pos')
×
1161
        # note: tx_height was just a hint to the server, don't enforce the response to match it
1162
        assert_non_negative_integer(block_height)
×
1163
        assert_non_negative_integer(pos)
×
1164
        assert_list_or_tuple(merkle)
×
1165
        for item in merkle:
×
1166
            assert_hash256_str(item)
×
1167
        return res
×
1168

1169
    async def get_transaction(self, tx_hash: str, *, timeout=None) -> str:
5✔
1170
        if not is_hash256_str(tx_hash):
×
1171
            raise Exception(f"{repr(tx_hash)} is not a txid")
×
1172
        raw = await self.session.send_request('blockchain.transaction.get', [tx_hash], timeout=timeout)
×
1173
        # validate response
1174
        if not is_hex_str(raw):
×
1175
            raise RequestCorrupted(f"received garbage (non-hex) as tx data (txid {tx_hash}): {raw!r}")
×
1176
        tx = Transaction(raw)
×
1177
        try:
×
1178
            tx.deserialize()  # see if raises
×
1179
        except Exception as e:
×
1180
            raise RequestCorrupted(f"cannot deserialize received transaction (txid {tx_hash})") from e
×
1181
        if tx.txid() != tx_hash:
×
1182
            raise RequestCorrupted(f"received tx does not match expected txid {tx_hash} (got {tx.txid()})")
×
1183
        return raw
×
1184

1185
    async def get_history_for_scripthash(self, sh: str) -> List[dict]:
5✔
1186
        if not is_hash256_str(sh):
×
1187
            raise Exception(f"{repr(sh)} is not a scripthash")
×
1188
        # do request
1189
        res = await self.session.send_request('blockchain.scripthash.get_history', [sh])
×
1190
        # check response
1191
        assert_list_or_tuple(res)
×
1192
        prev_height = 1
×
1193
        for tx_item in res:
×
1194
            height = assert_dict_contains_field(tx_item, field_name='height')
×
1195
            assert_dict_contains_field(tx_item, field_name='tx_hash')
×
1196
            assert_integer(height)
×
1197
            assert_hash256_str(tx_item['tx_hash'])
×
1198
            if height in (-1, 0):
×
1199
                assert_dict_contains_field(tx_item, field_name='fee')
×
1200
                assert_non_negative_integer(tx_item['fee'])
×
1201
                prev_height = float("inf")  # this ensures confirmed txs can't follow mempool txs
×
1202
            else:
1203
                # check monotonicity of heights
1204
                if height < prev_height:
×
1205
                    raise RequestCorrupted(f'heights of confirmed txs must be in increasing order')
×
1206
                prev_height = height
×
1207
        hashes = set(map(lambda item: item['tx_hash'], res))
×
1208
        if len(hashes) != len(res):
×
1209
            # Either server is sending garbage... or maybe if server is race-prone
1210
            # a recently mined tx could be included in both last block and mempool?
1211
            # Still, it's simplest to just disregard the response.
1212
            raise RequestCorrupted(f"server history has non-unique txids for sh={sh}")
×
1213
        return res
×
1214

1215
    async def listunspent_for_scripthash(self, sh: str) -> List[dict]:
5✔
1216
        if not is_hash256_str(sh):
×
1217
            raise Exception(f"{repr(sh)} is not a scripthash")
×
1218
        # do request
1219
        res = await self.session.send_request('blockchain.scripthash.listunspent', [sh])
×
1220
        # check response
1221
        assert_list_or_tuple(res)
×
1222
        for utxo_item in res:
×
1223
            assert_dict_contains_field(utxo_item, field_name='tx_pos')
×
1224
            assert_dict_contains_field(utxo_item, field_name='value')
×
1225
            assert_dict_contains_field(utxo_item, field_name='tx_hash')
×
1226
            assert_dict_contains_field(utxo_item, field_name='height')
×
1227
            assert_non_negative_integer(utxo_item['tx_pos'])
×
1228
            assert_non_negative_integer(utxo_item['value'])
×
1229
            assert_non_negative_integer(utxo_item['height'])
×
1230
            assert_hash256_str(utxo_item['tx_hash'])
×
1231
        return res
×
1232

1233
    async def get_balance_for_scripthash(self, sh: str) -> dict:
5✔
1234
        if not is_hash256_str(sh):
×
1235
            raise Exception(f"{repr(sh)} is not a scripthash")
×
1236
        # do request
1237
        res = await self.session.send_request('blockchain.scripthash.get_balance', [sh])
×
1238
        # check response
1239
        assert_dict_contains_field(res, field_name='confirmed')
×
1240
        assert_dict_contains_field(res, field_name='unconfirmed')
×
1241
        assert_non_negative_integer(res['confirmed'])
×
1242
        assert_integer(res['unconfirmed'])
×
1243
        return res
×
1244

1245
    async def get_txid_from_txpos(self, tx_height: int, tx_pos: int, merkle: bool):
5✔
1246
        if not is_non_negative_integer(tx_height):
×
1247
            raise Exception(f"{repr(tx_height)} is not a block height")
×
1248
        if not is_non_negative_integer(tx_pos):
×
1249
            raise Exception(f"{repr(tx_pos)} should be non-negative integer")
×
1250
        # do request
1251
        res = await self.session.send_request(
×
1252
            'blockchain.transaction.id_from_pos',
1253
            [tx_height, tx_pos, merkle],
1254
        )
1255
        # check response
1256
        if merkle:
×
1257
            assert_dict_contains_field(res, field_name='tx_hash')
×
1258
            assert_dict_contains_field(res, field_name='merkle')
×
1259
            assert_hash256_str(res['tx_hash'])
×
1260
            assert_list_or_tuple(res['merkle'])
×
1261
            for node_hash in res['merkle']:
×
1262
                assert_hash256_str(node_hash)
×
1263
        else:
1264
            assert_hash256_str(res)
×
1265
        return res
×
1266

1267
    async def get_fee_histogram(self) -> Sequence[Tuple[Union[float, int], int]]:
5✔
1268
        # do request
1269
        res = await self.session.send_request('mempool.get_fee_histogram')
×
1270
        # check response
1271
        assert_list_or_tuple(res)
×
1272
        prev_fee = float('inf')
×
1273
        for fee, s in res:
×
1274
            assert_non_negative_int_or_float(fee)
×
1275
            assert_non_negative_integer(s)
×
1276
            if fee >= prev_fee:  # check monotonicity
×
1277
                raise RequestCorrupted(f'fees must be in decreasing order')
×
1278
            prev_fee = fee
×
1279
        return res
×
1280

1281
    async def get_server_banner(self) -> str:
5✔
1282
        # do request
1283
        res = await self.session.send_request('server.banner')
×
1284
        # check response
1285
        if not isinstance(res, str):
×
1286
            raise RequestCorrupted(f'{res!r} should be a str')
×
1287
        return res
×
1288

1289
    async def get_donation_address(self) -> str:
5✔
1290
        # do request
1291
        res = await self.session.send_request('server.donation_address')
×
1292
        # check response
1293
        if not res:  # ignore empty string
×
1294
            return ''
×
1295
        if not bitcoin.is_address(res):
×
1296
            # note: do not hard-fail -- allow server to use future-type
1297
            #       bitcoin address we do not recognize
1298
            self.logger.info(f"invalid donation address from server: {repr(res)}")
×
1299
            res = ''
×
1300
        return res
×
1301

1302
    async def get_relay_fee(self) -> int:
5✔
1303
        """Returns the min relay feerate in sat/kbyte."""
1304
        # do request
1305
        res = await self.session.send_request('blockchain.relayfee')
×
1306
        # check response
1307
        assert_non_negative_int_or_float(res)
×
1308
        relayfee = int(res * bitcoin.COIN)
×
1309
        relayfee = max(0, relayfee)
×
1310
        return relayfee
×
1311

1312
    async def get_estimatefee(self, num_blocks: int) -> int:
5✔
1313
        """Returns a feerate estimate for getting confirmed within
1314
        num_blocks blocks, in sat/kbyte.
1315
        Returns -1 if the server could not provide an estimate.
1316
        """
1317
        if not is_non_negative_integer(num_blocks):
×
1318
            raise Exception(f"{repr(num_blocks)} is not a num_blocks")
×
1319
        # do request
1320
        try:
×
1321
            res = await self.session.send_request('blockchain.estimatefee', [num_blocks])
×
1322
        except aiorpcx.jsonrpc.ProtocolError as e:
×
1323
            # The protocol spec says the server itself should already have returned -1
1324
            # if it cannot provide an estimate, however apparently "electrs" does not conform
1325
            # and sends an error instead. Convert it here:
1326
            if "cannot estimate fee" in e.message:
×
1327
                res = -1
×
1328
            else:
1329
                raise
×
1330
        except aiorpcx.jsonrpc.RPCError as e:
×
1331
            # The protocol spec says the server itself should already have returned -1
1332
            # if it cannot provide an estimate. "Fulcrum" often sends:
1333
            #   aiorpcx.jsonrpc.RPCError: (-32603, 'internal error: bitcoind request timed out')
1334
            if e.code == JSONRPC.INTERNAL_ERROR:
×
1335
                res = -1
×
1336
            else:
1337
                raise
×
1338
        # check response
1339
        if res != -1:
×
1340
            assert_non_negative_int_or_float(res)
×
1341
            res = int(res * bitcoin.COIN)
×
1342
        return res
×
1343

1344

1345
def _assert_header_does_not_check_against_any_chain(header: dict) -> None:
5✔
1346
    chain_bad = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header)
5✔
1347
    if chain_bad:
5✔
1348
        raise Exception('bad_header must not check!')
×
1349

1350

1351
def check_cert(host, cert):
5✔
1352
    try:
×
1353
        b = pem.dePem(cert, 'CERTIFICATE')
×
1354
        x = x509.X509(b)
×
1355
    except Exception:
×
1356
        traceback.print_exc(file=sys.stdout)
×
1357
        return
×
1358

1359
    try:
×
1360
        x.check_date()
×
1361
        expired = False
×
1362
    except Exception:
×
1363
        expired = True
×
1364

1365
    m = "host: %s\n"%host
×
1366
    m += "has_expired: %s\n"% expired
×
1367
    util.print_msg(m)
×
1368

1369

1370
# Used by tests
1371
def _match_hostname(name, val):
5✔
1372
    if val == name:
×
1373
        return True
×
1374

1375
    return val.startswith('*.') and name.endswith(val[1:])
×
1376

1377

1378
def test_certificates():
5✔
1379
    from .simple_config import SimpleConfig
×
1380
    config = SimpleConfig()
×
1381
    mydir = os.path.join(config.path, "certs")
×
1382
    certs = os.listdir(mydir)
×
1383
    for c in certs:
×
1384
        p = os.path.join(mydir,c)
×
1385
        with open(p, encoding='utf-8') as f:
×
1386
            cert = f.read()
×
1387
        check_cert(c, cert)
×
1388

1389
if __name__ == "__main__":
5✔
1390
    test_certificates()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc