• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

spesmilo / electrum / 6486955220795392

29 May 2025 02:19PM UTC coverage: 59.437% (-0.09%) from 59.523%
6486955220795392

Pull #9875

CirrusCI

SomberNight
interface: address feedback for PaddedRSTransport
Pull Request #9875: interface: add padding and some noise to protocol messages

13 of 77 new or added lines in 1 file covered. (16.88%)

2 existing lines in 1 file now uncovered.

21688 of 36489 relevant lines covered (59.44%)

2.97 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

35.57
/electrum/interface.py
1
#!/usr/bin/env python
2
#
3
# Electrum - lightweight Bitcoin client
4
# Copyright (C) 2011 thomasv@gitorious
5
#
6
# Permission is hereby granted, free of charge, to any person
7
# obtaining a copy of this software and associated documentation files
8
# (the "Software"), to deal in the Software without restriction,
9
# including without limitation the rights to use, copy, modify, merge,
10
# publish, distribute, sublicense, and/or sell copies of the Software,
11
# and to permit persons to whom the Software is furnished to do so,
12
# subject to the following conditions:
13
#
14
# The above copyright notice and this permission notice shall be
15
# included in all copies or substantial portions of the Software.
16
#
17
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
18
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
19
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
20
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
21
# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
22
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
23
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
# SOFTWARE.
25
import os
5✔
26
import re
5✔
27
import ssl
5✔
28
import sys
5✔
29
import time
5✔
30
import traceback
5✔
31
import asyncio
5✔
32
import socket
5✔
33
from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set, NamedTuple, Any, Sequence, Dict
5✔
34
from collections import defaultdict
5✔
35
from ipaddress import IPv4Network, IPv6Network, ip_address, IPv6Address, IPv4Address
5✔
36
import itertools
5✔
37
import logging
5✔
38
import hashlib
5✔
39
import functools
5✔
40
import random
5✔
41

42
import aiorpcx
5✔
43
from aiorpcx import RPCSession, Notification, NetAddress, NewlineFramer
5✔
44
from aiorpcx.curio import timeout_after, TaskTimeout
5✔
45
from aiorpcx.jsonrpc import JSONRPC, CodeMessageError
5✔
46
from aiorpcx.rawsocket import RSClient, RSTransport
5✔
47
import certifi
5✔
48

49
from .util import (ignore_exceptions, log_exceptions, bfh, ESocksProxy,
5✔
50
                   is_integer, is_non_negative_integer, is_hash256_str, is_hex_str,
51
                   is_int_or_float, is_non_negative_int_or_float, OldTaskGroup)
52
from . import util
5✔
53
from . import x509
5✔
54
from . import pem
5✔
55
from . import version
5✔
56
from . import blockchain
5✔
57
from .blockchain import Blockchain, HEADER_SIZE
5✔
58
from . import bitcoin
5✔
59
from . import constants
5✔
60
from .i18n import _
5✔
61
from .logging import Logger
5✔
62
from .transaction import Transaction
5✔
63
from .fee_policy import FEE_ETA_TARGETS
5✔
64

65
if TYPE_CHECKING:
5✔
66
    from .network import Network
×
67
    from .simple_config import SimpleConfig
×
68

69

70
ca_path = certifi.where()
5✔
71

72
BUCKET_NAME_OF_ONION_SERVERS = 'onion'
5✔
73

74
_KNOWN_NETWORK_PROTOCOLS = {'t', 's'}
5✔
75
PREFERRED_NETWORK_PROTOCOL = 's'
5✔
76
assert PREFERRED_NETWORK_PROTOCOL in _KNOWN_NETWORK_PROTOCOLS
5✔
77

78

79
class NetworkTimeout:
5✔
80
    # seconds
81
    class Generic:
5✔
82
        NORMAL = 30
5✔
83
        RELAXED = 45
5✔
84
        MOST_RELAXED = 600
5✔
85

86
    class Urgent(Generic):
5✔
87
        NORMAL = 10
5✔
88
        RELAXED = 20
5✔
89
        MOST_RELAXED = 60
5✔
90

91

92
def assert_non_negative_integer(val: Any) -> None:
5✔
93
    if not is_non_negative_integer(val):
×
94
        raise RequestCorrupted(f'{val!r} should be a non-negative integer')
×
95

96

97
def assert_integer(val: Any) -> None:
5✔
98
    if not is_integer(val):
×
99
        raise RequestCorrupted(f'{val!r} should be an integer')
×
100

101

102
def assert_int_or_float(val: Any) -> None:
5✔
103
    if not is_int_or_float(val):
×
104
        raise RequestCorrupted(f'{val!r} should be int or float')
×
105

106

107
def assert_non_negative_int_or_float(val: Any) -> None:
5✔
108
    if not is_non_negative_int_or_float(val):
×
109
        raise RequestCorrupted(f'{val!r} should be a non-negative int or float')
×
110

111

112
def assert_hash256_str(val: Any) -> None:
5✔
113
    if not is_hash256_str(val):
×
114
        raise RequestCorrupted(f'{val!r} should be a hash256 str')
×
115

116

117
def assert_hex_str(val: Any) -> None:
5✔
118
    if not is_hex_str(val):
×
119
        raise RequestCorrupted(f'{val!r} should be a hex str')
×
120

121

122
def assert_dict_contains_field(d: Any, *, field_name: str) -> Any:
5✔
123
    if not isinstance(d, dict):
×
124
        raise RequestCorrupted(f'{d!r} should be a dict')
×
125
    if field_name not in d:
×
126
        raise RequestCorrupted(f'required field {field_name!r} missing from dict')
×
127
    return d[field_name]
×
128

129

130
def assert_list_or_tuple(val: Any) -> None:
5✔
131
    if not isinstance(val, (list, tuple)):
×
132
        raise RequestCorrupted(f'{val!r} should be a list or tuple')
×
133

134

135
class NotificationSession(RPCSession):
5✔
136

137
    def __init__(self, *args, interface: 'Interface', **kwargs):
5✔
138
        super(NotificationSession, self).__init__(*args, **kwargs)
×
139
        self.subscriptions = defaultdict(list)
×
140
        self.cache = {}
×
141
        self._msg_counter = itertools.count(start=1)
×
142
        self.interface = interface
×
NEW
143
        self.taskgroup = interface.taskgroup
×
UNCOV
144
        self.cost_hard_limit = 0  # disable aiorpcx resource limits
×
145

146
    async def handle_request(self, request):
5✔
147
        self.maybe_log(f"--> {request}")
×
148
        try:
×
149
            if isinstance(request, Notification):
×
150
                params, result = request.args[:-1], request.args[-1]
×
151
                key = self.get_hashable_key_for_rpc_call(request.method, params)
×
152
                if key in self.subscriptions:
×
153
                    self.cache[key] = result
×
154
                    for queue in self.subscriptions[key]:
×
155
                        await queue.put(request.args)
×
156
                else:
157
                    raise Exception(f'unexpected notification')
×
158
            else:
159
                raise Exception(f'unexpected request. not a notification')
×
160
        except Exception as e:
×
161
            self.interface.logger.info(f"error handling request {request}. exc: {repr(e)}")
×
162
            await self.close()
×
163

164
    async def send_request(self, *args, timeout=None, **kwargs):
5✔
165
        # note: semaphores/timeouts/backpressure etc are handled by
166
        # aiorpcx. the timeout arg here in most cases should not be set
167
        msg_id = next(self._msg_counter)
×
168
        self.maybe_log(f"<-- {args} {kwargs} (id: {msg_id})")
×
169
        try:
×
170
            # note: RPCSession.send_request raises TaskTimeout in case of a timeout.
171
            # TaskTimeout is a subclass of CancelledError, which is *suppressed* in TaskGroups
172
            response = await util.wait_for2(
×
173
                super().send_request(*args, **kwargs),
174
                timeout)
175
        except (TaskTimeout, asyncio.TimeoutError) as e:
×
176
            self.maybe_log(f"--> request timed out: {args} (id: {msg_id})")
×
177
            raise RequestTimedOut(f'request timed out: {args} (id: {msg_id})') from e
×
178
        except CodeMessageError as e:
×
179
            self.maybe_log(f"--> {repr(e)} (id: {msg_id})")
×
180
            raise
×
181
        except BaseException as e:  # cancellations, etc. are useful for debugging
×
182
            self.maybe_log(f"--> {repr(e)} (id: {msg_id})")
×
183
            raise
×
184
        else:
185
            self.maybe_log(f"--> {response} (id: {msg_id})")
×
186
            return response
×
187

188
    def set_default_timeout(self, timeout):
5✔
189
        assert hasattr(self, "sent_request_timeout")  # in base class
×
190
        self.sent_request_timeout = timeout
×
191
        assert hasattr(self, "max_send_delay")        # in base class
×
192
        self.max_send_delay = timeout
×
193

194
    async def subscribe(self, method: str, params: List, queue: asyncio.Queue):
5✔
195
        # note: until the cache is written for the first time,
196
        # each 'subscribe' call might make a request on the network.
197
        key = self.get_hashable_key_for_rpc_call(method, params)
×
198
        self.subscriptions[key].append(queue)
×
199
        if key in self.cache:
×
200
            result = self.cache[key]
×
201
        else:
202
            result = await self.send_request(method, params)
×
203
            self.cache[key] = result
×
204
        await queue.put(params + [result])
×
205

206
    def unsubscribe(self, queue):
5✔
207
        """Unsubscribe a callback to free object references to enable GC."""
208
        # note: we can't unsubscribe from the server, so we keep receiving
209
        # subsequent notifications
210
        for v in self.subscriptions.values():
×
211
            if queue in v:
×
212
                v.remove(queue)
×
213

214
    @classmethod
5✔
215
    def get_hashable_key_for_rpc_call(cls, method, params):
5✔
216
        """Hashable index for subscriptions and cache"""
217
        return str(method) + repr(params)
×
218

219
    def maybe_log(self, msg: str) -> None:
5✔
220
        if not self.interface: return
×
221
        if self.interface.debug or self.interface.network.debug:
×
222
            self.interface.logger.debug(msg)
×
223

224
    def default_framer(self):
5✔
225
        # overridden so that max_size can be customized
226
        max_size = self.interface.network.config.NETWORK_MAX_INCOMING_MSG_SIZE
×
227
        assert max_size > 500_000, f"{max_size=} (< 500_000) is too small"
×
228
        return NewlineFramer(max_size=max_size)
×
229

230
    async def close(self, *, force_after: int = None):
5✔
231
        """Closes the connection and waits for it to be closed.
232
        We try to flush buffered data to the wire, which can take some time.
233
        """
234
        if force_after is None:
×
235
            # We give up after a while and just abort the connection.
236
            # Note: specifically if the server is running Fulcrum, waiting seems hopeless,
237
            #       the connection must be aborted (see https://github.com/cculianu/Fulcrum/issues/76)
238
            # Note: if the ethernet cable was pulled or wifi disconnected, that too might
239
            #       wait until this timeout is triggered
240
            force_after = 1  # seconds
×
241
        await super().close(force_after=force_after)
×
242

243

244
class NetworkException(Exception): pass
5✔
245

246

247
class GracefulDisconnect(NetworkException):
5✔
248
    log_level = logging.INFO
5✔
249

250
    def __init__(self, *args, log_level=None, **kwargs):
5✔
251
        Exception.__init__(self, *args, **kwargs)
5✔
252
        if log_level is not None:
5✔
253
            self.log_level = log_level
×
254

255

256
class RequestTimedOut(GracefulDisconnect):
5✔
257
    def __str__(self):
5✔
258
        return _("Network request timed out.")
×
259

260

261
class RequestCorrupted(Exception): pass
5✔
262

263
class ErrorParsingSSLCert(Exception): pass
5✔
264
class ErrorGettingSSLCertFromServer(Exception): pass
5✔
265
class ErrorSSLCertFingerprintMismatch(Exception): pass
5✔
266
class InvalidOptionCombination(Exception): pass
5✔
267
class ConnectError(NetworkException): pass
5✔
268

269

270
class _RSClient(RSClient):
5✔
271
    async def create_connection(self):
5✔
272
        try:
×
273
            return await super().create_connection()
×
274
        except OSError as e:
×
275
            # note: using "from e" here will set __cause__ of ConnectError
276
            raise ConnectError(e) from e
×
277

278

279
class PaddedRSTransport(RSTransport):
5✔
280
    """A raw socket transport that provides basic countermeasures against traffic analysis
281
    by padding the jsonrpc payload with whitespaces to have ~uniform-size TCP packets.
282
    (it is assumed that a network observer does not see plaintext transport contents,
283
    due to it being wrapped e.g. in TLS)
284
    """
285

286
    MIN_PACKET_SIZE = 1024
5✔
287
    WAIT_FOR_BUFFER_GROWTH_SECONDS = 1.0
5✔
288

289
    session: Optional['RPCSession']
5✔
290

291
    def __init__(self, *args, **kwargs):
5✔
NEW
292
        RSTransport.__init__(self, *args, **kwargs)
×
NEW
293
        self._sbuffer = bytearray()  # "send buffer"
×
NEW
294
        self._sbuffer_task = None  # type: Optional[asyncio.Task]
×
NEW
295
        self._sbuffer_has_data_evt = asyncio.Event()
×
NEW
296
        self._last_send = time.monotonic()
×
NEW
297
        self._force_send = False  # type: bool
×
298

299
    # note: this does not call super().write() but is a complete reimplementation
300
    async def write(self, message):
5✔
NEW
301
        await self._can_send.wait()
×
NEW
302
        if self.is_closing():
×
NEW
303
            return
×
NEW
304
        framed_message = self._framer.frame(message)
×
NEW
305
        self._sbuffer += framed_message
×
NEW
306
        self._sbuffer_has_data_evt.set()
×
NEW
307
        self._maybe_consume_sbuffer()
×
308

309
    def _maybe_consume_sbuffer(self) -> None:
5✔
310
        """Maybe take some data from sbuffer and send it on the wire."""
NEW
311
        if not self._can_send.is_set() or self.is_closing():
×
NEW
312
            return
×
NEW
313
        buf = self._sbuffer
×
NEW
314
        if not buf:
×
NEW
315
            return
×
316
        # if there is enough data in the buffer, or if we haven't sent in a while, send now:
NEW
317
        if not (
×
318
            self._force_send
319
            or len(buf) >= self.MIN_PACKET_SIZE
320
            or self._last_send + self.WAIT_FOR_BUFFER_GROWTH_SECONDS < time.monotonic()
321
        ):
NEW
322
            return
×
NEW
323
        assert buf[-2:] in (b"}\n", b"]\n"), f"unexpected json-rpc terminator: {buf[-2:]=!r}"
×
324
        # either (1) pad length to next power of two, to create "lsize" packet:
NEW
325
        payload_lsize = len(buf)
×
NEW
326
        total_lsize = max(self.MIN_PACKET_SIZE, 2 ** (payload_lsize.bit_length()))
×
NEW
327
        npad_lsize = total_lsize - payload_lsize
×
328
        # or if that wasted a lot of bandwidth with padding, (2) defer sending some messages
329
        # and create a packet with half that size ("ssize", s for small)
NEW
330
        total_ssize = max(self.MIN_PACKET_SIZE, total_lsize // 2)
×
NEW
331
        payload_ssize = buf.rfind(b"\n", 0, total_ssize)
×
NEW
332
        if payload_ssize != -1:
×
NEW
333
            payload_ssize += 1  # for "\n" char
×
NEW
334
            npad_ssize = total_ssize - payload_ssize
×
335
        else:
NEW
336
            npad_ssize = float("inf")
×
337
        # decide between (1) and (2):
NEW
338
        if self._force_send or npad_lsize <= npad_ssize:
×
339
            # (1) create "lsize" packet: consume full buffer
NEW
340
            npad = npad_lsize
×
NEW
341
            p_idx = payload_lsize
×
342
        else:
343
            # (2) create "ssize" packet: consume some, but defer some for later
NEW
344
            npad = npad_ssize
×
NEW
345
            p_idx = payload_ssize
×
346
        # pad by adding spaces near end
347
        # self.session.maybe_log(
348
        #     f"PaddedRSTransport. calling low-level write(). "
349
        #     f"chose between (lsize:{payload_lsize}+{npad_lsize}, ssize:{payload_ssize}+{npad_ssize}). "
350
        #     f"won: {'tie' if npad_lsize == npad_ssize else 'lsize' if npad_lsize < npad_ssize else 'ssize'}."
351
        # )
NEW
352
        json_rpc_terminator = buf[p_idx-2:p_idx]
×
NEW
353
        assert json_rpc_terminator in (b"}\n", b"]\n"), f"unexpected {json_rpc_terminator=!r}"
×
NEW
354
        buf2 = buf[:p_idx-2] + (npad * b" ") + json_rpc_terminator
×
NEW
355
        self._asyncio_transport.write(buf2)
×
NEW
356
        self._last_send = time.monotonic()
×
NEW
357
        del self._sbuffer[:p_idx]
×
NEW
358
        if not self._sbuffer:
×
NEW
359
            self._sbuffer_has_data_evt.clear()
×
360

361
    async def _poll_sbuffer(self):
5✔
NEW
362
        while True:
×
NEW
363
            await self._sbuffer_has_data_evt.wait()  # to avoid busy-waiting
×
NEW
364
            self._maybe_consume_sbuffer()
×
365
            # If there is still data in the buffer, sleep until it would time out.
366
            # note: If the transport is ~idle, when we wake up, we will send the current buf data,
367
            #       but if busy, we might wake up to completely new buffer contents. Either is fine.
NEW
368
            if len(self._sbuffer) > 0:
×
NEW
369
                timeout_abs = self._last_send + self.WAIT_FOR_BUFFER_GROWTH_SECONDS
×
NEW
370
                timeout_rel = max(0.0, timeout_abs - time.monotonic())
×
NEW
371
                await asyncio.sleep(timeout_rel)
×
372

373
    def connection_made(self, transport: asyncio.BaseTransport):
5✔
NEW
374
        super().connection_made(transport)
×
NEW
375
        if isinstance(self.session, NotificationSession):
×
NEW
376
            coro = self.session.taskgroup.spawn(self._poll_sbuffer())
×
NEW
377
            self._sbuffer_task = self.loop.create_task(coro)
×
378
        else:
379
            # This a short-lived "fetch_certificate"-type session.
380
            # No polling here, we always force-empty the buffer.
NEW
381
            self._force_send = True
×
382

383

384
class ServerAddr:
5✔
385

386
    def __init__(self, host: str, port: Union[int, str], *, protocol: str = None):
5✔
387
        assert isinstance(host, str), repr(host)
5✔
388
        if protocol is None:
5✔
389
            protocol = 's'
×
390
        if not host:
5✔
391
            raise ValueError('host must not be empty')
×
392
        if host[0] == '[' and host[-1] == ']':  # IPv6
5✔
393
            host = host[1:-1]
5✔
394
        try:
5✔
395
            net_addr = NetAddress(host, port)  # this validates host and port
5✔
396
        except Exception as e:
5✔
397
            raise ValueError(f"cannot construct ServerAddr: invalid host or port (host={host}, port={port})") from e
5✔
398
        if protocol not in _KNOWN_NETWORK_PROTOCOLS:
5✔
399
            raise ValueError(f"invalid network protocol: {protocol}")
×
400
        self.host = str(net_addr.host)  # canonical form (if e.g. IPv6 address)
5✔
401
        self.port = int(net_addr.port)
5✔
402
        self.protocol = protocol
5✔
403
        self._net_addr_str = str(net_addr)
5✔
404

405
    @classmethod
5✔
406
    def from_str(cls, s: str) -> 'ServerAddr':
5✔
407
        """Constructs a ServerAddr or raises ValueError."""
408
        # host might be IPv6 address, hence do rsplit:
409
        host, port, protocol = str(s).rsplit(':', 2)
5✔
410
        return ServerAddr(host=host, port=port, protocol=protocol)
5✔
411

412
    @classmethod
5✔
413
    def from_str_with_inference(cls, s: str) -> Optional['ServerAddr']:
5✔
414
        """Construct ServerAddr from str, guessing missing details.
415
        Does not raise - just returns None if guessing failed.
416
        Ongoing compatibility not guaranteed.
417
        """
418
        if not s:
5✔
419
            return None
×
420
        host = ""
5✔
421
        if s[0] == "[" and "]" in s:  # IPv6 address
5✔
422
            host_end = s.index("]")
5✔
423
            host = s[1:host_end]
5✔
424
            s = s[host_end+1:]
5✔
425
        items = str(s).rsplit(':', 2)
5✔
426
        if len(items) < 2:
5✔
427
            return None  # although maybe we could guess the port too?
5✔
428
        host = host or items[0]
5✔
429
        port = items[1]
5✔
430
        if len(items) >= 3:
5✔
431
            protocol = items[2]
5✔
432
        else:
433
            protocol = PREFERRED_NETWORK_PROTOCOL
5✔
434
        try:
5✔
435
            return ServerAddr(host=host, port=port, protocol=protocol)
5✔
436
        except ValueError:
5✔
437
            return None
5✔
438

439
    def to_friendly_name(self) -> str:
5✔
440
        # note: this method is closely linked to from_str_with_inference
441
        if self.protocol == 's':  # hide trailing ":s"
5✔
442
            return self.net_addr_str()
5✔
443
        return str(self)
5✔
444

445
    def __str__(self):
5✔
446
        return '{}:{}'.format(self.net_addr_str(), self.protocol)
5✔
447

448
    def to_json(self) -> str:
5✔
449
        return str(self)
×
450

451
    def __repr__(self):
5✔
452
        return f'<ServerAddr host={self.host} port={self.port} protocol={self.protocol}>'
×
453

454
    def net_addr_str(self) -> str:
5✔
455
        return self._net_addr_str
5✔
456

457
    def __eq__(self, other):
5✔
458
        if not isinstance(other, ServerAddr):
5✔
459
            return False
×
460
        return (self.host == other.host
5✔
461
                and self.port == other.port
462
                and self.protocol == other.protocol)
463

464
    def __ne__(self, other):
5✔
465
        return not (self == other)
×
466

467
    def __hash__(self):
5✔
468
        return hash((self.host, self.port, self.protocol))
×
469

470

471
def _get_cert_path_for_host(*, config: 'SimpleConfig', host: str) -> str:
5✔
472
    filename = host
5✔
473
    try:
5✔
474
        ip = ip_address(host)
5✔
475
    except ValueError:
5✔
476
        pass
5✔
477
    else:
478
        if isinstance(ip, IPv6Address):
×
479
            filename = f"ipv6_{ip.packed.hex()}"
×
480
    return os.path.join(config.path, 'certs', filename)
5✔
481

482

483
class Interface(Logger):
5✔
484

485
    LOGGING_SHORTCUT = 'i'
5✔
486

487
    def __init__(self, *, network: 'Network', server: ServerAddr):
5✔
488
        self.ready = network.asyncio_loop.create_future()
5✔
489
        self.got_disconnected = asyncio.Event()
5✔
490
        self.server = server
5✔
491
        Logger.__init__(self)
5✔
492
        assert network.config.path
5✔
493
        self.cert_path = _get_cert_path_for_host(config=network.config, host=self.host)
5✔
494
        self.blockchain = None  # type: Optional[Blockchain]
5✔
495
        self._requested_chunks = set()  # type: Set[int]
5✔
496
        self.network = network
5✔
497
        self.session = None  # type: Optional[NotificationSession]
5✔
498
        self._ipaddr_bucket = None
5✔
499
        # Set up proxy.
500
        # - for servers running on localhost, the proxy is not used. If user runs their own server
501
        #   on same machine, this lets them enable the proxy (which is used for e.g. FX rates).
502
        #   note: we could maybe relax this further and bypass the proxy for all private
503
        #         addresses...? e.g. 192.168.x.x
504
        if util.is_localhost(server.host):
5✔
505
            self.logger.info(f"looks like localhost: not using proxy for this server")
×
506
            self.proxy = None
×
507
        else:
508
            self.proxy = ESocksProxy.from_network_settings(network)
5✔
509

510
        # Latest block header and corresponding height, as claimed by the server.
511
        # Note that these values are updated before they are verified.
512
        # Especially during initial header sync, verification can take a long time.
513
        # Failing verification will get the interface closed.
514
        self.tip_header = None
5✔
515
        self.tip = 0
5✔
516

517
        self.fee_estimates_eta = {}  # type: Dict[int, int]
5✔
518

519
        # Dump network messages (only for this interface).  Set at runtime from the console.
520
        self.debug = False
5✔
521

522
        self.taskgroup = OldTaskGroup()
5✔
523

524
        async def spawn_task():
5✔
525
            task = await self.network.taskgroup.spawn(self.run())
5✔
526
            task.set_name(f"interface::{str(server)}")
5✔
527
        asyncio.run_coroutine_threadsafe(spawn_task(), self.network.asyncio_loop)
5✔
528

529
    @property
5✔
530
    def host(self):
5✔
531
        return self.server.host
5✔
532

533
    @property
5✔
534
    def port(self):
5✔
535
        return self.server.port
×
536

537
    @property
5✔
538
    def protocol(self):
5✔
539
        return self.server.protocol
×
540

541
    def diagnostic_name(self):
5✔
542
        return self.server.net_addr_str()
5✔
543

544
    def __str__(self):
5✔
545
        return f"<Interface {self.diagnostic_name()}>"
×
546

547
    async def is_server_ca_signed(self, ca_ssl_context):
5✔
548
        """Given a CA enforcing SSL context, returns True if the connection
549
        can be established. Returns False if the server has a self-signed
550
        certificate but otherwise is okay. Any other failures raise.
551
        """
552
        try:
×
553
            await self.open_session(ca_ssl_context, exit_early=True)
×
554
        except ConnectError as e:
×
555
            cause = e.__cause__
×
556
            if (isinstance(cause, ssl.SSLCertVerificationError)
×
557
                    and cause.reason == 'CERTIFICATE_VERIFY_FAILED'
558
                    and cause.verify_code == 18):  # "self signed certificate"
559
                # Good. We will use this server as self-signed.
560
                return False
×
561
            # Not good. Cannot use this server.
562
            raise
×
563
        # Good. We will use this server as CA-signed.
564
        return True
×
565

566
    async def _try_saving_ssl_cert_for_first_time(self, ca_ssl_context):
5✔
567
        ca_signed = await self.is_server_ca_signed(ca_ssl_context)
×
568
        if ca_signed:
×
569
            if self._get_expected_fingerprint():
×
570
                raise InvalidOptionCombination("cannot use --serverfingerprint with CA signed servers")
×
571
            with open(self.cert_path, 'w') as f:
×
572
                # empty file means this is CA signed, not self-signed
573
                f.write('')
×
574
        else:
575
            await self._save_certificate()
×
576

577
    def _is_saved_ssl_cert_available(self):
5✔
578
        if not os.path.exists(self.cert_path):
×
579
            return False
×
580
        with open(self.cert_path, 'r') as f:
×
581
            contents = f.read()
×
582
        if contents == '':  # CA signed
×
583
            if self._get_expected_fingerprint():
×
584
                raise InvalidOptionCombination("cannot use --serverfingerprint with CA signed servers")
×
585
            return True
×
586
        # pinned self-signed cert
587
        try:
×
588
            b = pem.dePem(contents, 'CERTIFICATE')
×
589
        except SyntaxError as e:
×
590
            self.logger.info(f"error parsing already saved cert: {e}")
×
591
            raise ErrorParsingSSLCert(e) from e
×
592
        try:
×
593
            x = x509.X509(b)
×
594
        except Exception as e:
×
595
            self.logger.info(f"error parsing already saved cert: {e}")
×
596
            raise ErrorParsingSSLCert(e) from e
×
597
        try:
×
598
            x.check_date()
×
599
        except x509.CertificateError as e:
×
600
            self.logger.info(f"certificate has expired: {e}")
×
601
            os.unlink(self.cert_path)  # delete pinned cert only in this case
×
602
            return False
×
603
        self._verify_certificate_fingerprint(bytearray(b))
×
604
        return True
×
605

606
    async def _get_ssl_context(self):
5✔
607
        if self.protocol != 's':
×
608
            # using plaintext TCP
609
            return None
×
610

611
        # see if we already have cert for this server; or get it for the first time
612
        ca_sslc = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH, cafile=ca_path)
×
613
        if not self._is_saved_ssl_cert_available():
×
614
            try:
×
615
                await self._try_saving_ssl_cert_for_first_time(ca_sslc)
×
616
            except (OSError, ConnectError, aiorpcx.socks.SOCKSError) as e:
×
617
                raise ErrorGettingSSLCertFromServer(e) from e
×
618
        # now we have a file saved in our certificate store
619
        siz = os.stat(self.cert_path).st_size
×
620
        if siz == 0:
×
621
            # CA signed cert
622
            sslc = ca_sslc
×
623
        else:
624
            # pinned self-signed cert
625
            sslc = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH, cafile=self.cert_path)
×
626
            # note: Flag "ssl.VERIFY_X509_STRICT" is enabled by default in python 3.13+ (disabled in older versions).
627
            #       We explicitly disable it as it breaks lots of servers.
628
            sslc.verify_flags &= ~ssl.VERIFY_X509_STRICT
×
629
            sslc.check_hostname = False
×
630
        return sslc
×
631

632
    def handle_disconnect(func):
5✔
633
        @functools.wraps(func)
5✔
634
        async def wrapper_func(self: 'Interface', *args, **kwargs):
5✔
635
            try:
×
636
                return await func(self, *args, **kwargs)
×
637
            except GracefulDisconnect as e:
×
638
                self.logger.log(e.log_level, f"disconnecting due to {repr(e)}")
×
639
            except aiorpcx.jsonrpc.RPCError as e:
×
640
                self.logger.warning(f"disconnecting due to {repr(e)}")
×
641
                self.logger.debug(f"(disconnect) trace for {repr(e)}", exc_info=True)
×
642
            finally:
643
                self.got_disconnected.set()
×
644
                await self.network.connection_down(self)
×
645
                # if was not 'ready' yet, schedule waiting coroutines:
646
                self.ready.cancel()
×
647
        return wrapper_func
5✔
648

649
    @ignore_exceptions  # do not kill network.taskgroup
5✔
650
    @log_exceptions
5✔
651
    @handle_disconnect
5✔
652
    async def run(self):
5✔
653
        try:
×
654
            ssl_context = await self._get_ssl_context()
×
655
        except (ErrorParsingSSLCert, ErrorGettingSSLCertFromServer) as e:
×
656
            self.logger.info(f'disconnecting due to: {repr(e)}')
×
657
            return
×
658
        try:
×
659
            await self.open_session(ssl_context)
×
660
        except (asyncio.CancelledError, ConnectError, aiorpcx.socks.SOCKSError) as e:
×
661
            # make SSL errors for main interface more visible (to help servers ops debug cert pinning issues)
662
            if (isinstance(e, ConnectError) and isinstance(e.__cause__, ssl.SSLError)
×
663
                    and self.is_main_server() and not self.network.auto_connect):
664
                self.logger.warning(f'Cannot connect to main server due to SSL error '
×
665
                                    f'(maybe cert changed compared to "{self.cert_path}"). Exc: {repr(e)}')
666
            else:
667
                self.logger.info(f'disconnecting due to: {repr(e)}')
×
668
            return
×
669

670
    def _mark_ready(self) -> None:
5✔
671
        if self.ready.cancelled():
×
672
            raise GracefulDisconnect('conn establishment was too slow; *ready* future was cancelled')
×
673
        if self.ready.done():
×
674
            return
×
675

676
        assert self.tip_header
×
677
        chain = blockchain.check_header(self.tip_header)
×
678
        if not chain:
×
679
            self.blockchain = blockchain.get_best_chain()
×
680
        else:
681
            self.blockchain = chain
×
682
        assert self.blockchain is not None
×
683

684
        self.logger.info(f"set blockchain with height {self.blockchain.height()}")
×
685

686
        self.ready.set_result(1)
×
687

688
    def is_connected_and_ready(self) -> bool:
5✔
689
        return self.ready.done() and not self.got_disconnected.is_set()
×
690

691
    async def _save_certificate(self) -> None:
5✔
692
        if not os.path.exists(self.cert_path):
×
693
            # we may need to retry this a few times, in case the handshake hasn't completed
694
            for _ in range(10):
×
695
                dercert = await self._fetch_certificate()
×
696
                if dercert:
×
697
                    self.logger.info("succeeded in getting cert")
×
698
                    self._verify_certificate_fingerprint(dercert)
×
699
                    with open(self.cert_path, 'w') as f:
×
700
                        cert = ssl.DER_cert_to_PEM_cert(dercert)
×
701
                        # workaround android bug
702
                        cert = re.sub("([^\n])-----END CERTIFICATE-----","\\1\n-----END CERTIFICATE-----",cert)
×
703
                        f.write(cert)
×
704
                        # even though close flushes, we can't fsync when closed.
705
                        # and we must flush before fsyncing, cause flush flushes to OS buffer
706
                        # fsync writes to OS buffer to disk
707
                        f.flush()
×
708
                        os.fsync(f.fileno())
×
709
                    break
×
710
                await asyncio.sleep(1)
×
711
            else:
712
                raise GracefulDisconnect("could not get certificate after 10 tries")
×
713

714
    async def _fetch_certificate(self) -> bytes:
5✔
715
        sslc = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_CLIENT)
×
716
        sslc.check_hostname = False
×
717
        sslc.verify_mode = ssl.CERT_NONE
×
NEW
718
        async with _RSClient(
×
719
            session_factory=RPCSession,
720
            host=self.host, port=self.port,
721
            ssl=sslc,
722
            proxy=self.proxy,
723
            transport=PaddedRSTransport,
724
        ) as session:
725
            asyncio_transport = session.transport._asyncio_transport  # type: asyncio.BaseTransport
×
726
            ssl_object = asyncio_transport.get_extra_info("ssl_object")  # type: ssl.SSLObject
×
727
            return ssl_object.getpeercert(binary_form=True)
×
728

729
    def _get_expected_fingerprint(self) -> Optional[str]:
5✔
730
        if self.is_main_server():
×
731
            return self.network.config.NETWORK_SERVERFINGERPRINT
×
732

733
    def _verify_certificate_fingerprint(self, certificate):
5✔
734
        expected_fingerprint = self._get_expected_fingerprint()
×
735
        if not expected_fingerprint:
×
736
            return
×
737
        fingerprint = hashlib.sha256(certificate).hexdigest()
×
738
        fingerprints_match = fingerprint.lower() == expected_fingerprint.lower()
×
739
        if not fingerprints_match:
×
740
            util.trigger_callback('cert_mismatch')
×
741
            raise ErrorSSLCertFingerprintMismatch('Refusing to connect to server due to cert fingerprint mismatch')
×
742
        self.logger.info("cert fingerprint verification passed")
×
743

744
    async def get_block_header(self, height, assert_mode):
5✔
745
        if not is_non_negative_integer(height):
×
746
            raise Exception(f"{repr(height)} is not a block height")
×
747
        self.logger.info(f'requesting block header {height} in mode {assert_mode}')
×
748
        # use lower timeout as we usually have network.bhi_lock here
749
        timeout = self.network.get_network_timeout_seconds(NetworkTimeout.Urgent)
×
750
        res = await self.session.send_request('blockchain.block.header', [height], timeout=timeout)
×
751
        return blockchain.deserialize_header(bytes.fromhex(res), height)
×
752

753
    async def request_chunk(self, height: int, tip=None, *, can_return_early=False):
5✔
754
        if not is_non_negative_integer(height):
×
755
            raise Exception(f"{repr(height)} is not a block height")
×
756
        index = height // 2016
×
757
        if can_return_early and index in self._requested_chunks:
×
758
            return
×
759
        self.logger.info(f"requesting chunk from height {height}")
×
760
        size = 2016
×
761
        if tip is not None:
×
762
            size = min(size, tip - index * 2016 + 1)
×
763
            size = max(size, 0)
×
764
        try:
×
765
            self._requested_chunks.add(index)
×
766
            res = await self.session.send_request('blockchain.block.headers', [index * 2016, size])
×
767
        finally:
768
            self._requested_chunks.discard(index)
×
769
        assert_dict_contains_field(res, field_name='count')
×
770
        assert_dict_contains_field(res, field_name='hex')
×
771
        assert_dict_contains_field(res, field_name='max')
×
772
        assert_non_negative_integer(res['count'])
×
773
        assert_non_negative_integer(res['max'])
×
774
        assert_hex_str(res['hex'])
×
775
        if len(res['hex']) != HEADER_SIZE * 2 * res['count']:
×
776
            raise RequestCorrupted('inconsistent chunk hex and count')
×
777
        # we never request more than 2016 headers, but we enforce those fit in a single response
778
        if res['max'] < 2016:
×
779
            raise RequestCorrupted(f"server uses too low 'max' count for block.headers: {res['max']} < 2016")
×
780
        if res['count'] != size:
×
781
            raise RequestCorrupted(f"expected {size} headers but only got {res['count']}")
×
782
        conn = self.blockchain.connect_chunk(index, res['hex'])
×
783
        if not conn:
×
784
            return conn, 0
×
785
        return conn, res['count']
×
786

787
    def is_main_server(self) -> bool:
5✔
788
        return (self.network.interface == self or
×
789
                self.network.interface is None and self.network.default_server == self.server)
790

791
    async def open_session(self, sslc, exit_early=False):
5✔
792
        session_factory = lambda *args, iface=self, **kwargs: NotificationSession(*args, **kwargs, interface=iface)
×
NEW
793
        async with _RSClient(
×
794
            session_factory=session_factory,
795
            host=self.host, port=self.port,
796
            ssl=sslc,
797
            proxy=self.proxy,
798
            transport=PaddedRSTransport,
799
        ) as session:
800
            self.session = session  # type: NotificationSession
×
801
            self.session.set_default_timeout(self.network.get_network_timeout_seconds(NetworkTimeout.Generic))
×
802
            try:
×
803
                ver = await session.send_request('server.version', [self.client_name(), version.PROTOCOL_VERSION])
×
804
            except aiorpcx.jsonrpc.RPCError as e:
×
805
                raise GracefulDisconnect(e)  # probably 'unsupported protocol version'
×
806
            if exit_early:
×
807
                return
×
808
            if ver[1] != version.PROTOCOL_VERSION:
×
809
                raise GracefulDisconnect(f'server violated protocol-version-negotiation. '
×
810
                                         f'we asked for {version.PROTOCOL_VERSION!r}, they sent {ver[1]!r}')
811
            if not self.network.check_interface_against_healthy_spread_of_connected_servers(self):
×
812
                raise GracefulDisconnect(f'too many connected servers already '
×
813
                                         f'in bucket {self.bucket_based_on_ipaddress()}')
814
            self.logger.info(f"connection established. version: {ver}")
×
815

816
            try:
×
817
                async with self.taskgroup as group:
×
818
                    await group.spawn(self.ping)
×
819
                    await group.spawn(self.request_fee_estimates)
×
820
                    await group.spawn(self.run_fetch_blocks)
×
821
                    await group.spawn(self.monitor_connection)
×
822
            except aiorpcx.jsonrpc.RPCError as e:
×
823
                if e.code in (
×
824
                    JSONRPC.EXCESSIVE_RESOURCE_USAGE,
825
                    JSONRPC.SERVER_BUSY,
826
                    JSONRPC.METHOD_NOT_FOUND,
827
                    JSONRPC.INTERNAL_ERROR,
828
                ):
829
                    log_level = logging.WARNING if self.is_main_server() else logging.INFO
×
830
                    raise GracefulDisconnect(e, log_level=log_level) from e
×
831
                raise
×
832
            finally:
833
                self.got_disconnected.set()  # set this ASAP, ideally before any awaits
×
834

835
    async def monitor_connection(self):
5✔
836
        while True:
×
837
            await asyncio.sleep(1)
×
838
            # If the session/transport is no longer open, we disconnect.
839
            # e.g. if the remote cleanly sends EOF, we would handle that here.
840
            # note: If the user pulls the ethernet cable or disconnects wifi,
841
            #       ideally we would detect that here, so that the GUI/etc can reflect that.
842
            #       - On Android, this seems to work reliably , where asyncio.BaseProtocol.connection_lost()
843
            #         gets called with e.g. ConnectionAbortedError(103, 'Software caused connection abort').
844
            #       - On desktop Linux/Win, it seems BaseProtocol.connection_lost() is not called in such cases.
845
            #         Hence, in practice the connection issue will only be detected the next time we try
846
            #         to send a message (plus timeout), which can take minutes...
847
            if not self.session or self.session.is_closing():
×
848
                raise GracefulDisconnect('session was closed')
×
849

850
    async def ping(self):
5✔
851
        # We periodically send a "ping" msg to make sure the server knows we are still here.
852
        # Adding a bit of randomness generates some noise against traffic analysis.
853
        while True:
×
NEW
854
            await asyncio.sleep(random.random() * 300)
×
NEW
855
            await self.session.send_request('server.ping')
×
NEW
856
            await self._maybe_send_noise()
×
857

858
    async def _maybe_send_noise(self):
5✔
NEW
859
        while random.random() < 0.2:
×
NEW
860
            await asyncio.sleep(random.random())
×
UNCOV
861
            await self.session.send_request('server.ping')
×
862

863
    async def request_fee_estimates(self):
5✔
864
        while True:
×
865
            async with OldTaskGroup() as group:
×
866
                fee_tasks = []
×
867
                for i in FEE_ETA_TARGETS[0:-1]:
×
868
                    fee_tasks.append((i, await group.spawn(self.get_estimatefee(i))))
×
869
            for nblock_target, task in fee_tasks:
×
870
                fee = task.result()
×
871
                if fee < 0: continue
×
872
                assert isinstance(fee, int)
×
873
                self.fee_estimates_eta[nblock_target] = fee
×
874
            self.network.update_fee_estimates()
×
875
            await asyncio.sleep(60)
×
876

877
    async def close(self, *, force_after: int = None):
5✔
878
        """Closes the connection and waits for it to be closed.
879
        We try to flush buffered data to the wire, which can take some time.
880
        """
881
        if self.session:
×
882
            await self.session.close(force_after=force_after)
×
883
        # monitor_connection will cancel tasks
884

885
    async def run_fetch_blocks(self):
5✔
886
        header_queue = asyncio.Queue()
×
887
        await self.session.subscribe('blockchain.headers.subscribe', [], header_queue)
×
888
        while True:
×
889
            item = await header_queue.get()
×
890
            raw_header = item[0]
×
891
            height = raw_header['height']
×
892
            header = blockchain.deserialize_header(bfh(raw_header['hex']), height)
×
893
            self.tip_header = header
×
894
            self.tip = height
×
895
            if self.tip < constants.net.max_checkpoint():
×
896
                raise GracefulDisconnect('server tip below max checkpoint')
×
897
            self._mark_ready()
×
898
            blockchain_updated = await self._process_header_at_tip()
×
899
            # header processing done
900
            if self.is_main_server():
×
901
                self.logger.info(f"new chain tip on main interface. {height=}")
×
902
            if blockchain_updated:
×
903
                util.trigger_callback('blockchain_updated')
×
904
            util.trigger_callback('network_updated')
×
905
            await self.network.switch_unwanted_fork_interface()
×
906
            await self.network.switch_lagging_interface()
×
NEW
907
            await self.taskgroup.spawn(self._maybe_send_noise())
×
908

909
    async def _process_header_at_tip(self) -> bool:
5✔
910
        """Returns:
911
        False - boring fast-forward: we already have this header as part of this blockchain from another interface,
912
        True - new header we didn't have, or reorg
913
        """
914
        height, header = self.tip, self.tip_header
×
915
        async with self.network.bhi_lock:
×
916
            if self.blockchain.height() >= height and self.blockchain.check_header(header):
×
917
                # another interface amended the blockchain
918
                return False
×
919
            _, height = await self.step(height, header)
×
920
            # in the simple case, height == self.tip+1
921
            if height <= self.tip:
×
922
                await self.sync_until(height)
×
923
            return True
×
924

925
    async def sync_until(self, height, next_height=None):
5✔
926
        if next_height is None:
5✔
927
            next_height = self.tip
×
928
        last = None
5✔
929
        while last is None or height <= next_height:
5✔
930
            prev_last, prev_height = last, height
5✔
931
            if next_height > height + 10:
5✔
932
                could_connect, num_headers = await self.request_chunk(height, next_height)
×
933
                if not could_connect:
×
934
                    if height <= constants.net.max_checkpoint():
×
935
                        raise GracefulDisconnect('server chain conflicts with checkpoints or genesis')
×
936
                    last, height = await self.step(height)
×
937
                    continue
×
938
                util.trigger_callback('blockchain_updated')
×
939
                util.trigger_callback('network_updated')
×
940
                height = (height // 2016 * 2016) + num_headers
×
941
                assert height <= next_height+1, (height, self.tip)
×
942
                last = 'catchup'
×
943
            else:
944
                last, height = await self.step(height)
5✔
945
            assert (prev_last, prev_height) != (last, height), 'had to prevent infinite loop in interface.sync_until'
5✔
946
        return last, height
5✔
947

948
    async def step(self, height, header=None):
5✔
949
        assert 0 <= height <= self.tip, (height, self.tip)
5✔
950
        if header is None:
5✔
951
            header = await self.get_block_header(height, 'catchup')
5✔
952

953
        chain = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header)
5✔
954
        if chain:
5✔
955
            self.blockchain = chain if isinstance(chain, Blockchain) else self.blockchain
×
956
            # note: there is an edge case here that is not handled.
957
            # we might know the blockhash (enough for check_header) but
958
            # not have the header itself. e.g. regtest chain with only genesis.
959
            # this situation resolves itself on the next block
960
            return 'catchup', height+1
×
961

962
        can_connect = blockchain.can_connect(header) if 'mock' not in header else header['mock']['connect'](height)
5✔
963
        if not can_connect:
5✔
964
            self.logger.info(f"can't connect new block: {height=}")
5✔
965
            height, header, bad, bad_header = await self._search_headers_backwards(height, header)
5✔
966
            chain = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header)
5✔
967
            can_connect = blockchain.can_connect(header) if 'mock' not in header else header['mock']['connect'](height)
5✔
968
            assert chain or can_connect
5✔
969
        if can_connect:
5✔
970
            self.logger.info(f"new block: {height=}")
5✔
971
            height += 1
5✔
972
            if isinstance(can_connect, Blockchain):  # not when mocking
5✔
973
                self.blockchain = can_connect
×
974
                self.blockchain.save_header(header)
×
975
            return 'catchup', height
5✔
976

977
        good, bad, bad_header = await self._search_headers_binary(height, bad, bad_header, chain)
5✔
978
        return await self._resolve_potential_chain_fork_given_forkpoint(good, bad, bad_header)
5✔
979

980
    async def _search_headers_binary(self, height, bad, bad_header, chain):
5✔
981
        assert bad == bad_header['block_height']
5✔
982
        _assert_header_does_not_check_against_any_chain(bad_header)
5✔
983

984
        self.blockchain = chain if isinstance(chain, Blockchain) else self.blockchain
5✔
985
        good = height
5✔
986
        while True:
5✔
987
            assert good < bad, (good, bad)
5✔
988
            height = (good + bad) // 2
5✔
989
            self.logger.info(f"binary step. good {good}, bad {bad}, height {height}")
5✔
990
            header = await self.get_block_header(height, 'binary')
5✔
991
            chain = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header)
5✔
992
            if chain:
5✔
993
                self.blockchain = chain if isinstance(chain, Blockchain) else self.blockchain
5✔
994
                good = height
5✔
995
            else:
996
                bad = height
5✔
997
                bad_header = header
5✔
998
            if good + 1 == bad:
5✔
999
                break
5✔
1000

1001
        mock = 'mock' in bad_header and bad_header['mock']['connect'](height)
5✔
1002
        real = not mock and self.blockchain.can_connect(bad_header, check_height=False)
5✔
1003
        if not real and not mock:
5✔
1004
            raise Exception('unexpected bad header during binary: {}'.format(bad_header))
×
1005
        _assert_header_does_not_check_against_any_chain(bad_header)
5✔
1006

1007
        self.logger.info(f"binary search exited. good {good}, bad {bad}")
5✔
1008
        return good, bad, bad_header
5✔
1009

1010
    async def _resolve_potential_chain_fork_given_forkpoint(self, good, bad, bad_header):
5✔
1011
        assert good + 1 == bad
5✔
1012
        assert bad == bad_header['block_height']
5✔
1013
        _assert_header_does_not_check_against_any_chain(bad_header)
5✔
1014
        # 'good' is the height of a block 'good_header', somewhere in self.blockchain.
1015
        # bad_header connects to good_header; bad_header itself is NOT in self.blockchain.
1016

1017
        bh = self.blockchain.height()
5✔
1018
        assert bh >= good, (bh, good)
5✔
1019
        if bh == good:
5✔
1020
            height = good + 1
×
1021
            self.logger.info(f"catching up from {height}")
×
1022
            return 'no_fork', height
×
1023

1024
        # this is a new fork we don't yet have
1025
        height = bad + 1
5✔
1026
        self.logger.info(f"new fork at bad height {bad}")
5✔
1027
        forkfun = self.blockchain.fork if 'mock' not in bad_header else bad_header['mock']['fork']
5✔
1028
        b = forkfun(bad_header)  # type: Blockchain
5✔
1029
        self.blockchain = b
5✔
1030
        assert b.forkpoint == bad
5✔
1031
        return 'fork', height
5✔
1032

1033
    async def _search_headers_backwards(self, height, header):
5✔
1034
        async def iterate():
5✔
1035
            nonlocal height, header
1036
            checkp = False
5✔
1037
            if height <= constants.net.max_checkpoint():
5✔
1038
                height = constants.net.max_checkpoint()
×
1039
                checkp = True
×
1040
            header = await self.get_block_header(height, 'backward')
5✔
1041
            chain = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header)
5✔
1042
            can_connect = blockchain.can_connect(header) if 'mock' not in header else header['mock']['connect'](height)
5✔
1043
            if chain or can_connect:
5✔
1044
                return False
5✔
1045
            if checkp:
5✔
1046
                raise GracefulDisconnect("server chain conflicts with checkpoints")
×
1047
            return True
5✔
1048

1049
        bad, bad_header = height, header
5✔
1050
        _assert_header_does_not_check_against_any_chain(bad_header)
5✔
1051
        with blockchain.blockchains_lock: chains = list(blockchain.blockchains.values())
5✔
1052
        local_max = max([0] + [x.height() for x in chains]) if 'mock' not in header else float('inf')
5✔
1053
        height = min(local_max + 1, height - 1)
5✔
1054
        while await iterate():
5✔
1055
            bad, bad_header = height, header
5✔
1056
            delta = self.tip - height
5✔
1057
            height = self.tip - 2 * delta
5✔
1058

1059
        _assert_header_does_not_check_against_any_chain(bad_header)
5✔
1060
        self.logger.info(f"exiting backward mode at {height}")
5✔
1061
        return height, header, bad, bad_header
5✔
1062

1063
    @classmethod
5✔
1064
    def client_name(cls) -> str:
5✔
1065
        return f'electrum/{version.ELECTRUM_VERSION}'
×
1066

1067
    def is_tor(self):
5✔
1068
        return self.host.endswith('.onion')
×
1069

1070
    def ip_addr(self) -> Optional[str]:
5✔
1071
        session = self.session
×
1072
        if not session: return None
×
1073
        peer_addr = session.remote_address()
×
1074
        if not peer_addr: return None
×
1075
        return str(peer_addr.host)
×
1076

1077
    def bucket_based_on_ipaddress(self) -> str:
5✔
1078
        def do_bucket():
×
1079
            if self.is_tor():
×
1080
                return BUCKET_NAME_OF_ONION_SERVERS
×
1081
            try:
×
1082
                ip_addr = ip_address(self.ip_addr())  # type: Union[IPv4Address, IPv6Address]
×
1083
            except ValueError:
×
1084
                return ''
×
1085
            if not ip_addr:
×
1086
                return ''
×
1087
            if ip_addr.is_loopback:  # localhost is exempt
×
1088
                return ''
×
1089
            if ip_addr.version == 4:
×
1090
                slash16 = IPv4Network(ip_addr).supernet(prefixlen_diff=32-16)
×
1091
                return str(slash16)
×
1092
            elif ip_addr.version == 6:
×
1093
                slash48 = IPv6Network(ip_addr).supernet(prefixlen_diff=128-48)
×
1094
                return str(slash48)
×
1095
            return ''
×
1096

1097
        if not self._ipaddr_bucket:
×
1098
            self._ipaddr_bucket = do_bucket()
×
1099
        return self._ipaddr_bucket
×
1100

1101
    async def get_merkle_for_transaction(self, tx_hash: str, tx_height: int) -> dict:
5✔
1102
        if not is_hash256_str(tx_hash):
×
1103
            raise Exception(f"{repr(tx_hash)} is not a txid")
×
1104
        if not is_non_negative_integer(tx_height):
×
1105
            raise Exception(f"{repr(tx_height)} is not a block height")
×
1106
        # do request
1107
        res = await self.session.send_request('blockchain.transaction.get_merkle', [tx_hash, tx_height])
×
1108
        # check response
1109
        block_height = assert_dict_contains_field(res, field_name='block_height')
×
1110
        merkle = assert_dict_contains_field(res, field_name='merkle')
×
1111
        pos = assert_dict_contains_field(res, field_name='pos')
×
1112
        # note: tx_height was just a hint to the server, don't enforce the response to match it
1113
        assert_non_negative_integer(block_height)
×
1114
        assert_non_negative_integer(pos)
×
1115
        assert_list_or_tuple(merkle)
×
1116
        for item in merkle:
×
1117
            assert_hash256_str(item)
×
1118
        return res
×
1119

1120
    async def get_transaction(self, tx_hash: str, *, timeout=None) -> str:
5✔
1121
        if not is_hash256_str(tx_hash):
×
1122
            raise Exception(f"{repr(tx_hash)} is not a txid")
×
1123
        raw = await self.session.send_request('blockchain.transaction.get', [tx_hash], timeout=timeout)
×
1124
        # validate response
1125
        if not is_hex_str(raw):
×
1126
            raise RequestCorrupted(f"received garbage (non-hex) as tx data (txid {tx_hash}): {raw!r}")
×
1127
        tx = Transaction(raw)
×
1128
        try:
×
1129
            tx.deserialize()  # see if raises
×
1130
        except Exception as e:
×
1131
            raise RequestCorrupted(f"cannot deserialize received transaction (txid {tx_hash})") from e
×
1132
        if tx.txid() != tx_hash:
×
1133
            raise RequestCorrupted(f"received tx does not match expected txid {tx_hash} (got {tx.txid()})")
×
1134
        return raw
×
1135

1136
    async def get_history_for_scripthash(self, sh: str) -> List[dict]:
5✔
1137
        if not is_hash256_str(sh):
×
1138
            raise Exception(f"{repr(sh)} is not a scripthash")
×
1139
        # do request
1140
        res = await self.session.send_request('blockchain.scripthash.get_history', [sh])
×
1141
        # check response
1142
        assert_list_or_tuple(res)
×
1143
        prev_height = 1
×
1144
        for tx_item in res:
×
1145
            height = assert_dict_contains_field(tx_item, field_name='height')
×
1146
            assert_dict_contains_field(tx_item, field_name='tx_hash')
×
1147
            assert_integer(height)
×
1148
            assert_hash256_str(tx_item['tx_hash'])
×
1149
            if height in (-1, 0):
×
1150
                assert_dict_contains_field(tx_item, field_name='fee')
×
1151
                assert_non_negative_integer(tx_item['fee'])
×
1152
                prev_height = float("inf")  # this ensures confirmed txs can't follow mempool txs
×
1153
            else:
1154
                # check monotonicity of heights
1155
                if height < prev_height:
×
1156
                    raise RequestCorrupted(f'heights of confirmed txs must be in increasing order')
×
1157
                prev_height = height
×
1158
        hashes = set(map(lambda item: item['tx_hash'], res))
×
1159
        if len(hashes) != len(res):
×
1160
            # Either server is sending garbage... or maybe if server is race-prone
1161
            # a recently mined tx could be included in both last block and mempool?
1162
            # Still, it's simplest to just disregard the response.
1163
            raise RequestCorrupted(f"server history has non-unique txids for sh={sh}")
×
1164
        return res
×
1165

1166
    async def listunspent_for_scripthash(self, sh: str) -> List[dict]:
5✔
1167
        if not is_hash256_str(sh):
×
1168
            raise Exception(f"{repr(sh)} is not a scripthash")
×
1169
        # do request
1170
        res = await self.session.send_request('blockchain.scripthash.listunspent', [sh])
×
1171
        # check response
1172
        assert_list_or_tuple(res)
×
1173
        for utxo_item in res:
×
1174
            assert_dict_contains_field(utxo_item, field_name='tx_pos')
×
1175
            assert_dict_contains_field(utxo_item, field_name='value')
×
1176
            assert_dict_contains_field(utxo_item, field_name='tx_hash')
×
1177
            assert_dict_contains_field(utxo_item, field_name='height')
×
1178
            assert_non_negative_integer(utxo_item['tx_pos'])
×
1179
            assert_non_negative_integer(utxo_item['value'])
×
1180
            assert_non_negative_integer(utxo_item['height'])
×
1181
            assert_hash256_str(utxo_item['tx_hash'])
×
1182
        return res
×
1183

1184
    async def get_balance_for_scripthash(self, sh: str) -> dict:
5✔
1185
        if not is_hash256_str(sh):
×
1186
            raise Exception(f"{repr(sh)} is not a scripthash")
×
1187
        # do request
1188
        res = await self.session.send_request('blockchain.scripthash.get_balance', [sh])
×
1189
        # check response
1190
        assert_dict_contains_field(res, field_name='confirmed')
×
1191
        assert_dict_contains_field(res, field_name='unconfirmed')
×
1192
        assert_non_negative_integer(res['confirmed'])
×
1193
        assert_integer(res['unconfirmed'])
×
1194
        return res
×
1195

1196
    async def get_txid_from_txpos(self, tx_height: int, tx_pos: int, merkle: bool):
5✔
1197
        if not is_non_negative_integer(tx_height):
×
1198
            raise Exception(f"{repr(tx_height)} is not a block height")
×
1199
        if not is_non_negative_integer(tx_pos):
×
1200
            raise Exception(f"{repr(tx_pos)} should be non-negative integer")
×
1201
        # do request
1202
        res = await self.session.send_request(
×
1203
            'blockchain.transaction.id_from_pos',
1204
            [tx_height, tx_pos, merkle],
1205
        )
1206
        # check response
1207
        if merkle:
×
1208
            assert_dict_contains_field(res, field_name='tx_hash')
×
1209
            assert_dict_contains_field(res, field_name='merkle')
×
1210
            assert_hash256_str(res['tx_hash'])
×
1211
            assert_list_or_tuple(res['merkle'])
×
1212
            for node_hash in res['merkle']:
×
1213
                assert_hash256_str(node_hash)
×
1214
        else:
1215
            assert_hash256_str(res)
×
1216
        return res
×
1217

1218
    async def get_fee_histogram(self) -> Sequence[Tuple[Union[float, int], int]]:
5✔
1219
        # do request
1220
        res = await self.session.send_request('mempool.get_fee_histogram')
×
1221
        # check response
1222
        assert_list_or_tuple(res)
×
1223
        prev_fee = float('inf')
×
1224
        for fee, s in res:
×
1225
            assert_non_negative_int_or_float(fee)
×
1226
            assert_non_negative_integer(s)
×
1227
            if fee >= prev_fee:  # check monotonicity
×
1228
                raise RequestCorrupted(f'fees must be in decreasing order')
×
1229
            prev_fee = fee
×
1230
        return res
×
1231

1232
    async def get_server_banner(self) -> str:
5✔
1233
        # do request
1234
        res = await self.session.send_request('server.banner')
×
1235
        # check response
1236
        if not isinstance(res, str):
×
1237
            raise RequestCorrupted(f'{res!r} should be a str')
×
1238
        return res
×
1239

1240
    async def get_donation_address(self) -> str:
5✔
1241
        # do request
1242
        res = await self.session.send_request('server.donation_address')
×
1243
        # check response
1244
        if not res:  # ignore empty string
×
1245
            return ''
×
1246
        if not bitcoin.is_address(res):
×
1247
            # note: do not hard-fail -- allow server to use future-type
1248
            #       bitcoin address we do not recognize
1249
            self.logger.info(f"invalid donation address from server: {repr(res)}")
×
1250
            res = ''
×
1251
        return res
×
1252

1253
    async def get_relay_fee(self) -> int:
5✔
1254
        """Returns the min relay feerate in sat/kbyte."""
1255
        # do request
1256
        res = await self.session.send_request('blockchain.relayfee')
×
1257
        # check response
1258
        assert_non_negative_int_or_float(res)
×
1259
        relayfee = int(res * bitcoin.COIN)
×
1260
        relayfee = max(0, relayfee)
×
1261
        return relayfee
×
1262

1263
    async def get_estimatefee(self, num_blocks: int) -> int:
5✔
1264
        """Returns a feerate estimate for getting confirmed within
1265
        num_blocks blocks, in sat/kbyte.
1266
        Returns -1 if the server could not provide an estimate.
1267
        """
1268
        if not is_non_negative_integer(num_blocks):
×
1269
            raise Exception(f"{repr(num_blocks)} is not a num_blocks")
×
1270
        # do request
1271
        try:
×
1272
            res = await self.session.send_request('blockchain.estimatefee', [num_blocks])
×
1273
        except aiorpcx.jsonrpc.ProtocolError as e:
×
1274
            # The protocol spec says the server itself should already have returned -1
1275
            # if it cannot provide an estimate, however apparently "electrs" does not conform
1276
            # and sends an error instead. Convert it here:
1277
            if "cannot estimate fee" in e.message:
×
1278
                res = -1
×
1279
            else:
1280
                raise
×
1281
        except aiorpcx.jsonrpc.RPCError as e:
×
1282
            # The protocol spec says the server itself should already have returned -1
1283
            # if it cannot provide an estimate. "Fulcrum" often sends:
1284
            #   aiorpcx.jsonrpc.RPCError: (-32603, 'internal error: bitcoind request timed out')
1285
            if e.code == JSONRPC.INTERNAL_ERROR:
×
1286
                res = -1
×
1287
            else:
1288
                raise
×
1289
        # check response
1290
        if res != -1:
×
1291
            assert_non_negative_int_or_float(res)
×
1292
            res = int(res * bitcoin.COIN)
×
1293
        return res
×
1294

1295

1296
def _assert_header_does_not_check_against_any_chain(header: dict) -> None:
5✔
1297
    chain_bad = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header)
5✔
1298
    if chain_bad:
5✔
1299
        raise Exception('bad_header must not check!')
×
1300

1301

1302
def check_cert(host, cert):
5✔
1303
    try:
×
1304
        b = pem.dePem(cert, 'CERTIFICATE')
×
1305
        x = x509.X509(b)
×
1306
    except Exception:
×
1307
        traceback.print_exc(file=sys.stdout)
×
1308
        return
×
1309

1310
    try:
×
1311
        x.check_date()
×
1312
        expired = False
×
1313
    except Exception:
×
1314
        expired = True
×
1315

1316
    m = "host: %s\n"%host
×
1317
    m += "has_expired: %s\n"% expired
×
1318
    util.print_msg(m)
×
1319

1320

1321
# Used by tests
1322
def _match_hostname(name, val):
5✔
1323
    if val == name:
×
1324
        return True
×
1325

1326
    return val.startswith('*.') and name.endswith(val[1:])
×
1327

1328

1329
def test_certificates():
5✔
1330
    from .simple_config import SimpleConfig
×
1331
    config = SimpleConfig()
×
1332
    mydir = os.path.join(config.path, "certs")
×
1333
    certs = os.listdir(mydir)
×
1334
    for c in certs:
×
1335
        p = os.path.join(mydir,c)
×
1336
        with open(p, encoding='utf-8') as f:
×
1337
            cert = f.read()
×
1338
        check_cert(c, cert)
×
1339

1340
if __name__ == "__main__":
5✔
1341
    test_certificates()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc