• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

spesmilo / electrum / 5926929871667200

28 May 2025 04:07PM UTC coverage: 59.441% (-0.09%) from 59.527%
5926929871667200

Pull #9875

CirrusCI

SomberNight
interface: add padding and some noise to protocol messages

basic countermeasures against traffic analysis
Pull Request #9875: interface: add padding and some noise to protocol messages

14 of 78 new or added lines in 1 file covered. (17.95%)

4 existing lines in 4 files now uncovered.

21690 of 36490 relevant lines covered (59.44%)

2.38 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

35.64
/electrum/interface.py
1
#!/usr/bin/env python
2
#
3
# Electrum - lightweight Bitcoin client
4
# Copyright (C) 2011 thomasv@gitorious
5
#
6
# Permission is hereby granted, free of charge, to any person
7
# obtaining a copy of this software and associated documentation files
8
# (the "Software"), to deal in the Software without restriction,
9
# including without limitation the rights to use, copy, modify, merge,
10
# publish, distribute, sublicense, and/or sell copies of the Software,
11
# and to permit persons to whom the Software is furnished to do so,
12
# subject to the following conditions:
13
#
14
# The above copyright notice and this permission notice shall be
15
# included in all copies or substantial portions of the Software.
16
#
17
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
18
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
19
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
20
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
21
# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
22
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
23
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
# SOFTWARE.
25
import os
4✔
26
import re
4✔
27
import ssl
4✔
28
import sys
4✔
29
import time
4✔
30
import traceback
4✔
31
import asyncio
4✔
32
import socket
4✔
33
from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set, NamedTuple, Any, Sequence, Dict
4✔
34
from collections import defaultdict
4✔
35
from ipaddress import IPv4Network, IPv6Network, ip_address, IPv6Address, IPv4Address
4✔
36
import itertools
4✔
37
import logging
4✔
38
import hashlib
4✔
39
import functools
4✔
40
import random
4✔
41

42
import aiorpcx
4✔
43
from aiorpcx import RPCSession, Notification, NetAddress, NewlineFramer
4✔
44
from aiorpcx.curio import timeout_after, TaskTimeout
4✔
45
from aiorpcx.jsonrpc import JSONRPC, CodeMessageError
4✔
46
from aiorpcx.rawsocket import RSClient, RSTransport
4✔
47
import certifi
4✔
48

49
from .util import (ignore_exceptions, log_exceptions, bfh, ESocksProxy,
4✔
50
                   is_integer, is_non_negative_integer, is_hash256_str, is_hex_str,
51
                   is_int_or_float, is_non_negative_int_or_float, OldTaskGroup)
52
from . import util
4✔
53
from . import x509
4✔
54
from . import pem
4✔
55
from . import version
4✔
56
from . import blockchain
4✔
57
from .blockchain import Blockchain, HEADER_SIZE
4✔
58
from . import bitcoin
4✔
59
from . import constants
4✔
60
from .i18n import _
4✔
61
from .logging import Logger
4✔
62
from .transaction import Transaction
4✔
63
from .fee_policy import FEE_ETA_TARGETS
4✔
64

65
if TYPE_CHECKING:
4✔
66
    from .network import Network
×
67
    from .simple_config import SimpleConfig
×
68

69

70
ca_path = certifi.where()
4✔
71

72
BUCKET_NAME_OF_ONION_SERVERS = 'onion'
4✔
73

74
_KNOWN_NETWORK_PROTOCOLS = {'t', 's'}
4✔
75
PREFERRED_NETWORK_PROTOCOL = 's'
4✔
76
assert PREFERRED_NETWORK_PROTOCOL in _KNOWN_NETWORK_PROTOCOLS
4✔
77

78

79
class NetworkTimeout:
4✔
80
    # seconds
81
    class Generic:
4✔
82
        NORMAL = 30
4✔
83
        RELAXED = 45
4✔
84
        MOST_RELAXED = 600
4✔
85

86
    class Urgent(Generic):
4✔
87
        NORMAL = 10
4✔
88
        RELAXED = 20
4✔
89
        MOST_RELAXED = 60
4✔
90

91

92
def assert_non_negative_integer(val: Any) -> None:
4✔
93
    if not is_non_negative_integer(val):
×
94
        raise RequestCorrupted(f'{val!r} should be a non-negative integer')
×
95

96

97
def assert_integer(val: Any) -> None:
4✔
98
    if not is_integer(val):
×
99
        raise RequestCorrupted(f'{val!r} should be an integer')
×
100

101

102
def assert_int_or_float(val: Any) -> None:
4✔
103
    if not is_int_or_float(val):
×
104
        raise RequestCorrupted(f'{val!r} should be int or float')
×
105

106

107
def assert_non_negative_int_or_float(val: Any) -> None:
4✔
108
    if not is_non_negative_int_or_float(val):
×
109
        raise RequestCorrupted(f'{val!r} should be a non-negative int or float')
×
110

111

112
def assert_hash256_str(val: Any) -> None:
4✔
113
    if not is_hash256_str(val):
×
114
        raise RequestCorrupted(f'{val!r} should be a hash256 str')
×
115

116

117
def assert_hex_str(val: Any) -> None:
4✔
118
    if not is_hex_str(val):
×
119
        raise RequestCorrupted(f'{val!r} should be a hex str')
×
120

121

122
def assert_dict_contains_field(d: Any, *, field_name: str) -> Any:
4✔
123
    if not isinstance(d, dict):
×
124
        raise RequestCorrupted(f'{d!r} should be a dict')
×
125
    if field_name not in d:
×
126
        raise RequestCorrupted(f'required field {field_name!r} missing from dict')
×
127
    return d[field_name]
×
128

129

130
def assert_list_or_tuple(val: Any) -> None:
4✔
131
    if not isinstance(val, (list, tuple)):
×
132
        raise RequestCorrupted(f'{val!r} should be a list or tuple')
×
133

134

135
class NotificationSession(RPCSession):
4✔
136

137
    def __init__(self, *args, interface: 'Interface', **kwargs):
4✔
138
        super(NotificationSession, self).__init__(*args, **kwargs)
×
139
        self.subscriptions = defaultdict(list)
×
140
        self.cache = {}
×
141
        self._msg_counter = itertools.count(start=1)
×
142
        self.interface = interface
×
143
        self.cost_hard_limit = 0  # disable aiorpcx resource limits
×
144

145
    async def handle_request(self, request):
4✔
146
        self.maybe_log(f"--> {request}")
×
147
        try:
×
148
            if isinstance(request, Notification):
×
149
                params, result = request.args[:-1], request.args[-1]
×
150
                key = self.get_hashable_key_for_rpc_call(request.method, params)
×
151
                if key in self.subscriptions:
×
152
                    self.cache[key] = result
×
153
                    for queue in self.subscriptions[key]:
×
154
                        await queue.put(request.args)
×
155
                else:
156
                    raise Exception(f'unexpected notification')
×
157
            else:
158
                raise Exception(f'unexpected request. not a notification')
×
159
        except Exception as e:
×
160
            self.interface.logger.info(f"error handling request {request}. exc: {repr(e)}")
×
161
            await self.close()
×
162

163
    async def send_request(self, *args, timeout=None, **kwargs):
4✔
164
        # note: semaphores/timeouts/backpressure etc are handled by
165
        # aiorpcx. the timeout arg here in most cases should not be set
166
        msg_id = next(self._msg_counter)
×
167
        self.maybe_log(f"<-- {args} {kwargs} (id: {msg_id})")
×
168
        try:
×
169
            # note: RPCSession.send_request raises TaskTimeout in case of a timeout.
170
            # TaskTimeout is a subclass of CancelledError, which is *suppressed* in TaskGroups
171
            response = await util.wait_for2(
×
172
                super().send_request(*args, **kwargs),
173
                timeout)
174
        except (TaskTimeout, asyncio.TimeoutError) as e:
×
175
            self.maybe_log(f"--> request timed out: {args} (id: {msg_id})")
×
176
            raise RequestTimedOut(f'request timed out: {args} (id: {msg_id})') from e
×
177
        except CodeMessageError as e:
×
178
            self.maybe_log(f"--> {repr(e)} (id: {msg_id})")
×
179
            raise
×
180
        except BaseException as e:  # cancellations, etc. are useful for debugging
×
181
            self.maybe_log(f"--> {repr(e)} (id: {msg_id})")
×
182
            raise
×
183
        else:
184
            self.maybe_log(f"--> {response} (id: {msg_id})")
×
185
            return response
×
186

187
    def set_default_timeout(self, timeout):
4✔
188
        assert hasattr(self, "sent_request_timeout")  # in base class
×
189
        self.sent_request_timeout = timeout
×
190
        assert hasattr(self, "max_send_delay")        # in base class
×
191
        self.max_send_delay = timeout
×
192

193
    async def subscribe(self, method: str, params: List, queue: asyncio.Queue):
4✔
194
        # note: until the cache is written for the first time,
195
        # each 'subscribe' call might make a request on the network.
196
        key = self.get_hashable_key_for_rpc_call(method, params)
×
197
        self.subscriptions[key].append(queue)
×
198
        if key in self.cache:
×
199
            result = self.cache[key]
×
200
        else:
201
            result = await self.send_request(method, params)
×
202
            self.cache[key] = result
×
203
        await queue.put(params + [result])
×
204

205
    def unsubscribe(self, queue):
4✔
206
        """Unsubscribe a callback to free object references to enable GC."""
207
        # note: we can't unsubscribe from the server, so we keep receiving
208
        # subsequent notifications
209
        for v in self.subscriptions.values():
×
210
            if queue in v:
×
211
                v.remove(queue)
×
212

213
    @classmethod
4✔
214
    def get_hashable_key_for_rpc_call(cls, method, params):
4✔
215
        """Hashable index for subscriptions and cache"""
216
        return str(method) + repr(params)
×
217

218
    def maybe_log(self, msg: str) -> None:
4✔
219
        if not self.interface: return
×
220
        if self.interface.debug or self.interface.network.debug:
×
221
            self.interface.logger.debug(msg)
×
222

223
    def default_framer(self):
4✔
224
        # overridden so that max_size can be customized
225
        max_size = self.interface.network.config.NETWORK_MAX_INCOMING_MSG_SIZE
×
226
        assert max_size > 500_000, f"{max_size=} (< 500_000) is too small"
×
227
        return NewlineFramer(max_size=max_size)
×
228

229
    async def close(self, *, force_after: int = None):
4✔
230
        """Closes the connection and waits for it to be closed.
231
        We try to flush buffered data to the wire, which can take some time.
232
        """
233
        if force_after is None:
×
234
            # We give up after a while and just abort the connection.
235
            # Note: specifically if the server is running Fulcrum, waiting seems hopeless,
236
            #       the connection must be aborted (see https://github.com/cculianu/Fulcrum/issues/76)
237
            # Note: if the ethernet cable was pulled or wifi disconnected, that too might
238
            #       wait until this timeout is triggered
239
            force_after = 1  # seconds
×
240
        await super().close(force_after=force_after)
×
241

242

243
class NetworkException(Exception): pass
4✔
244

245

246
class GracefulDisconnect(NetworkException):
4✔
247
    log_level = logging.INFO
4✔
248

249
    def __init__(self, *args, log_level=None, **kwargs):
4✔
250
        Exception.__init__(self, *args, **kwargs)
4✔
251
        if log_level is not None:
4✔
252
            self.log_level = log_level
×
253

254

255
class RequestTimedOut(GracefulDisconnect):
4✔
256
    def __str__(self):
4✔
257
        return _("Network request timed out.")
×
258

259

260
class RequestCorrupted(Exception): pass
4✔
261

262
class ErrorParsingSSLCert(Exception): pass
4✔
263
class ErrorGettingSSLCertFromServer(Exception): pass
4✔
264
class ErrorSSLCertFingerprintMismatch(Exception): pass
4✔
265
class InvalidOptionCombination(Exception): pass
4✔
266
class ConnectError(NetworkException): pass
4✔
267

268

269
class _RSClient(RSClient):
4✔
270
    def __init__(self, *, transport=None, **kwargs):
4✔
NEW
271
        if transport is None:
×
NEW
272
            transport = PaddedRSTransport
×
NEW
273
        RSClient.__init__(self, transport=transport, **kwargs)
×
274

275
    async def create_connection(self):
4✔
276
        try:
×
277
            return await super().create_connection()
×
278
        except OSError as e:
×
279
            # note: using "from e" here will set __cause__ of ConnectError
280
            raise ConnectError(e) from e
×
281

282

283
class PaddedRSTransport(RSTransport):
4✔
284
    """A raw socket transport that provides basic countermeasures against traffic analysis
285
    by padding the jsonrpc payload with whitespaces to have ~uniform-size TCP packets.
286
    (it is assumed that a network observer does not see plaintext transport contents,
287
    due to it being wrapped e.g. in TLS)
288
    """
289

290
    MIN_PACKET_SIZE = 1024
4✔
291
    WAIT_FOR_BUFFER_GROWTH_SECONDS = 1.0
4✔
292

293
    session: Optional['RPCSession']
4✔
294

295
    def __init__(self, *args, **kwargs):
4✔
NEW
296
        RSTransport.__init__(self, *args, **kwargs)
×
NEW
297
        self._sbuffer = bytearray()  # "send buffer"
×
NEW
298
        self._sbuffer_task = None  # type: Optional[asyncio.Task]
×
NEW
299
        self._sbuffer_has_data_evt = asyncio.Event()
×
NEW
300
        self._last_send = time.monotonic()
×
NEW
301
        self._force_send = False  # type: bool
×
302

303
    # note: this does not call super().write() but is a complete reimplementation
304
    async def write(self, message):
4✔
NEW
305
        await self._can_send.wait()
×
NEW
306
        if self.is_closing():
×
NEW
307
            return
×
NEW
308
        framed_message = self._framer.frame(message)
×
NEW
309
        self._sbuffer += framed_message
×
NEW
310
        self._sbuffer_has_data_evt.set()
×
NEW
311
        self._maybe_consume_sbuffer()
×
312

313
    def _maybe_consume_sbuffer(self):
4✔
NEW
314
        if not self._can_send.is_set() or self.is_closing():
×
NEW
315
            return
×
NEW
316
        buf = self._sbuffer
×
NEW
317
        if not buf:
×
NEW
318
            return
×
319
        # if there is enough data in the buffer, or if we haven't sent in a while, send now:
NEW
320
        if not (
×
321
            self._force_send
322
            or len(buf) >= self.MIN_PACKET_SIZE
323
            or self._last_send + self.WAIT_FOR_BUFFER_GROWTH_SECONDS < time.monotonic()
324
        ):
NEW
325
            return
×
NEW
326
        assert buf[-2:] in (b"}\n", b"]\n"), f"unexpected json-rpc terminator: {buf[-2:]=!r}"
×
327
        # either (1) pad length to next power of two, to create "lsize" packet:
NEW
328
        payload_lsize = len(buf)
×
NEW
329
        total_lsize = max(self.MIN_PACKET_SIZE, 2 ** (payload_lsize.bit_length()))
×
NEW
330
        npad_lsize = total_lsize - payload_lsize
×
331
        # or if that wasted a lot of bandwidth with padding, (2) defer sending some messages
332
        # and create a packet with half that size ("ssize", s for small)
NEW
333
        total_ssize = max(self.MIN_PACKET_SIZE, total_lsize // 2)
×
NEW
334
        payload_ssize = buf.rfind(b"\n", 0, total_ssize)
×
NEW
335
        if payload_ssize != -1:
×
NEW
336
            payload_ssize += 1  # for "\n" char
×
NEW
337
            npad_ssize = total_ssize - payload_ssize
×
338
        else:
NEW
339
            npad_ssize = float("inf")
×
340
        # decide between (1) and (2):
NEW
341
        if self._force_send or npad_lsize <= npad_ssize:
×
342
            # (1) create "lsize" packet: consume full buffer
NEW
343
            npad = npad_lsize
×
NEW
344
            p_idx = payload_lsize
×
345
        else:
346
            # (2) create "ssize" packet: consume some, but defer some for later
NEW
347
            npad = npad_ssize
×
NEW
348
            p_idx = payload_ssize
×
349
        # pad by adding spaces near end
350
        # self.session.maybe_log(
351
        #     f"PaddedRSTransport. calling low-level write(). "
352
        #     f"chose between (lsize:{payload_lsize}+{npad_lsize}, ssize:{payload_ssize}+{npad_ssize}). "
353
        #     f"won: {'tie' if npad_lsize == npad_ssize else 'lsize' if npad_lsize < npad_ssize else 'ssize'}."
354
        # )
NEW
355
        json_rpc_terminator = buf[p_idx-2:p_idx]
×
NEW
356
        assert json_rpc_terminator in (b"}\n", b"]\n"), f"unexpected {json_rpc_terminator=!r}"
×
NEW
357
        buf2 = buf[:p_idx-2] + (npad * b" ") + json_rpc_terminator
×
NEW
358
        self._asyncio_transport.write(buf2)
×
NEW
359
        self._last_send = time.monotonic()
×
NEW
360
        del self._sbuffer[:p_idx]
×
NEW
361
        if not self._sbuffer:
×
NEW
362
            self._sbuffer_has_data_evt.clear()
×
363

364
    async def _poll_sbuffer(self):
4✔
NEW
365
        while True:
×
NEW
366
            await self._sbuffer_has_data_evt.wait()  # to avoid busy-waiting
×
NEW
367
            self._maybe_consume_sbuffer()
×
368
            # If there is still data in the buffer, sleep until it would time out.
369
            # note: If the transport is ~idle, when we wake up, we will send the current buf data,
370
            #       but if busy, we might wake up to completely new buffer contents. Either is fine.
NEW
371
            if len(self._sbuffer) > 0:
×
NEW
372
                timeout_abs = self._last_send + self.WAIT_FOR_BUFFER_GROWTH_SECONDS
×
NEW
373
                timeout_rel = max(0.0, timeout_abs - time.monotonic())
×
NEW
374
                await asyncio.sleep(timeout_rel)
×
375

376
    def connection_made(self, transport: asyncio.BaseTransport):
4✔
NEW
377
        super().connection_made(transport)
×
NEW
378
        if isinstance(self.session, NotificationSession):
×
NEW
379
            coro = self.session.interface.taskgroup.spawn(self._poll_sbuffer())
×
NEW
380
            self._sbuffer_task = self.loop.create_task(coro)
×
381
        else:
382
            # This a short-lived "fetch_certificate"-type session.
383
            # No polling here, we always force-empty the buffer.
NEW
384
            self._force_send = True
×
385

386

387
class ServerAddr:
4✔
388

389
    def __init__(self, host: str, port: Union[int, str], *, protocol: str = None):
4✔
390
        assert isinstance(host, str), repr(host)
4✔
391
        if protocol is None:
4✔
392
            protocol = 's'
×
393
        if not host:
4✔
394
            raise ValueError('host must not be empty')
×
395
        if host[0] == '[' and host[-1] == ']':  # IPv6
4✔
396
            host = host[1:-1]
4✔
397
        try:
4✔
398
            net_addr = NetAddress(host, port)  # this validates host and port
4✔
399
        except Exception as e:
4✔
400
            raise ValueError(f"cannot construct ServerAddr: invalid host or port (host={host}, port={port})") from e
4✔
401
        if protocol not in _KNOWN_NETWORK_PROTOCOLS:
4✔
402
            raise ValueError(f"invalid network protocol: {protocol}")
×
403
        self.host = str(net_addr.host)  # canonical form (if e.g. IPv6 address)
4✔
404
        self.port = int(net_addr.port)
4✔
405
        self.protocol = protocol
4✔
406
        self._net_addr_str = str(net_addr)
4✔
407

408
    @classmethod
4✔
409
    def from_str(cls, s: str) -> 'ServerAddr':
4✔
410
        """Constructs a ServerAddr or raises ValueError."""
411
        # host might be IPv6 address, hence do rsplit:
412
        host, port, protocol = str(s).rsplit(':', 2)
4✔
413
        return ServerAddr(host=host, port=port, protocol=protocol)
4✔
414

415
    @classmethod
4✔
416
    def from_str_with_inference(cls, s: str) -> Optional['ServerAddr']:
4✔
417
        """Construct ServerAddr from str, guessing missing details.
418
        Does not raise - just returns None if guessing failed.
419
        Ongoing compatibility not guaranteed.
420
        """
421
        if not s:
4✔
422
            return None
×
423
        host = ""
4✔
424
        if s[0] == "[" and "]" in s:  # IPv6 address
4✔
425
            host_end = s.index("]")
4✔
426
            host = s[1:host_end]
4✔
427
            s = s[host_end+1:]
4✔
428
        items = str(s).rsplit(':', 2)
4✔
429
        if len(items) < 2:
4✔
430
            return None  # although maybe we could guess the port too?
4✔
431
        host = host or items[0]
4✔
432
        port = items[1]
4✔
433
        if len(items) >= 3:
4✔
434
            protocol = items[2]
4✔
435
        else:
436
            protocol = PREFERRED_NETWORK_PROTOCOL
4✔
437
        try:
4✔
438
            return ServerAddr(host=host, port=port, protocol=protocol)
4✔
439
        except ValueError:
4✔
440
            return None
4✔
441

442
    def to_friendly_name(self) -> str:
4✔
443
        # note: this method is closely linked to from_str_with_inference
444
        if self.protocol == 's':  # hide trailing ":s"
4✔
445
            return self.net_addr_str()
4✔
446
        return str(self)
4✔
447

448
    def __str__(self):
4✔
449
        return '{}:{}'.format(self.net_addr_str(), self.protocol)
4✔
450

451
    def to_json(self) -> str:
4✔
452
        return str(self)
×
453

454
    def __repr__(self):
4✔
455
        return f'<ServerAddr host={self.host} port={self.port} protocol={self.protocol}>'
×
456

457
    def net_addr_str(self) -> str:
4✔
458
        return self._net_addr_str
4✔
459

460
    def __eq__(self, other):
4✔
461
        if not isinstance(other, ServerAddr):
4✔
462
            return False
×
463
        return (self.host == other.host
4✔
464
                and self.port == other.port
465
                and self.protocol == other.protocol)
466

467
    def __ne__(self, other):
4✔
468
        return not (self == other)
×
469

470
    def __hash__(self):
4✔
471
        return hash((self.host, self.port, self.protocol))
×
472

473

474
def _get_cert_path_for_host(*, config: 'SimpleConfig', host: str) -> str:
4✔
475
    filename = host
4✔
476
    try:
4✔
477
        ip = ip_address(host)
4✔
478
    except ValueError:
4✔
479
        pass
4✔
480
    else:
481
        if isinstance(ip, IPv6Address):
×
482
            filename = f"ipv6_{ip.packed.hex()}"
×
483
    return os.path.join(config.path, 'certs', filename)
4✔
484

485

486
class Interface(Logger):
4✔
487

488
    LOGGING_SHORTCUT = 'i'
4✔
489

490
    def __init__(self, *, network: 'Network', server: ServerAddr):
4✔
491
        self.ready = network.asyncio_loop.create_future()
4✔
492
        self.got_disconnected = asyncio.Event()
4✔
493
        self.server = server
4✔
494
        Logger.__init__(self)
4✔
495
        assert network.config.path
4✔
496
        self.cert_path = _get_cert_path_for_host(config=network.config, host=self.host)
4✔
497
        self.blockchain = None  # type: Optional[Blockchain]
4✔
498
        self._requested_chunks = set()  # type: Set[int]
4✔
499
        self.network = network
4✔
500
        self.session = None  # type: Optional[NotificationSession]
4✔
501
        self._ipaddr_bucket = None
4✔
502
        # Set up proxy.
503
        # - for servers running on localhost, the proxy is not used. If user runs their own server
504
        #   on same machine, this lets them enable the proxy (which is used for e.g. FX rates).
505
        #   note: we could maybe relax this further and bypass the proxy for all private
506
        #         addresses...? e.g. 192.168.x.x
507
        if util.is_localhost(server.host):
4✔
508
            self.logger.info(f"looks like localhost: not using proxy for this server")
×
509
            self.proxy = None
×
510
        else:
511
            self.proxy = ESocksProxy.from_network_settings(network)
4✔
512

513
        # Latest block header and corresponding height, as claimed by the server.
514
        # Note that these values are updated before they are verified.
515
        # Especially during initial header sync, verification can take a long time.
516
        # Failing verification will get the interface closed.
517
        self.tip_header = None
4✔
518
        self.tip = 0
4✔
519

520
        self.fee_estimates_eta = {}  # type: Dict[int, int]
4✔
521

522
        # Dump network messages (only for this interface).  Set at runtime from the console.
523
        self.debug = False
4✔
524

525
        self.taskgroup = OldTaskGroup()
4✔
526

527
        async def spawn_task():
4✔
528
            task = await self.network.taskgroup.spawn(self.run())
4✔
529
            task.set_name(f"interface::{str(server)}")
4✔
530
        asyncio.run_coroutine_threadsafe(spawn_task(), self.network.asyncio_loop)
4✔
531

532
    @property
4✔
533
    def host(self):
4✔
534
        return self.server.host
4✔
535

536
    @property
4✔
537
    def port(self):
4✔
538
        return self.server.port
×
539

540
    @property
4✔
541
    def protocol(self):
4✔
542
        return self.server.protocol
×
543

544
    def diagnostic_name(self):
4✔
545
        return self.server.net_addr_str()
4✔
546

547
    def __str__(self):
4✔
548
        return f"<Interface {self.diagnostic_name()}>"
×
549

550
    async def is_server_ca_signed(self, ca_ssl_context):
4✔
551
        """Given a CA enforcing SSL context, returns True if the connection
552
        can be established. Returns False if the server has a self-signed
553
        certificate but otherwise is okay. Any other failures raise.
554
        """
555
        try:
×
556
            await self.open_session(ca_ssl_context, exit_early=True)
×
557
        except ConnectError as e:
×
558
            cause = e.__cause__
×
559
            if (isinstance(cause, ssl.SSLCertVerificationError)
×
560
                    and cause.reason == 'CERTIFICATE_VERIFY_FAILED'
561
                    and cause.verify_code == 18):  # "self signed certificate"
562
                # Good. We will use this server as self-signed.
563
                return False
×
564
            # Not good. Cannot use this server.
565
            raise
×
566
        # Good. We will use this server as CA-signed.
567
        return True
×
568

569
    async def _try_saving_ssl_cert_for_first_time(self, ca_ssl_context):
4✔
570
        ca_signed = await self.is_server_ca_signed(ca_ssl_context)
×
571
        if ca_signed:
×
572
            if self._get_expected_fingerprint():
×
573
                raise InvalidOptionCombination("cannot use --serverfingerprint with CA signed servers")
×
574
            with open(self.cert_path, 'w') as f:
×
575
                # empty file means this is CA signed, not self-signed
576
                f.write('')
×
577
        else:
578
            await self._save_certificate()
×
579

580
    def _is_saved_ssl_cert_available(self):
4✔
581
        if not os.path.exists(self.cert_path):
×
582
            return False
×
583
        with open(self.cert_path, 'r') as f:
×
584
            contents = f.read()
×
585
        if contents == '':  # CA signed
×
586
            if self._get_expected_fingerprint():
×
587
                raise InvalidOptionCombination("cannot use --serverfingerprint with CA signed servers")
×
588
            return True
×
589
        # pinned self-signed cert
590
        try:
×
591
            b = pem.dePem(contents, 'CERTIFICATE')
×
592
        except SyntaxError as e:
×
593
            self.logger.info(f"error parsing already saved cert: {e}")
×
594
            raise ErrorParsingSSLCert(e) from e
×
595
        try:
×
596
            x = x509.X509(b)
×
597
        except Exception as e:
×
598
            self.logger.info(f"error parsing already saved cert: {e}")
×
599
            raise ErrorParsingSSLCert(e) from e
×
600
        try:
×
601
            x.check_date()
×
602
        except x509.CertificateError as e:
×
603
            self.logger.info(f"certificate has expired: {e}")
×
604
            os.unlink(self.cert_path)  # delete pinned cert only in this case
×
605
            return False
×
606
        self._verify_certificate_fingerprint(bytearray(b))
×
607
        return True
×
608

609
    async def _get_ssl_context(self):
4✔
610
        if self.protocol != 's':
×
611
            # using plaintext TCP
612
            return None
×
613

614
        # see if we already have cert for this server; or get it for the first time
615
        ca_sslc = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH, cafile=ca_path)
×
616
        if not self._is_saved_ssl_cert_available():
×
617
            try:
×
618
                await self._try_saving_ssl_cert_for_first_time(ca_sslc)
×
619
            except (OSError, ConnectError, aiorpcx.socks.SOCKSError) as e:
×
620
                raise ErrorGettingSSLCertFromServer(e) from e
×
621
        # now we have a file saved in our certificate store
622
        siz = os.stat(self.cert_path).st_size
×
623
        if siz == 0:
×
624
            # CA signed cert
625
            sslc = ca_sslc
×
626
        else:
627
            # pinned self-signed cert
628
            sslc = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH, cafile=self.cert_path)
×
629
            # note: Flag "ssl.VERIFY_X509_STRICT" is enabled by default in python 3.13+ (disabled in older versions).
630
            #       We explicitly disable it as it breaks lots of servers.
631
            sslc.verify_flags &= ~ssl.VERIFY_X509_STRICT
×
632
            sslc.check_hostname = False
×
633
        return sslc
×
634

635
    def handle_disconnect(func):
4✔
636
        @functools.wraps(func)
4✔
637
        async def wrapper_func(self: 'Interface', *args, **kwargs):
4✔
638
            try:
×
639
                return await func(self, *args, **kwargs)
×
640
            except GracefulDisconnect as e:
×
641
                self.logger.log(e.log_level, f"disconnecting due to {repr(e)}")
×
642
            except aiorpcx.jsonrpc.RPCError as e:
×
643
                self.logger.warning(f"disconnecting due to {repr(e)}")
×
644
                self.logger.debug(f"(disconnect) trace for {repr(e)}", exc_info=True)
×
645
            finally:
646
                self.got_disconnected.set()
×
647
                await self.network.connection_down(self)
×
648
                # if was not 'ready' yet, schedule waiting coroutines:
649
                self.ready.cancel()
×
650
        return wrapper_func
4✔
651

652
    @ignore_exceptions  # do not kill network.taskgroup
4✔
653
    @log_exceptions
4✔
654
    @handle_disconnect
4✔
655
    async def run(self):
4✔
656
        try:
×
657
            ssl_context = await self._get_ssl_context()
×
658
        except (ErrorParsingSSLCert, ErrorGettingSSLCertFromServer) as e:
×
659
            self.logger.info(f'disconnecting due to: {repr(e)}')
×
660
            return
×
661
        try:
×
662
            await self.open_session(ssl_context)
×
663
        except (asyncio.CancelledError, ConnectError, aiorpcx.socks.SOCKSError) as e:
×
664
            # make SSL errors for main interface more visible (to help servers ops debug cert pinning issues)
665
            if (isinstance(e, ConnectError) and isinstance(e.__cause__, ssl.SSLError)
×
666
                    and self.is_main_server() and not self.network.auto_connect):
667
                self.logger.warning(f'Cannot connect to main server due to SSL error '
×
668
                                    f'(maybe cert changed compared to "{self.cert_path}"). Exc: {repr(e)}')
669
            else:
670
                self.logger.info(f'disconnecting due to: {repr(e)}')
×
671
            return
×
672

673
    def _mark_ready(self) -> None:
4✔
674
        if self.ready.cancelled():
×
675
            raise GracefulDisconnect('conn establishment was too slow; *ready* future was cancelled')
×
676
        if self.ready.done():
×
677
            return
×
678

679
        assert self.tip_header
×
680
        chain = blockchain.check_header(self.tip_header)
×
681
        if not chain:
×
682
            self.blockchain = blockchain.get_best_chain()
×
683
        else:
684
            self.blockchain = chain
×
685
        assert self.blockchain is not None
×
686

687
        self.logger.info(f"set blockchain with height {self.blockchain.height()}")
×
688

689
        self.ready.set_result(1)
×
690

691
    def is_connected_and_ready(self) -> bool:
4✔
692
        return self.ready.done() and not self.got_disconnected.is_set()
×
693

694
    async def _save_certificate(self) -> None:
4✔
695
        if not os.path.exists(self.cert_path):
×
696
            # we may need to retry this a few times, in case the handshake hasn't completed
697
            for _ in range(10):
×
698
                dercert = await self._fetch_certificate()
×
699
                if dercert:
×
700
                    self.logger.info("succeeded in getting cert")
×
701
                    self._verify_certificate_fingerprint(dercert)
×
702
                    with open(self.cert_path, 'w') as f:
×
703
                        cert = ssl.DER_cert_to_PEM_cert(dercert)
×
704
                        # workaround android bug
705
                        cert = re.sub("([^\n])-----END CERTIFICATE-----","\\1\n-----END CERTIFICATE-----",cert)
×
706
                        f.write(cert)
×
707
                        # even though close flushes, we can't fsync when closed.
708
                        # and we must flush before fsyncing, cause flush flushes to OS buffer
709
                        # fsync writes to OS buffer to disk
710
                        f.flush()
×
711
                        os.fsync(f.fileno())
×
712
                    break
×
713
                await asyncio.sleep(1)
×
714
            else:
715
                raise GracefulDisconnect("could not get certificate after 10 tries")
×
716

717
    async def _fetch_certificate(self) -> bytes:
4✔
718
        sslc = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_CLIENT)
×
719
        sslc.check_hostname = False
×
720
        sslc.verify_mode = ssl.CERT_NONE
×
721
        async with _RSClient(session_factory=RPCSession,
×
722
                             host=self.host, port=self.port,
723
                             ssl=sslc, proxy=self.proxy) as session:
724
            asyncio_transport = session.transport._asyncio_transport  # type: asyncio.BaseTransport
×
725
            ssl_object = asyncio_transport.get_extra_info("ssl_object")  # type: ssl.SSLObject
×
726
            return ssl_object.getpeercert(binary_form=True)
×
727

728
    def _get_expected_fingerprint(self) -> Optional[str]:
4✔
729
        if self.is_main_server():
×
730
            return self.network.config.NETWORK_SERVERFINGERPRINT
×
731

732
    def _verify_certificate_fingerprint(self, certificate):
4✔
733
        expected_fingerprint = self._get_expected_fingerprint()
×
734
        if not expected_fingerprint:
×
735
            return
×
736
        fingerprint = hashlib.sha256(certificate).hexdigest()
×
737
        fingerprints_match = fingerprint.lower() == expected_fingerprint.lower()
×
738
        if not fingerprints_match:
×
739
            util.trigger_callback('cert_mismatch')
×
740
            raise ErrorSSLCertFingerprintMismatch('Refusing to connect to server due to cert fingerprint mismatch')
×
741
        self.logger.info("cert fingerprint verification passed")
×
742

743
    async def get_block_header(self, height, assert_mode):
4✔
744
        if not is_non_negative_integer(height):
×
745
            raise Exception(f"{repr(height)} is not a block height")
×
746
        self.logger.info(f'requesting block header {height} in mode {assert_mode}')
×
747
        # use lower timeout as we usually have network.bhi_lock here
748
        timeout = self.network.get_network_timeout_seconds(NetworkTimeout.Urgent)
×
749
        res = await self.session.send_request('blockchain.block.header', [height], timeout=timeout)
×
750
        return blockchain.deserialize_header(bytes.fromhex(res), height)
×
751

752
    async def request_chunk(self, height: int, tip=None, *, can_return_early=False):
4✔
753
        if not is_non_negative_integer(height):
×
754
            raise Exception(f"{repr(height)} is not a block height")
×
755
        index = height // 2016
×
756
        if can_return_early and index in self._requested_chunks:
×
757
            return
×
758
        self.logger.info(f"requesting chunk from height {height}")
×
759
        size = 2016
×
760
        if tip is not None:
×
761
            size = min(size, tip - index * 2016 + 1)
×
762
            size = max(size, 0)
×
763
        try:
×
764
            self._requested_chunks.add(index)
×
765
            res = await self.session.send_request('blockchain.block.headers', [index * 2016, size])
×
766
        finally:
767
            self._requested_chunks.discard(index)
×
768
        assert_dict_contains_field(res, field_name='count')
×
769
        assert_dict_contains_field(res, field_name='hex')
×
770
        assert_dict_contains_field(res, field_name='max')
×
771
        assert_non_negative_integer(res['count'])
×
772
        assert_non_negative_integer(res['max'])
×
773
        assert_hex_str(res['hex'])
×
774
        if len(res['hex']) != HEADER_SIZE * 2 * res['count']:
×
775
            raise RequestCorrupted('inconsistent chunk hex and count')
×
776
        # we never request more than 2016 headers, but we enforce those fit in a single response
777
        if res['max'] < 2016:
×
778
            raise RequestCorrupted(f"server uses too low 'max' count for block.headers: {res['max']} < 2016")
×
779
        if res['count'] != size:
×
780
            raise RequestCorrupted(f"expected {size} headers but only got {res['count']}")
×
781
        conn = self.blockchain.connect_chunk(index, res['hex'])
×
782
        if not conn:
×
783
            return conn, 0
×
784
        return conn, res['count']
×
785

786
    def is_main_server(self) -> bool:
4✔
787
        return (self.network.interface == self or
×
788
                self.network.interface is None and self.network.default_server == self.server)
789

790
    async def open_session(self, sslc, exit_early=False):
4✔
791
        session_factory = lambda *args, iface=self, **kwargs: NotificationSession(*args, **kwargs, interface=iface)
×
792
        async with _RSClient(session_factory=session_factory,
×
793
                             host=self.host, port=self.port,
794
                             ssl=sslc, proxy=self.proxy) as session:
795
            self.session = session  # type: NotificationSession
×
796
            self.session.set_default_timeout(self.network.get_network_timeout_seconds(NetworkTimeout.Generic))
×
797
            try:
×
798
                ver = await session.send_request('server.version', [self.client_name(), version.PROTOCOL_VERSION])
×
799
            except aiorpcx.jsonrpc.RPCError as e:
×
800
                raise GracefulDisconnect(e)  # probably 'unsupported protocol version'
×
801
            if exit_early:
×
802
                return
×
803
            if ver[1] != version.PROTOCOL_VERSION:
×
804
                raise GracefulDisconnect(f'server violated protocol-version-negotiation. '
×
805
                                         f'we asked for {version.PROTOCOL_VERSION!r}, they sent {ver[1]!r}')
806
            if not self.network.check_interface_against_healthy_spread_of_connected_servers(self):
×
807
                raise GracefulDisconnect(f'too many connected servers already '
×
808
                                         f'in bucket {self.bucket_based_on_ipaddress()}')
809
            self.logger.info(f"connection established. version: {ver}")
×
810

811
            try:
×
812
                async with self.taskgroup as group:
×
813
                    await group.spawn(self.ping)
×
814
                    await group.spawn(self.request_fee_estimates)
×
815
                    await group.spawn(self.run_fetch_blocks)
×
816
                    await group.spawn(self.monitor_connection)
×
817
            except aiorpcx.jsonrpc.RPCError as e:
×
818
                if e.code in (
×
819
                    JSONRPC.EXCESSIVE_RESOURCE_USAGE,
820
                    JSONRPC.SERVER_BUSY,
821
                    JSONRPC.METHOD_NOT_FOUND,
822
                    JSONRPC.INTERNAL_ERROR,
823
                ):
824
                    log_level = logging.WARNING if self.is_main_server() else logging.INFO
×
825
                    raise GracefulDisconnect(e, log_level=log_level) from e
×
826
                raise
×
827
            finally:
828
                self.got_disconnected.set()  # set this ASAP, ideally before any awaits
×
829

830
    async def monitor_connection(self):
4✔
831
        while True:
×
832
            await asyncio.sleep(1)
×
833
            # If the session/transport is no longer open, we disconnect.
834
            # e.g. if the remote cleanly sends EOF, we would handle that here.
835
            # note: If the user pulls the ethernet cable or disconnects wifi,
836
            #       ideally we would detect that here, so that the GUI/etc can reflect that.
837
            #       - On Android, this seems to work reliably , where asyncio.BaseProtocol.connection_lost()
838
            #         gets called with e.g. ConnectionAbortedError(103, 'Software caused connection abort').
839
            #       - On desktop Linux/Win, it seems BaseProtocol.connection_lost() is not called in such cases.
840
            #         Hence, in practice the connection issue will only be detected the next time we try
841
            #         to send a message (plus timeout), which can take minutes...
842
            if not self.session or self.session.is_closing():
×
843
                raise GracefulDisconnect('session was closed')
×
844

845
    async def ping(self):
4✔
846
        # We periodically send a "ping" msg to make sure the server knows we are still here.
847
        # Adding a bit of randomness generates some noise against traffic analysis.
848
        while True:
×
NEW
849
            await asyncio.sleep(random.random() * 300)
×
NEW
850
            await self.session.send_request('server.ping')
×
NEW
851
            await self._maybe_send_noise()
×
852

853
    async def _maybe_send_noise(self):
4✔
NEW
854
        while random.random() < 0.2:
×
NEW
855
            await asyncio.sleep(random.random())
×
UNCOV
856
            await self.session.send_request('server.ping')
×
857

858
    async def request_fee_estimates(self):
4✔
859
        while True:
×
860
            async with OldTaskGroup() as group:
×
861
                fee_tasks = []
×
862
                for i in FEE_ETA_TARGETS[0:-1]:
×
863
                    fee_tasks.append((i, await group.spawn(self.get_estimatefee(i))))
×
864
            for nblock_target, task in fee_tasks:
×
865
                fee = task.result()
×
866
                if fee < 0: continue
×
867
                assert isinstance(fee, int)
×
868
                self.fee_estimates_eta[nblock_target] = fee
×
869
            self.network.update_fee_estimates()
×
870
            await asyncio.sleep(60)
×
871

872
    async def close(self, *, force_after: int = None):
4✔
873
        """Closes the connection and waits for it to be closed.
874
        We try to flush buffered data to the wire, which can take some time.
875
        """
876
        if self.session:
×
877
            await self.session.close(force_after=force_after)
×
878
        # monitor_connection will cancel tasks
879

880
    async def run_fetch_blocks(self):
4✔
881
        header_queue = asyncio.Queue()
×
882
        await self.session.subscribe('blockchain.headers.subscribe', [], header_queue)
×
883
        while True:
×
884
            item = await header_queue.get()
×
885
            raw_header = item[0]
×
886
            height = raw_header['height']
×
887
            header = blockchain.deserialize_header(bfh(raw_header['hex']), height)
×
888
            self.tip_header = header
×
889
            self.tip = height
×
890
            if self.tip < constants.net.max_checkpoint():
×
891
                raise GracefulDisconnect('server tip below max checkpoint')
×
892
            self._mark_ready()
×
893
            blockchain_updated = await self._process_header_at_tip()
×
894
            # header processing done
895
            if blockchain_updated:
×
896
                util.trigger_callback('blockchain_updated')
×
897
            util.trigger_callback('network_updated')
×
898
            await self.network.switch_unwanted_fork_interface()
×
899
            await self.network.switch_lagging_interface()
×
NEW
900
            await self.taskgroup.spawn(self._maybe_send_noise())
×
901

902
    async def _process_header_at_tip(self) -> bool:
4✔
903
        """Returns:
904
        False - boring fast-forward: we already have this header as part of this blockchain from another interface,
905
        True - new header we didn't have, or reorg
906
        """
907
        height, header = self.tip, self.tip_header
×
908
        async with self.network.bhi_lock:
×
909
            if self.blockchain.height() >= height and self.blockchain.check_header(header):
×
910
                # another interface amended the blockchain
911
                return False
×
912
            _, height = await self.step(height, header)
×
913
            # in the simple case, height == self.tip+1
914
            if height <= self.tip:
×
915
                await self.sync_until(height)
×
916
            return True
×
917

918
    async def sync_until(self, height, next_height=None):
4✔
919
        if next_height is None:
4✔
920
            next_height = self.tip
×
921
        last = None
4✔
922
        while last is None or height <= next_height:
4✔
923
            prev_last, prev_height = last, height
4✔
924
            if next_height > height + 10:
4✔
925
                could_connect, num_headers = await self.request_chunk(height, next_height)
×
926
                if not could_connect:
×
927
                    if height <= constants.net.max_checkpoint():
×
928
                        raise GracefulDisconnect('server chain conflicts with checkpoints or genesis')
×
929
                    last, height = await self.step(height)
×
930
                    continue
×
931
                util.trigger_callback('blockchain_updated')
×
932
                util.trigger_callback('network_updated')
×
933
                height = (height // 2016 * 2016) + num_headers
×
934
                assert height <= next_height+1, (height, self.tip)
×
935
                last = 'catchup'
×
936
            else:
937
                last, height = await self.step(height)
4✔
938
            assert (prev_last, prev_height) != (last, height), 'had to prevent infinite loop in interface.sync_until'
4✔
939
        return last, height
4✔
940

941
    async def step(self, height, header=None):
4✔
942
        assert 0 <= height <= self.tip, (height, self.tip)
4✔
943
        if header is None:
4✔
944
            header = await self.get_block_header(height, 'catchup')
4✔
945

946
        chain = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header)
4✔
947
        if chain:
4✔
948
            self.blockchain = chain if isinstance(chain, Blockchain) else self.blockchain
×
949
            # note: there is an edge case here that is not handled.
950
            # we might know the blockhash (enough for check_header) but
951
            # not have the header itself. e.g. regtest chain with only genesis.
952
            # this situation resolves itself on the next block
953
            return 'catchup', height+1
×
954

955
        can_connect = blockchain.can_connect(header) if 'mock' not in header else header['mock']['connect'](height)
4✔
956
        if not can_connect:
4✔
957
            self.logger.info(f"can't connect new block: {height=}")
4✔
958
            height, header, bad, bad_header = await self._search_headers_backwards(height, header)
4✔
959
            chain = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header)
4✔
960
            can_connect = blockchain.can_connect(header) if 'mock' not in header else header['mock']['connect'](height)
4✔
961
            assert chain or can_connect
4✔
962
        if can_connect:
4✔
963
            self.logger.info(f"new block: {height=}")
4✔
964
            height += 1
4✔
965
            if isinstance(can_connect, Blockchain):  # not when mocking
4✔
966
                self.blockchain = can_connect
×
967
                self.blockchain.save_header(header)
×
968
            return 'catchup', height
4✔
969

970
        good, bad, bad_header = await self._search_headers_binary(height, bad, bad_header, chain)
4✔
971
        return await self._resolve_potential_chain_fork_given_forkpoint(good, bad, bad_header)
4✔
972

973
    async def _search_headers_binary(self, height, bad, bad_header, chain):
4✔
974
        assert bad == bad_header['block_height']
4✔
975
        _assert_header_does_not_check_against_any_chain(bad_header)
4✔
976

977
        self.blockchain = chain if isinstance(chain, Blockchain) else self.blockchain
4✔
978
        good = height
4✔
979
        while True:
4✔
980
            assert good < bad, (good, bad)
4✔
981
            height = (good + bad) // 2
4✔
982
            self.logger.info(f"binary step. good {good}, bad {bad}, height {height}")
4✔
983
            header = await self.get_block_header(height, 'binary')
4✔
984
            chain = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header)
4✔
985
            if chain:
4✔
986
                self.blockchain = chain if isinstance(chain, Blockchain) else self.blockchain
4✔
987
                good = height
4✔
988
            else:
989
                bad = height
4✔
990
                bad_header = header
4✔
991
            if good + 1 == bad:
4✔
992
                break
4✔
993

994
        mock = 'mock' in bad_header and bad_header['mock']['connect'](height)
4✔
995
        real = not mock and self.blockchain.can_connect(bad_header, check_height=False)
4✔
996
        if not real and not mock:
4✔
997
            raise Exception('unexpected bad header during binary: {}'.format(bad_header))
×
998
        _assert_header_does_not_check_against_any_chain(bad_header)
4✔
999

1000
        self.logger.info(f"binary search exited. good {good}, bad {bad}")
4✔
1001
        return good, bad, bad_header
4✔
1002

1003
    async def _resolve_potential_chain_fork_given_forkpoint(self, good, bad, bad_header):
4✔
1004
        assert good + 1 == bad
4✔
1005
        assert bad == bad_header['block_height']
4✔
1006
        _assert_header_does_not_check_against_any_chain(bad_header)
4✔
1007
        # 'good' is the height of a block 'good_header', somewhere in self.blockchain.
1008
        # bad_header connects to good_header; bad_header itself is NOT in self.blockchain.
1009

1010
        bh = self.blockchain.height()
4✔
1011
        assert bh >= good, (bh, good)
4✔
1012
        if bh == good:
4✔
1013
            height = good + 1
×
1014
            self.logger.info(f"catching up from {height}")
×
1015
            return 'no_fork', height
×
1016

1017
        # this is a new fork we don't yet have
1018
        height = bad + 1
4✔
1019
        self.logger.info(f"new fork at bad height {bad}")
4✔
1020
        forkfun = self.blockchain.fork if 'mock' not in bad_header else bad_header['mock']['fork']
4✔
1021
        b = forkfun(bad_header)  # type: Blockchain
4✔
1022
        self.blockchain = b
4✔
1023
        assert b.forkpoint == bad
4✔
1024
        return 'fork', height
4✔
1025

1026
    async def _search_headers_backwards(self, height, header):
4✔
1027
        async def iterate():
4✔
1028
            nonlocal height, header
1029
            checkp = False
4✔
1030
            if height <= constants.net.max_checkpoint():
4✔
1031
                height = constants.net.max_checkpoint()
×
1032
                checkp = True
×
1033
            header = await self.get_block_header(height, 'backward')
4✔
1034
            chain = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header)
4✔
1035
            can_connect = blockchain.can_connect(header) if 'mock' not in header else header['mock']['connect'](height)
4✔
1036
            if chain or can_connect:
4✔
1037
                return False
4✔
1038
            if checkp:
4✔
1039
                raise GracefulDisconnect("server chain conflicts with checkpoints")
×
1040
            return True
4✔
1041

1042
        bad, bad_header = height, header
4✔
1043
        _assert_header_does_not_check_against_any_chain(bad_header)
4✔
1044
        with blockchain.blockchains_lock: chains = list(blockchain.blockchains.values())
4✔
1045
        local_max = max([0] + [x.height() for x in chains]) if 'mock' not in header else float('inf')
4✔
1046
        height = min(local_max + 1, height - 1)
4✔
1047
        while await iterate():
4✔
1048
            bad, bad_header = height, header
4✔
1049
            delta = self.tip - height
4✔
1050
            height = self.tip - 2 * delta
4✔
1051

1052
        _assert_header_does_not_check_against_any_chain(bad_header)
4✔
1053
        self.logger.info(f"exiting backward mode at {height}")
4✔
1054
        return height, header, bad, bad_header
4✔
1055

1056
    @classmethod
4✔
1057
    def client_name(cls) -> str:
4✔
1058
        return f'electrum/{version.ELECTRUM_VERSION}'
×
1059

1060
    def is_tor(self):
4✔
1061
        return self.host.endswith('.onion')
×
1062

1063
    def ip_addr(self) -> Optional[str]:
4✔
1064
        session = self.session
×
1065
        if not session: return None
×
1066
        peer_addr = session.remote_address()
×
1067
        if not peer_addr: return None
×
1068
        return str(peer_addr.host)
×
1069

1070
    def bucket_based_on_ipaddress(self) -> str:
4✔
1071
        def do_bucket():
×
1072
            if self.is_tor():
×
1073
                return BUCKET_NAME_OF_ONION_SERVERS
×
1074
            try:
×
1075
                ip_addr = ip_address(self.ip_addr())  # type: Union[IPv4Address, IPv6Address]
×
1076
            except ValueError:
×
1077
                return ''
×
1078
            if not ip_addr:
×
1079
                return ''
×
1080
            if ip_addr.is_loopback:  # localhost is exempt
×
1081
                return ''
×
1082
            if ip_addr.version == 4:
×
1083
                slash16 = IPv4Network(ip_addr).supernet(prefixlen_diff=32-16)
×
1084
                return str(slash16)
×
1085
            elif ip_addr.version == 6:
×
1086
                slash48 = IPv6Network(ip_addr).supernet(prefixlen_diff=128-48)
×
1087
                return str(slash48)
×
1088
            return ''
×
1089

1090
        if not self._ipaddr_bucket:
×
1091
            self._ipaddr_bucket = do_bucket()
×
1092
        return self._ipaddr_bucket
×
1093

1094
    async def get_merkle_for_transaction(self, tx_hash: str, tx_height: int) -> dict:
4✔
1095
        if not is_hash256_str(tx_hash):
×
1096
            raise Exception(f"{repr(tx_hash)} is not a txid")
×
1097
        if not is_non_negative_integer(tx_height):
×
1098
            raise Exception(f"{repr(tx_height)} is not a block height")
×
1099
        # do request
1100
        res = await self.session.send_request('blockchain.transaction.get_merkle', [tx_hash, tx_height])
×
1101
        # check response
1102
        block_height = assert_dict_contains_field(res, field_name='block_height')
×
1103
        merkle = assert_dict_contains_field(res, field_name='merkle')
×
1104
        pos = assert_dict_contains_field(res, field_name='pos')
×
1105
        # note: tx_height was just a hint to the server, don't enforce the response to match it
1106
        assert_non_negative_integer(block_height)
×
1107
        assert_non_negative_integer(pos)
×
1108
        assert_list_or_tuple(merkle)
×
1109
        for item in merkle:
×
1110
            assert_hash256_str(item)
×
1111
        return res
×
1112

1113
    async def get_transaction(self, tx_hash: str, *, timeout=None) -> str:
4✔
1114
        if not is_hash256_str(tx_hash):
×
1115
            raise Exception(f"{repr(tx_hash)} is not a txid")
×
1116
        raw = await self.session.send_request('blockchain.transaction.get', [tx_hash], timeout=timeout)
×
1117
        # validate response
1118
        if not is_hex_str(raw):
×
1119
            raise RequestCorrupted(f"received garbage (non-hex) as tx data (txid {tx_hash}): {raw!r}")
×
1120
        tx = Transaction(raw)
×
1121
        try:
×
1122
            tx.deserialize()  # see if raises
×
1123
        except Exception as e:
×
1124
            raise RequestCorrupted(f"cannot deserialize received transaction (txid {tx_hash})") from e
×
1125
        if tx.txid() != tx_hash:
×
1126
            raise RequestCorrupted(f"received tx does not match expected txid {tx_hash} (got {tx.txid()})")
×
1127
        return raw
×
1128

1129
    async def get_history_for_scripthash(self, sh: str) -> List[dict]:
4✔
1130
        if not is_hash256_str(sh):
×
1131
            raise Exception(f"{repr(sh)} is not a scripthash")
×
1132
        # do request
1133
        res = await self.session.send_request('blockchain.scripthash.get_history', [sh])
×
1134
        # check response
1135
        assert_list_or_tuple(res)
×
1136
        prev_height = 1
×
1137
        for tx_item in res:
×
1138
            height = assert_dict_contains_field(tx_item, field_name='height')
×
1139
            assert_dict_contains_field(tx_item, field_name='tx_hash')
×
1140
            assert_integer(height)
×
1141
            assert_hash256_str(tx_item['tx_hash'])
×
1142
            if height in (-1, 0):
×
1143
                assert_dict_contains_field(tx_item, field_name='fee')
×
1144
                assert_non_negative_integer(tx_item['fee'])
×
1145
                prev_height = float("inf")  # this ensures confirmed txs can't follow mempool txs
×
1146
            else:
1147
                # check monotonicity of heights
1148
                if height < prev_height:
×
1149
                    raise RequestCorrupted(f'heights of confirmed txs must be in increasing order')
×
1150
                prev_height = height
×
1151
        hashes = set(map(lambda item: item['tx_hash'], res))
×
1152
        if len(hashes) != len(res):
×
1153
            # Either server is sending garbage... or maybe if server is race-prone
1154
            # a recently mined tx could be included in both last block and mempool?
1155
            # Still, it's simplest to just disregard the response.
1156
            raise RequestCorrupted(f"server history has non-unique txids for sh={sh}")
×
1157
        return res
×
1158

1159
    async def listunspent_for_scripthash(self, sh: str) -> List[dict]:
4✔
1160
        if not is_hash256_str(sh):
×
1161
            raise Exception(f"{repr(sh)} is not a scripthash")
×
1162
        # do request
1163
        res = await self.session.send_request('blockchain.scripthash.listunspent', [sh])
×
1164
        # check response
1165
        assert_list_or_tuple(res)
×
1166
        for utxo_item in res:
×
1167
            assert_dict_contains_field(utxo_item, field_name='tx_pos')
×
1168
            assert_dict_contains_field(utxo_item, field_name='value')
×
1169
            assert_dict_contains_field(utxo_item, field_name='tx_hash')
×
1170
            assert_dict_contains_field(utxo_item, field_name='height')
×
1171
            assert_non_negative_integer(utxo_item['tx_pos'])
×
1172
            assert_non_negative_integer(utxo_item['value'])
×
1173
            assert_non_negative_integer(utxo_item['height'])
×
1174
            assert_hash256_str(utxo_item['tx_hash'])
×
1175
        return res
×
1176

1177
    async def get_balance_for_scripthash(self, sh: str) -> dict:
4✔
1178
        if not is_hash256_str(sh):
×
1179
            raise Exception(f"{repr(sh)} is not a scripthash")
×
1180
        # do request
1181
        res = await self.session.send_request('blockchain.scripthash.get_balance', [sh])
×
1182
        # check response
1183
        assert_dict_contains_field(res, field_name='confirmed')
×
1184
        assert_dict_contains_field(res, field_name='unconfirmed')
×
1185
        assert_non_negative_integer(res['confirmed'])
×
1186
        assert_integer(res['unconfirmed'])
×
1187
        return res
×
1188

1189
    async def get_txid_from_txpos(self, tx_height: int, tx_pos: int, merkle: bool):
4✔
1190
        if not is_non_negative_integer(tx_height):
×
1191
            raise Exception(f"{repr(tx_height)} is not a block height")
×
1192
        if not is_non_negative_integer(tx_pos):
×
1193
            raise Exception(f"{repr(tx_pos)} should be non-negative integer")
×
1194
        # do request
1195
        res = await self.session.send_request(
×
1196
            'blockchain.transaction.id_from_pos',
1197
            [tx_height, tx_pos, merkle],
1198
        )
1199
        # check response
1200
        if merkle:
×
1201
            assert_dict_contains_field(res, field_name='tx_hash')
×
1202
            assert_dict_contains_field(res, field_name='merkle')
×
1203
            assert_hash256_str(res['tx_hash'])
×
1204
            assert_list_or_tuple(res['merkle'])
×
1205
            for node_hash in res['merkle']:
×
1206
                assert_hash256_str(node_hash)
×
1207
        else:
1208
            assert_hash256_str(res)
×
1209
        return res
×
1210

1211
    async def get_fee_histogram(self) -> Sequence[Tuple[Union[float, int], int]]:
4✔
1212
        # do request
1213
        res = await self.session.send_request('mempool.get_fee_histogram')
×
1214
        # check response
1215
        assert_list_or_tuple(res)
×
1216
        prev_fee = float('inf')
×
1217
        for fee, s in res:
×
1218
            assert_non_negative_int_or_float(fee)
×
1219
            assert_non_negative_integer(s)
×
1220
            if fee >= prev_fee:  # check monotonicity
×
1221
                raise RequestCorrupted(f'fees must be in decreasing order')
×
1222
            prev_fee = fee
×
1223
        return res
×
1224

1225
    async def get_server_banner(self) -> str:
4✔
1226
        # do request
1227
        res = await self.session.send_request('server.banner')
×
1228
        # check response
1229
        if not isinstance(res, str):
×
1230
            raise RequestCorrupted(f'{res!r} should be a str')
×
1231
        return res
×
1232

1233
    async def get_donation_address(self) -> str:
4✔
1234
        # do request
1235
        res = await self.session.send_request('server.donation_address')
×
1236
        # check response
1237
        if not res:  # ignore empty string
×
1238
            return ''
×
1239
        if not bitcoin.is_address(res):
×
1240
            # note: do not hard-fail -- allow server to use future-type
1241
            #       bitcoin address we do not recognize
1242
            self.logger.info(f"invalid donation address from server: {repr(res)}")
×
1243
            res = ''
×
1244
        return res
×
1245

1246
    async def get_relay_fee(self) -> int:
4✔
1247
        """Returns the min relay feerate in sat/kbyte."""
1248
        # do request
1249
        res = await self.session.send_request('blockchain.relayfee')
×
1250
        # check response
1251
        assert_non_negative_int_or_float(res)
×
1252
        relayfee = int(res * bitcoin.COIN)
×
1253
        relayfee = max(0, relayfee)
×
1254
        return relayfee
×
1255

1256
    async def get_estimatefee(self, num_blocks: int) -> int:
4✔
1257
        """Returns a feerate estimate for getting confirmed within
1258
        num_blocks blocks, in sat/kbyte.
1259
        Returns -1 if the server could not provide an estimate.
1260
        """
1261
        if not is_non_negative_integer(num_blocks):
×
1262
            raise Exception(f"{repr(num_blocks)} is not a num_blocks")
×
1263
        # do request
1264
        try:
×
1265
            res = await self.session.send_request('blockchain.estimatefee', [num_blocks])
×
1266
        except aiorpcx.jsonrpc.ProtocolError as e:
×
1267
            # The protocol spec says the server itself should already have returned -1
1268
            # if it cannot provide an estimate, however apparently "electrs" does not conform
1269
            # and sends an error instead. Convert it here:
1270
            if "cannot estimate fee" in e.message:
×
1271
                res = -1
×
1272
            else:
1273
                raise
×
1274
        except aiorpcx.jsonrpc.RPCError as e:
×
1275
            # The protocol spec says the server itself should already have returned -1
1276
            # if it cannot provide an estimate. "Fulcrum" often sends:
1277
            #   aiorpcx.jsonrpc.RPCError: (-32603, 'internal error: bitcoind request timed out')
1278
            if e.code == JSONRPC.INTERNAL_ERROR:
×
1279
                res = -1
×
1280
            else:
1281
                raise
×
1282
        # check response
1283
        if res != -1:
×
1284
            assert_non_negative_int_or_float(res)
×
1285
            res = int(res * bitcoin.COIN)
×
1286
        return res
×
1287

1288

1289
def _assert_header_does_not_check_against_any_chain(header: dict) -> None:
4✔
1290
    chain_bad = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header)
4✔
1291
    if chain_bad:
4✔
1292
        raise Exception('bad_header must not check!')
×
1293

1294

1295
def check_cert(host, cert):
4✔
1296
    try:
×
1297
        b = pem.dePem(cert, 'CERTIFICATE')
×
1298
        x = x509.X509(b)
×
1299
    except Exception:
×
1300
        traceback.print_exc(file=sys.stdout)
×
1301
        return
×
1302

1303
    try:
×
1304
        x.check_date()
×
1305
        expired = False
×
1306
    except Exception:
×
1307
        expired = True
×
1308

1309
    m = "host: %s\n"%host
×
1310
    m += "has_expired: %s\n"% expired
×
1311
    util.print_msg(m)
×
1312

1313

1314
# Used by tests
1315
def _match_hostname(name, val):
4✔
1316
    if val == name:
×
1317
        return True
×
1318

1319
    return val.startswith('*.') and name.endswith(val[1:])
×
1320

1321

1322
def test_certificates():
4✔
1323
    from .simple_config import SimpleConfig
×
1324
    config = SimpleConfig()
×
1325
    mydir = os.path.join(config.path, "certs")
×
1326
    certs = os.listdir(mydir)
×
1327
    for c in certs:
×
1328
        p = os.path.join(mydir,c)
×
1329
        with open(p, encoding='utf-8') as f:
×
1330
            cert = f.read()
×
1331
        check_cert(c, cert)
×
1332

1333
if __name__ == "__main__":
4✔
1334
    test_certificates()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc