• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

freqtrade / freqtrade / 9394559170

26 Apr 2024 06:36AM UTC coverage: 94.656% (-0.02%) from 94.674%
9394559170

push

github

xmatthias
Loader should be passed as kwarg for clarity

20280 of 21425 relevant lines covered (94.66%)

0.95 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.12
/freqtrade/rpc/external_message_consumer.py
1
"""
2
ExternalMessageConsumer module
3

4
Main purpose is to connect to external bot's message websocket to consume data
5
from it
6
"""
7
import asyncio
1✔
8
import logging
1✔
9
import socket
1✔
10
from threading import Thread
1✔
11
from typing import TYPE_CHECKING, Any, Callable, Dict, List, TypedDict, Union
1✔
12

13
import websockets
1✔
14
from pydantic import ValidationError
1✔
15

16
from freqtrade.constants import FULL_DATAFRAME_THRESHOLD
1✔
17
from freqtrade.data.dataprovider import DataProvider
1✔
18
from freqtrade.enums import RPCMessageType
1✔
19
from freqtrade.misc import remove_entry_exit_signals
1✔
20
from freqtrade.rpc.api_server.ws.channel import WebSocketChannel, create_channel
1✔
21
from freqtrade.rpc.api_server.ws.message_stream import MessageStream
1✔
22
from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSAnalyzedDFRequest,
1✔
23
                                                 WSMessageSchema, WSRequestSchema,
24
                                                 WSSubscribeRequest, WSWhitelistMessage,
25
                                                 WSWhitelistRequest)
26

27

28
if TYPE_CHECKING:
1✔
29
    import websockets.connect
×
30

31

32
class Producer(TypedDict):
1✔
33
    name: str
1✔
34
    host: str
1✔
35
    port: int
1✔
36
    secure: bool
1✔
37
    ws_token: str
1✔
38

39

40
logger = logging.getLogger(__name__)
1✔
41

42

43
def schema_to_dict(schema: Union[WSMessageSchema, WSRequestSchema]):
1✔
44
    return schema.model_dump(exclude_none=True)
1✔
45

46

47
class ExternalMessageConsumer:
1✔
48
    """
49
    The main controller class for consuming external messages from
50
    other freqtrade bot's
51
    """
52

53
    def __init__(
1✔
54
        self,
55
        config: Dict[str, Any],
56
        dataprovider: DataProvider
57
    ):
58
        self._config = config
1✔
59
        self._dp = dataprovider
1✔
60

61
        self._running = False
1✔
62
        self._thread = None
1✔
63
        self._loop = None
1✔
64
        self._main_task = None
1✔
65
        self._sub_tasks = None
1✔
66

67
        self._emc_config = self._config.get('external_message_consumer', {})
1✔
68

69
        self.enabled = self._emc_config.get('enabled', False)
1✔
70
        self.producers: List[Producer] = self._emc_config.get('producers', [])
1✔
71

72
        self.wait_timeout = self._emc_config.get('wait_timeout', 30)  # in seconds
1✔
73
        self.ping_timeout = self._emc_config.get('ping_timeout', 10)  # in seconds
1✔
74
        self.sleep_time = self._emc_config.get('sleep_time', 10)  # in seconds
1✔
75

76
        # The amount of candles per dataframe on the initial request
77
        self.initial_candle_limit = self._emc_config.get('initial_candle_limit', 1500)
1✔
78

79
        # Message size limit, in megabytes. Default 8mb, Use bitwise operator << 20 to convert
80
        # as the websockets client expects bytes.
81
        self.message_size_limit = (self._emc_config.get('message_size_limit', 8) << 20)
1✔
82

83
        # Setting these explicitly as they probably shouldn't be changed by a user
84
        # Unless we somehow integrate this with the strategy to allow creating
85
        # callbacks for the messages
86
        self.topics = [RPCMessageType.WHITELIST, RPCMessageType.ANALYZED_DF]
1✔
87

88
        # Allow setting data for each initial request
89
        self._initial_requests: List[WSRequestSchema] = [
1✔
90
            WSSubscribeRequest(data=self.topics),
91
            WSWhitelistRequest(),
92
            WSAnalyzedDFRequest()
93
        ]
94

95
        # Specify which function to use for which RPCMessageType
96
        self._message_handlers: Dict[str, Callable[[str, WSMessageSchema], None]] = {
1✔
97
            RPCMessageType.WHITELIST: self._consume_whitelist_message,
98
            RPCMessageType.ANALYZED_DF: self._consume_analyzed_df_message,
99
        }
100

101
        self._channel_streams: Dict[str, MessageStream] = {}
1✔
102

103
        self.start()
1✔
104

105
    def start(self):
1✔
106
        """
107
        Start the main internal loop in another thread to run coroutines
108
        """
109
        if self._thread and self._loop:
1✔
110
            return
1✔
111

112
        logger.info("Starting ExternalMessageConsumer")
1✔
113

114
        self._loop = asyncio.new_event_loop()
1✔
115
        self._thread = Thread(target=self._loop.run_forever)
1✔
116
        self._running = True
1✔
117
        self._thread.start()
1✔
118

119
        self._main_task = asyncio.run_coroutine_threadsafe(self._main(), loop=self._loop)
1✔
120

121
    def shutdown(self):
1✔
122
        """
123
        Shutdown the loop, thread, and tasks
124
        """
125
        if self._thread and self._loop:
1✔
126
            logger.info("Stopping ExternalMessageConsumer")
1✔
127
            self._running = False
1✔
128

129
            self._channel_streams = {}
1✔
130

131
            if self._sub_tasks:
1✔
132
                # Cancel sub tasks
133
                for task in self._sub_tasks:
1✔
134
                    task.cancel()
1✔
135

136
            if self._main_task:
1✔
137
                # Cancel the main task
138
                self._main_task.cancel()
1✔
139

140
            self._thread.join()
1✔
141

142
            self._thread = None
1✔
143
            self._loop = None
1✔
144
            self._sub_tasks = None
1✔
145
            self._main_task = None
1✔
146

147
    async def _main(self):
1✔
148
        """
149
        The main task coroutine
150
        """
151
        lock = asyncio.Lock()
1✔
152

153
        try:
1✔
154
            # Create a connection to each producer
155
            self._sub_tasks = [
1✔
156
                self._loop.create_task(self._handle_producer_connection(producer, lock))
157
                for producer in self.producers
158
            ]
159

160
            await asyncio.gather(*self._sub_tasks)
1✔
161
        except asyncio.CancelledError:
1✔
162
            pass
1✔
163
        finally:
164
            # Stop the loop once we are done
165
            self._loop.stop()
1✔
166

167
    async def _handle_producer_connection(self, producer: Producer, lock: asyncio.Lock):
1✔
168
        """
169
        Main connection loop for the consumer
170

171
        :param producer: Dictionary containing producer info
172
        :param lock: An asyncio Lock
173
        """
174
        try:
1✔
175
            await self._create_connection(producer, lock)
1✔
176
        except asyncio.CancelledError:
1✔
177
            # Exit silently
178
            pass
1✔
179

180
    async def _create_connection(self, producer: Producer, lock: asyncio.Lock):
1✔
181
        """
182
        Actually creates and handles the websocket connection, pinging on timeout
183
        and handling connection errors.
184

185
        :param producer: Dictionary containing producer info
186
        :param lock: An asyncio Lock
187
        """
188
        while self._running:
1✔
189
            try:
1✔
190
                host, port = producer['host'], producer['port']
1✔
191
                token = producer['ws_token']
1✔
192
                name = producer['name']
1✔
193
                scheme = 'wss' if producer.get('secure', False) else 'ws'
1✔
194
                ws_url = f"{scheme}://{host}:{port}/api/v1/message/ws?token={token}"
1✔
195

196
                # This will raise InvalidURI if the url is bad
197
                async with websockets.connect(
1✔
198
                    ws_url,
199
                    max_size=self.message_size_limit,
200
                    ping_interval=None
201
                ) as ws:
202
                    async with create_channel(
1✔
203
                        ws,
204
                        channel_id=name,
205
                        send_throttle=0.5
206
                    ) as channel:
207

208
                        # Create the message stream for this channel
209
                        self._channel_streams[name] = MessageStream()
1✔
210

211
                        # Run the channel tasks while connected
212
                        await channel.run_channel_tasks(
1✔
213
                            self._receive_messages(channel, producer, lock),
214
                            self._send_requests(channel, self._channel_streams[name])
215
                        )
216

217
            except (websockets.exceptions.InvalidURI, ValueError) as e:
1✔
218
                logger.error(f"{ws_url} is an invalid WebSocket URL - {e}")
1✔
219
                break
1✔
220

221
            except (
1✔
222
                socket.gaierror,
223
                ConnectionRefusedError,
224
                websockets.exceptions.InvalidStatusCode,
225
                websockets.exceptions.InvalidMessage
226
            ) as e:
227
                logger.error(f"Connection Refused - {e} retrying in {self.sleep_time}s")
1✔
228
                await asyncio.sleep(self.sleep_time)
1✔
229
                continue
×
230

231
            except (
1✔
232
                websockets.exceptions.ConnectionClosedError,
233
                websockets.exceptions.ConnectionClosedOK
234
            ):
235
                # Just keep trying to connect again indefinitely
236
                await asyncio.sleep(self.sleep_time)
×
237
                continue
×
238

239
            except Exception as e:
1✔
240
                # An unforeseen error has occurred, log and continue
241
                logger.error("Unexpected error has occurred:")
1✔
242
                logger.exception(e)
1✔
243
                await asyncio.sleep(self.sleep_time)
1✔
244
                continue
×
245

246
    async def _send_requests(self, channel: WebSocketChannel, channel_stream: MessageStream):
1✔
247
        # Send the initial requests
248
        for init_request in self._initial_requests:
1✔
249
            await channel.send(schema_to_dict(init_request))
1✔
250

251
        # Now send any subsequent requests published to
252
        # this channel's stream
253
        async for request, _ in channel_stream:
×
254
            logger.debug(f"Sending request to channel - {channel} - {request}")
×
255
            await channel.send(request)
×
256

257
    async def _receive_messages(
1✔
258
        self,
259
        channel: WebSocketChannel,
260
        producer: Producer,
261
        lock: asyncio.Lock
262
    ):
263
        """
264
        Loop to handle receiving messages from a Producer
265

266
        :param channel: The WebSocketChannel object for the WebSocket
267
        :param producer: Dictionary containing producer info
268
        :param lock: An asyncio Lock
269
        """
270
        while self._running:
1✔
271
            try:
1✔
272
                message = await asyncio.wait_for(
1✔
273
                    channel.recv(),
274
                    timeout=self.wait_timeout
275
                )
276

277
                try:
1✔
278
                    async with lock:
1✔
279
                        # Handle the message
280
                        self.handle_producer_message(producer, message)
1✔
281
                except Exception as e:
1✔
282
                    logger.exception(f"Error handling producer message: {e}")
1✔
283

284
            except (asyncio.TimeoutError, websockets.exceptions.ConnectionClosed):
1✔
285
                # We haven't received data yet. Check the connection and continue.
286
                try:
1✔
287
                    # ping
288
                    pong = await channel.ping()
1✔
289
                    latency = (await asyncio.wait_for(pong, timeout=self.ping_timeout) * 1000)
1✔
290

291
                    logger.info(f"Connection to {channel} still alive, latency: {latency}ms")
×
292
                    continue
×
293

294
                except Exception as e:
1✔
295
                    # Just eat the error and continue reconnecting
296
                    logger.warning(f"Ping error {channel} - {e} - retrying in {self.sleep_time}s")
1✔
297
                    logger.debug(e, exc_info=e)
1✔
298
                    raise
1✔
299

300
    def send_producer_request(
1✔
301
        self,
302
        producer_name: str,
303
        request: Union[WSRequestSchema, Dict[str, Any]]
304
    ):
305
        """
306
        Publish a message to the producer's message stream to be
307
        sent by the channel task.
308

309
        :param producer_name: The name of the producer to publish the message to
310
        :param request: The request to send to the producer
311
        """
312
        if isinstance(request, WSRequestSchema):
1✔
313
            request = schema_to_dict(request)
1✔
314

315
        if channel_stream := self._channel_streams.get(producer_name):
1✔
316
            channel_stream.publish(request)
×
317

318
    def handle_producer_message(self, producer: Producer, message: Dict[str, Any]):
1✔
319
        """
320
        Handles external messages from a Producer
321
        """
322
        producer_name = producer.get('name', 'default')
1✔
323

324
        try:
1✔
325
            producer_message = WSMessageSchema.model_validate(message)
1✔
326
        except ValidationError as e:
1✔
327
            logger.error(f"Invalid message from `{producer_name}`: {e}")
1✔
328
            return
1✔
329

330
        if not producer_message.data:
1✔
331
            logger.error(f"Empty message received from `{producer_name}`")
1✔
332
            return
1✔
333

334
        logger.debug(f"Received message of type `{producer_message.type}` from `{producer_name}`")
1✔
335

336
        message_handler = self._message_handlers.get(producer_message.type)
1✔
337

338
        if not message_handler:
1✔
339
            logger.info(f"Received unhandled message: `{producer_message.data}`, ignoring...")
1✔
340
            return
1✔
341

342
        message_handler(producer_name, producer_message)
1✔
343

344
    def _consume_whitelist_message(self, producer_name: str, message: WSMessageSchema):
1✔
345
        try:
1✔
346
            # Validate the message
347
            whitelist_message = WSWhitelistMessage.model_validate(message.model_dump())
1✔
348
        except ValidationError as e:
1✔
349
            logger.error(f"Invalid message from `{producer_name}`: {e}")
1✔
350
            return
1✔
351

352
        # Add the pairlist data to the DataProvider
353
        self._dp._set_producer_pairs(whitelist_message.data, producer_name=producer_name)
1✔
354

355
        logger.debug(f"Consumed message from `{producer_name}` of type `RPCMessageType.WHITELIST`")
1✔
356

357
    def _consume_analyzed_df_message(self, producer_name: str, message: WSMessageSchema):
1✔
358
        try:
1✔
359
            df_message = WSAnalyzedDFMessage.model_validate(message.model_dump())
1✔
360
        except ValidationError as e:
1✔
361
            logger.error(f"Invalid message from `{producer_name}`: {e}")
1✔
362
            return
1✔
363

364
        key = df_message.data.key
1✔
365
        df = df_message.data.df
1✔
366
        la = df_message.data.la
1✔
367

368
        pair, timeframe, candle_type = key
1✔
369

370
        if df.empty:
1✔
371
            logger.debug(f"Received Empty Dataframe for {key}")
1✔
372
            return
1✔
373

374
        # If set, remove the Entry and Exit signals from the Producer
375
        if self._emc_config.get('remove_entry_exit_signals', False):
1✔
376
            df = remove_entry_exit_signals(df)
×
377

378
        logger.debug(f"Received {len(df)} candle(s) for {key}")
1✔
379

380
        did_append, n_missing = self._dp._add_external_df(
1✔
381
            pair,
382
            df,
383
            last_analyzed=la,
384
            timeframe=timeframe,
385
            candle_type=candle_type,
386
            producer_name=producer_name
387
            )
388

389
        if not did_append:
1✔
390
            # We want an overlap in candles in case some data has changed
391
            n_missing += 1
1✔
392
            # Set to None for all candles if we missed a full df's worth of candles
393
            n_missing = n_missing if n_missing < FULL_DATAFRAME_THRESHOLD else 1500
1✔
394

395
            logger.warning(f"Holes in data or no existing df, requesting {n_missing} candles "
1✔
396
                           f"for {key} from `{producer_name}`")
397

398
            self.send_producer_request(
1✔
399
                producer_name,
400
                WSAnalyzedDFRequest(
401
                    data={
402
                        "limit": n_missing,
403
                        "pair": pair
404
                    }
405
                )
406
            )
407
            return
1✔
408

409
        logger.debug(
×
410
            f"Consumed message from `{producer_name}` "
411
            f"of type `RPCMessageType.ANALYZED_DF` for {key}")
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc