• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

kyuupichan / aiorpcX / #439

17 Mar 2024 07:56AM UTC coverage: 95.73% (-2.8%) from 98.482%
#439

push

coveralls-python

kyuupichan
Prepare aiorpcX 0.23

1 of 1 new or added line in 1 file covered. (100.0%)

52 existing lines in 2 files now uncovered.

1771 of 1850 relevant lines covered (95.73%)

0.96 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

99.07
/aiorpcx/session.py
1
# Copyright (c) 2018-2019, Neil Booth
2
#
3
# All rights reserved.
4
#
5
# The MIT License (MIT)
6
#
7
# Permission is hereby granted, free of charge, to any person obtaining
8
# a copy of this software and associated documentation files (the
9
# "Software"), to deal in the Software without restriction, including
10
# without limitation the rights to use, copy, modify, merge, publish,
11
# distribute, sublicense, and/or sell copies of the Software, and to
12
# permit persons to whom the Software is furnished to do so, subject to
13
# the following conditions:
14
#
15
# The above copyright notice and this permission notice shall be
16
# included in all copies or substantial portions of the Software.
17
#
18
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
19
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
20
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
21
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
22
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
23
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
24
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
25

26

27
__all__ = ('RPCSession', 'MessageSession', 'ExcessiveSessionCostError',
1✔
28
           'BatchError', 'Concurrency', 'ReplyAndDisconnect', 'SessionKind')
29

30

31
import asyncio
1✔
32
from enum import Enum
1✔
33
import logging
1✔
34
from math import ceil
1✔
35
import time
1✔
36

37
from aiorpcx.curio import (
1✔
38
    TaskGroup, TaskTimeout, timeout_after, sleep
39
)
40
from aiorpcx.framing import (
1✔
41
    NewlineFramer, BitcoinFramer, BadMagicError, BadChecksumError, OversizedPayloadError
42
)
43
from aiorpcx.jsonrpc import (
1✔
44
    Request, Batch, Notification, ProtocolError, RPCError,
45
    JSONRPC, JSONRPCv2, JSONRPCConnection
46
)
47

48

49
class ReplyAndDisconnect(Exception):
1✔
50
    '''Force a session disconnect after sending result (a Python object or an RPCError).
51
    '''
52

53

54
class ExcessiveSessionCostError(RuntimeError):
1✔
55
    pass
1✔
56

57

58
class Concurrency:
1✔
59

60
    def __init__(self, target):
1✔
61
        self._target = int(target)
1✔
62
        self._semaphore = asyncio.Semaphore(self._target)
1✔
63
        self._sem_value = self._target
1✔
64

65
    async def _retarget_semaphore(self):
1✔
66
        if self._target <= 0:
1✔
67
            raise ExcessiveSessionCostError
1✔
68
        while self._sem_value < self._target:
1✔
69
            self._sem_value += 1
1✔
70
            self._semaphore.release()
1✔
71

72
    @property
1✔
73
    def max_concurrent(self):
1✔
74
        return self._target
1✔
75

76
    def set_target(self, target):
1✔
77
        self._target = int(target)
1✔
78

79
    async def __aenter__(self):
1✔
80
        await self._semaphore.acquire()
1✔
81
        await self._retarget_semaphore()
1✔
82

83
    async def __aexit__(self, exc_type, exc_value, traceback):
1✔
84
        if self._sem_value > self._target:
1✔
85
            self._sem_value -= 1
1✔
86
        else:
87
            self._semaphore.release()
1✔
88

89

90
class SessionKind(Enum):
1✔
91
    CLIENT = 'client'
1✔
92
    SERVER = 'server'
1✔
93

94

95
class SessionBase:
1✔
96
    '''Base class of networking sessions.
97

98
    There is no client / server distinction other than who initiated
99
    the connection.
100
    '''
101

102
    # Multiply this by bandwidth bytes used to get resource usage cost
103
    bw_cost_per_byte = 1 / 100000
1✔
104
    # If cost is over this requests begin to get delayed and concurrency is reduced
105
    cost_soft_limit = 2000
1✔
106
    # If cost is over this the session is closed
107
    cost_hard_limit = 10000
1✔
108
    # Resource usage is reduced by this every second
109
    cost_decay_per_sec = cost_hard_limit / 3600
1✔
110
    # Request delay ranges from 0 to this between cost_soft_limit and cost_hard_limit
111
    cost_sleep = 2.0
1✔
112
    # Base cost of an error.  Errors that took resources to discover incur additional costs
113
    error_base_cost = 100.0
1✔
114
    # Initial number of requests that can be concurrently processed
115
    initial_concurrent = 20
1✔
116
    # Send a "server busy" error if processing a request takes longer than this seconds
117
    processing_timeout = 30.0
1✔
118
    # Force-close a connection if its socket send buffer stays full this long
119
    max_send_delay = 20.0
1✔
120

121
    def __init__(self, transport, *, loop=None):
1✔
122
        self.transport = transport
1✔
123
        self.loop = loop or asyncio.get_event_loop()
1✔
124
        self.logger = logging.getLogger(self.__class__.__name__)
1✔
125
        # For logger.debug messsages
126
        self.verbosity = 0
1✔
127
        self._group = TaskGroup()
1✔
128
        # Statistics.  The RPC object also keeps its own statistics.
129
        self.start_time = time.time()
1✔
130
        self.errors = 0
1✔
131
        self.send_count = 0
1✔
132
        self.send_size = 0
1✔
133
        self.last_send = self.start_time
1✔
134
        self.recv_count = 0
1✔
135
        self.recv_size = 0
1✔
136
        self.last_recv = self.start_time
1✔
137
        # Resource usage
138
        self.cost = 0.0
1✔
139
        self._cost_last = 0.0
1✔
140
        self._cost_time = self.start_time
1✔
141
        self._cost_fraction = 0.0
1✔
142
        # Concurrency control for incoming request handling
143
        self._incoming_concurrency = Concurrency(self.initial_concurrent)
1✔
144
        # By default, do not limit outgoing connections
145
        if self.session_kind == SessionKind.CLIENT:
1✔
146
            self.cost_hard_limit = 0
1✔
147

148
    async def _send_message(self, message):
1✔
149
        if self.verbosity >= 4:
1✔
150
            self.logger.debug(f'sending message {message}')
1✔
151
        try:
1✔
152
            async with timeout_after(self.max_send_delay):
1✔
153
                await self.transport.write(message)
1✔
154
        except TaskTimeout:
1✔
155
            await self.abort()
1✔
156
            raise
1✔
157
        self.send_size += len(message)
1✔
158
        self.bump_cost(len(message) * self.bw_cost_per_byte)
1✔
159
        self.send_count += 1
1✔
160
        self.last_send = time.time()
1✔
161
        return self.last_send
1✔
162

163
    def _bump_errors(self, exception=None):
1✔
164
        self.errors += 1
1✔
165
        self.bump_cost(self.error_base_cost + getattr(exception, 'cost', 0.0))
1✔
166

167
    @property
1✔
168
    def session_kind(self):
1✔
169
        '''Either client or server.'''
170
        return self.transport.kind
1✔
171

172
    async def connection_lost(self):
1✔
173
        pass
1✔
174

175
    def data_received(self, data):
1✔
176
        if self.verbosity >= 2:
1✔
177
            self.logger.debug(f'received data {data}')
1✔
178
        self.recv_size += len(data)
1✔
179
        self.bump_cost(len(data) * self.bw_cost_per_byte)
1✔
180

181
    def bump_cost(self, delta):
1✔
182
        # Delta can be positive or negative
183
        self.cost = max(0, self.cost + delta)
1✔
184
        if abs(self.cost - self._cost_last) > 100:
1✔
185
            self.recalc_concurrency()
1✔
186

187
    def on_disconnect_due_to_excessive_session_cost(self):
1✔
188
        '''Called just before disconnecting from the session, if it was consuming too
189
        much resources.
190
        '''
191

192
    def recalc_concurrency(self):
1✔
193
        '''Call to recalculate sleeps and concurrency for the session.  Called automatically if
194
        cost has drifted significantly.  Otherwise can be called at regular intervals if
195
        desired.
196
        '''
197
        # Refund resource usage proportionally to elapsed time; the bump passed is negative
198
        now = time.time()
1✔
199
        self.cost = max(0, self.cost - (now - self._cost_time) * self.cost_decay_per_sec)
1✔
200
        self._cost_time = now
1✔
201
        self._cost_last = self.cost
1✔
202

203
        # Setting cost_hard_limit <= 0 means to not limit concurrency
204
        value = self._incoming_concurrency.max_concurrent
1✔
205
        cost_soft_range = self.cost_hard_limit - self.cost_soft_limit
1✔
206
        if cost_soft_range <= 0:
1✔
207
            return
1✔
208

209
        cost = self.cost + self.extra_cost()
1✔
210
        self._cost_fraction = max(0.0, (cost - self.cost_soft_limit) / cost_soft_range)
1✔
211

212
        target = max(0, ceil((1.0 - self._cost_fraction) * self.initial_concurrent))
1✔
213
        if abs(target - value) > 1:
1✔
214
            self.logger.info(f'changing task concurrency from {value} to {target}')
1✔
215
        self._incoming_concurrency.set_target(target)
1✔
216

217
    async def _process_messages(self, recv_message):
1✔
218
        try:
1✔
219
            await self._process_messages_loop(recv_message)
1✔
220
        finally:
221
            # Call the hook provided for derived classes
222
            await self.connection_lost()
1✔
223

224
    async def process_messages(self, recv_message):
1✔
225
        async with self._group as group:
1✔
226
            await group.spawn(self._process_messages, recv_message)
1✔
227

228
            # Remove tasks
229
            async for task in group:
1✔
230
                task.result()
1✔
231

232
    def unanswered_request_count(self):
1✔
233
        '''The number of requests received but not yet answered.'''
234
        # Max with zero in case the message processing task hasn't yet spawned
235
        return max(0, len(self._group._pending) - 1)
1✔
236

237
    def extra_cost(self):
1✔
238
        '''A dynamic value added to this session's cost when deciding how much to throttle
239
        requests.  Can be negative.
240
        '''
241
        return 0.0
1✔
242

243
    def default_framer(self):
1✔
244
        '''Return a default framer.'''
245
        raise NotImplementedError
×
246

247
    def proxy(self):
1✔
248
        '''Returns the proxy used, or None.'''
249
        return self.transport.proxy()
1✔
250

251
    def remote_address(self):
1✔
252
        '''Returns a NetAddress or None if not connected.'''
253
        return self.transport.remote_address()
1✔
254

255
    def is_closing(self):
1✔
256
        '''Return True if the connection is closing.'''
257
        return self.transport.is_closing()
1✔
258

259
    async def abort(self):
1✔
260
        '''Forcefully close the connection.'''
261
        await self.transport.abort()
1✔
262

263
    async def close(self, *, force_after=30):
1✔
264
        '''Close the connection and return when closed.'''
265
        await self.transport.close(force_after)
1✔
266

267

268
class MessageSession(SessionBase):
1✔
269
    '''Session class for protocols where messages are not tied to responses,
270
    such as the Bitcoin protocol.
271
    '''
272
    async def _process_messages_loop(self, recv_message):
1✔
273
        while True:
1✔
274
            try:
1✔
275
                message = await recv_message()
1✔
276
            except BadMagicError as e:
1✔
277
                magic, expected = e.args
1✔
278
                self.logger.error(
1✔
279
                    f'bad network magic: got {magic} expected {expected}, '
280
                    f'disconnecting'
281
                )
282
                self._bump_errors(e)
1✔
283
                await self._group.spawn(self.close)
1✔
284
                await sleep(0.001)
1✔
285
            except OversizedPayloadError as e:
1✔
286
                command, payload_len = e.args
1✔
287
                self.logger.error(
1✔
288
                    f'oversized payload of {payload_len:,d} bytes to command '
289
                    f'{command}, disconnecting'
290
                )
291
                self._bump_errors(e)
1✔
292
                await self._group.spawn(self.close)
1✔
293
                await sleep(0.001)
1✔
294
            except BadChecksumError as e:
1✔
295
                payload_checksum, claimed_checksum = e.args
1✔
296
                self.logger.warning(
1✔
297
                    f'checksum mismatch: actual {payload_checksum.hex()} '
298
                    f'vs claimed {claimed_checksum.hex()}'
299
                )
300
                self._bump_errors(e)
1✔
301
            else:
302
                self.last_recv = time.time()
1✔
303
                self.recv_count += 1
1✔
304
                await self._group.spawn(self._throttled_message(message))
1✔
305

306
    async def _throttled_message(self, message):
1✔
307
        '''Process a single request, respecting the concurrency limit.'''
308
        try:
1✔
309
            timeout = self.processing_timeout
1✔
310
            async with timeout_after(timeout):
1✔
311
                async with self._incoming_concurrency:
1✔
312
                    if self._cost_fraction:
1✔
313
                        await sleep(self._cost_fraction * self.cost_sleep)
1✔
314
                    await self.handle_message(message)
1✔
315
        except ProtocolError as e:
1✔
316
            self.logger.error(f'{e}')
1✔
317
            self._bump_errors(e)
1✔
318
        except TaskTimeout:
1✔
319
            self.logger.info(f'incoming request timed out after {timeout} secs')
1✔
320
            self._bump_errors()
1✔
321
        except ExcessiveSessionCostError:
1✔
322
            self.on_disconnect_due_to_excessive_session_cost()
1✔
323
            await self.close()
1✔
324
        except Exception:
1✔
325
            self.logger.exception(f'exception handling {message}')
1✔
326
            self._bump_errors()
1✔
327

328
    def default_framer(self):
1✔
329
        '''Return a bitcoin framer.'''
330
        return BitcoinFramer()
1✔
331

332
    async def handle_message(self, message):
1✔
333
        '''message is a (command, payload) pair.'''
334

335
    async def send_message(self, message):
1✔
336
        '''Send a message (command, payload) over the network.'''
337
        await self._send_message(message)
1✔
338

339

340
class BatchError(Exception):
1✔
341

342
    def __init__(self, request):
1✔
343
        super().__init__(request)
1✔
344
        self.request = request   # BatchRequest object
1✔
345

346

347
class BatchRequest:
1✔
348
    '''Used to build a batch request to send to the server.  Stores
349
    the
350

351
    Attributes batch and results are initially None.
352

353
    Adding an invalid request or notification immediately raises a
354
    ProtocolError.
355

356
    On exiting the with clause, it will:
357

358
    1) create a Batch object for the requests in the order they were
359
       added.  If the batch is empty this raises a ProtocolError.
360

361
    2) set the "batch" attribute to be that batch
362

363
    3) send the batch request and wait for a response
364

365
    4) raise a ProtocolError if the protocol was violated by the
366
       server.  Currently this only happens if it gave more than one
367
       response to any request
368

369
    5) otherwise there is precisely one response to each Request.  Set
370
       the "results" attribute to the tuple of results; the responses
371
       are ordered to match the Requests in the batch.  Notifications
372
       do not get a response.
373

374
    6) if raise_errors is True and any individual response was a JSON
375
       RPC error response, or violated the protocol in some way, a
376
       BatchError exception is raised.  Otherwise the caller can be
377
       certain each request returned a standard result.
378
    '''
379

380
    def __init__(self, session, raise_errors):
1✔
381
        self._session = session
1✔
382
        self._raise_errors = raise_errors
1✔
383
        self._requests = []
1✔
384
        self.batch = None
1✔
385
        self.results = None
1✔
386

387
    def add_request(self, method, args=()):
1✔
388
        self._requests.append(Request(method, args))
1✔
389

390
    def add_notification(self, method, args=()):
1✔
391
        self._requests.append(Notification(method, args))
1✔
392

393
    def __len__(self):
1✔
394
        return len(self._requests)
1✔
395

396
    async def __aenter__(self):
1✔
397
        return self
1✔
398

399
    async def __aexit__(self, exc_type, exc_value, traceback):
1✔
400
        if exc_type is None:
1✔
401
            self.batch = Batch(self._requests)
1✔
402
            message, future = self._session.connection.send_batch(self.batch)
1✔
403
            self.results = await self._session._send_concurrent(message, future, len(self.batch))
1✔
404
            if self._raise_errors:
1✔
405
                if any(isinstance(item, Exception) for item in self.results):
1✔
406
                    raise BatchError(self)
1✔
407

408

409
class RPCSession(SessionBase):
1✔
410
    '''Base class for protocols where a message can lead to a response,
411
    for example JSON RPC.'''
412

413
    # Adjust outgoing request concurrency to target a round trip response time of
414
    # this many seconds, recalibrating every recalibrate_count requests
415
    target_response_time = 3.0
1✔
416
    recalibrate_count = 30
1✔
417
    # Raise a TaskTimeout if getting a response takes longer than this
418
    sent_request_timeout = 30.0
1✔
419
    log_me = False
1✔
420

421
    def __init__(self, transport, *, loop=None, connection=None):
1✔
422
        super().__init__(transport, loop=loop)
1✔
423
        self.connection = connection or self.default_connection()
1✔
424
        # Concurrency control for outgoing request sending
425
        self._outgoing_concurrency = Concurrency(50)
1✔
426
        self._req_times = []
1✔
427

428
    def _recalc_concurrency(self):
1✔
429
        req_times = self._req_times
1✔
430
        avg = sum(req_times) / len(req_times)
1✔
431
        req_times.clear()
1✔
432
        current = self._outgoing_concurrency.max_concurrent
1✔
433
        cap = min(current + max(3, current * 0.1), 250)
1✔
434
        floor = max(1, min(current * 0.8, current - 1))
1✔
435
        if avg != 0:
1✔
UNCOV
436
            target = max(floor, min(cap, current * self.target_response_time / avg))
×
437
        else:
438
            target = cap
1✔
439
        target = int(0.5 + target)
1✔
440
        if target != current:
1✔
441
            self.logger.info(f'changing outgoing request concurrency to {target} from {current}')
1✔
442
            self._outgoing_concurrency.set_target(target)
1✔
443

444
    async def _process_messages_loop(self, recv_message):
1✔
445
        # The loop will exit when recv_message raises a ConnectionLost error; which is also
446
        # arranged when close() is called.
447
        while True:
1✔
448
            try:
1✔
449
                message = await recv_message()
1✔
450
            except MemoryError as e:
1✔
451
                self.logger.warning(f'{e!r}')
1✔
452
                continue
1✔
453

454
            self.last_recv = time.time()
1✔
455
            self.recv_count += 1
1✔
456
            if self.log_me:
1✔
457
                self.logger.info(f'processing {message}')
1✔
458

459
            try:
1✔
460
                requests = self.connection.receive_message(message)
1✔
461
            except ProtocolError as e:
1✔
462
                self.logger.debug(str(e))
1✔
463
                if e.code == JSONRPC.PARSE_ERROR:
1✔
464
                    e.cost = self.error_base_cost * 10
1✔
465
                self._bump_errors(e)
1✔
466
                if e.error_message:
1✔
467
                    await self._send_message(e.error_message)
1✔
468
            else:
469
                for request in requests:
1✔
470
                    await self._group.spawn(self._throttled_request(request))
1✔
471

472
    async def _throttled_request(self, request):
1✔
473
        '''Process a single request, respecting the concurrency limit.'''
474
        disconnect = False
1✔
475
        try:
1✔
476
            timeout = self.processing_timeout
1✔
477
            async with timeout_after(timeout):
1✔
478
                async with self._incoming_concurrency:
1✔
479
                    if self._cost_fraction:
1✔
480
                        await sleep(self._cost_fraction * self.cost_sleep)
1✔
481
                    result = await self.handle_request(request)
1✔
482
        except (ProtocolError, RPCError) as e:
1✔
483
            result = e
1✔
484
        except TaskTimeout:
1✔
485
            self.logger.info(f'incoming request {request} timed out after {timeout} secs')
1✔
486
            result = RPCError(JSONRPC.SERVER_BUSY, 'server busy - request timed out')
1✔
487
        except ReplyAndDisconnect as e:
1✔
488
            result = e.args[0]
1✔
489
            disconnect = True
1✔
490
        except ExcessiveSessionCostError:
1✔
491
            self.on_disconnect_due_to_excessive_session_cost()
1✔
492
            result = RPCError(JSONRPC.EXCESSIVE_RESOURCE_USAGE, 'excessive resource usage')
1✔
493
            disconnect = True
1✔
494
        except Exception:
1✔
495
            self.logger.exception(f'exception handling {request}')
1✔
496
            result = RPCError(JSONRPC.INTERNAL_ERROR, 'internal server error')
1✔
497

498
        if isinstance(request, Request):
1✔
499
            message = request.send_result(result)
1✔
500
            if message:
1✔
501
                await self._send_message(message)
1✔
502
        if isinstance(result, Exception):
1✔
503
            self._bump_errors(result)
1✔
504
        if disconnect:
1✔
505
            await self.close()
1✔
506

507
    async def _send_concurrent(self, message, future, request_count):
1✔
508
        async with self._outgoing_concurrency:
1✔
509
            send_time = await self._send_message(message)
1✔
510
            try:
1✔
511
                async with timeout_after(self.sent_request_timeout):
1✔
512
                    return await future
1✔
513
            finally:
514
                time_taken = max(0, time.time() - send_time)
1✔
515
                if request_count == 1:
1✔
516
                    self._req_times.append(time_taken)
1✔
517
                else:
518
                    self._req_times.extend([time_taken / request_count] * request_count)
1✔
519
                if len(self._req_times) >= self.recalibrate_count:
1✔
520
                    self._recalc_concurrency()
1✔
521

522
    # External API
523
    async def connection_lost(self):
1✔
524
        self.connection.cancel_pending_requests()
1✔
525

526
    def default_connection(self):
1✔
527
        '''Return a default connection if the user provides none.'''
528
        return JSONRPCConnection(JSONRPCv2)
1✔
529

530
    def default_framer(self):
1✔
531
        '''Return a default framer.'''
532
        return NewlineFramer()
1✔
533

534
    async def handle_request(self, request):
1✔
535
        pass
×
536

537
    async def send_request(self, method, args=()):
1✔
538
        '''Send an RPC request over the network.'''
539
        message, future = self.connection.send_request(Request(method, args))
1✔
540
        return await self._send_concurrent(message, future, 1)
1✔
541

542
    async def send_notification(self, method, args=()):
1✔
543
        '''Send an RPC notification over the network.'''
544
        message = self.connection.send_notification(Notification(method, args))
1✔
545
        await self._send_message(message)
1✔
546

547
    def send_batch(self, raise_errors=False):
1✔
548
        '''Return a BatchRequest.  Intended to be used like so:
549

550
           async with session.send_batch() as batch:
551
               batch.add_request("method1")
552
               batch.add_request("sum", (x, y))
553
               batch.add_notification("updated")
554

555
           for result in batch.results:
556
              ...
557

558
        Note that in some circumstances exceptions can be raised; see
559
        BatchRequest doc string.
560
        '''
561
        return BatchRequest(self, raise_errors)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc