• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

freqtrade / freqtrade / 9394558078

17 May 2024 04:27PM UTC coverage: 94.674% (-0.009%) from 94.683%
9394558078

push

github

xmatthias
Add simple test for "fetch_my_trades" parsing quality

20334 of 21478 relevant lines covered (94.67%)

0.95 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.06
/freqtrade/rpc/external_message_consumer.py
1
"""
2
ExternalMessageConsumer module
3

4
Main purpose is to connect to external bot's message websocket to consume data
5
from it
6
"""
7

8
import asyncio
1✔
9
import logging
1✔
10
import socket
1✔
11
from threading import Thread
1✔
12
from typing import TYPE_CHECKING, Any, Callable, Dict, List, TypedDict, Union
1✔
13

14
import websockets
1✔
15
from pydantic import ValidationError
1✔
16

17
from freqtrade.constants import FULL_DATAFRAME_THRESHOLD
1✔
18
from freqtrade.data.dataprovider import DataProvider
1✔
19
from freqtrade.enums import RPCMessageType
1✔
20
from freqtrade.misc import remove_entry_exit_signals
1✔
21
from freqtrade.rpc.api_server.ws.channel import WebSocketChannel, create_channel
1✔
22
from freqtrade.rpc.api_server.ws.message_stream import MessageStream
1✔
23
from freqtrade.rpc.api_server.ws_schemas import (
1✔
24
    WSAnalyzedDFMessage,
25
    WSAnalyzedDFRequest,
26
    WSMessageSchema,
27
    WSRequestSchema,
28
    WSSubscribeRequest,
29
    WSWhitelistMessage,
30
    WSWhitelistRequest,
31
)
32

33

34
if TYPE_CHECKING:
1✔
35
    import websockets.connect
×
36

37

38
class Producer(TypedDict):
1✔
39
    name: str
1✔
40
    host: str
1✔
41
    port: int
1✔
42
    secure: bool
1✔
43
    ws_token: str
1✔
44

45

46
logger = logging.getLogger(__name__)
1✔
47

48

49
def schema_to_dict(schema: Union[WSMessageSchema, WSRequestSchema]):
1✔
50
    return schema.model_dump(exclude_none=True)
1✔
51

52

53
class ExternalMessageConsumer:
1✔
54
    """
55
    The main controller class for consuming external messages from
56
    other freqtrade bot's
57
    """
58

59
    def __init__(self, config: Dict[str, Any], dataprovider: DataProvider):
1✔
60
        self._config = config
1✔
61
        self._dp = dataprovider
1✔
62

63
        self._running = False
1✔
64
        self._thread = None
1✔
65
        self._loop = None
1✔
66
        self._main_task = None
1✔
67
        self._sub_tasks = None
1✔
68

69
        self._emc_config = self._config.get("external_message_consumer", {})
1✔
70

71
        self.enabled = self._emc_config.get("enabled", False)
1✔
72
        self.producers: List[Producer] = self._emc_config.get("producers", [])
1✔
73

74
        self.wait_timeout = self._emc_config.get("wait_timeout", 30)  # in seconds
1✔
75
        self.ping_timeout = self._emc_config.get("ping_timeout", 10)  # in seconds
1✔
76
        self.sleep_time = self._emc_config.get("sleep_time", 10)  # in seconds
1✔
77

78
        # The amount of candles per dataframe on the initial request
79
        self.initial_candle_limit = self._emc_config.get("initial_candle_limit", 1500)
1✔
80

81
        # Message size limit, in megabytes. Default 8mb, Use bitwise operator << 20 to convert
82
        # as the websockets client expects bytes.
83
        self.message_size_limit = self._emc_config.get("message_size_limit", 8) << 20
1✔
84

85
        # Setting these explicitly as they probably shouldn't be changed by a user
86
        # Unless we somehow integrate this with the strategy to allow creating
87
        # callbacks for the messages
88
        self.topics = [RPCMessageType.WHITELIST, RPCMessageType.ANALYZED_DF]
1✔
89

90
        # Allow setting data for each initial request
91
        self._initial_requests: List[WSRequestSchema] = [
1✔
92
            WSSubscribeRequest(data=self.topics),
93
            WSWhitelistRequest(),
94
            WSAnalyzedDFRequest(),
95
        ]
96

97
        # Specify which function to use for which RPCMessageType
98
        self._message_handlers: Dict[str, Callable[[str, WSMessageSchema], None]] = {
1✔
99
            RPCMessageType.WHITELIST: self._consume_whitelist_message,
100
            RPCMessageType.ANALYZED_DF: self._consume_analyzed_df_message,
101
        }
102

103
        self._channel_streams: Dict[str, MessageStream] = {}
1✔
104

105
        self.start()
1✔
106

107
    def start(self):
1✔
108
        """
109
        Start the main internal loop in another thread to run coroutines
110
        """
111
        if self._thread and self._loop:
1✔
112
            return
1✔
113

114
        logger.info("Starting ExternalMessageConsumer")
1✔
115

116
        self._loop = asyncio.new_event_loop()
1✔
117
        self._thread = Thread(target=self._loop.run_forever)
1✔
118
        self._running = True
1✔
119
        self._thread.start()
1✔
120

121
        self._main_task = asyncio.run_coroutine_threadsafe(self._main(), loop=self._loop)
1✔
122

123
    def shutdown(self):
1✔
124
        """
125
        Shutdown the loop, thread, and tasks
126
        """
127
        if self._thread and self._loop:
1✔
128
            logger.info("Stopping ExternalMessageConsumer")
1✔
129
            self._running = False
1✔
130

131
            self._channel_streams = {}
1✔
132

133
            if self._sub_tasks:
1✔
134
                # Cancel sub tasks
135
                for task in self._sub_tasks:
1✔
136
                    task.cancel()
1✔
137

138
            if self._main_task:
1✔
139
                # Cancel the main task
140
                self._main_task.cancel()
1✔
141

142
            self._thread.join()
1✔
143

144
            self._thread = None
1✔
145
            self._loop = None
1✔
146
            self._sub_tasks = None
1✔
147
            self._main_task = None
1✔
148

149
    async def _main(self):
1✔
150
        """
151
        The main task coroutine
152
        """
153
        lock = asyncio.Lock()
1✔
154

155
        try:
1✔
156
            # Create a connection to each producer
157
            self._sub_tasks = [
1✔
158
                self._loop.create_task(self._handle_producer_connection(producer, lock))
159
                for producer in self.producers
160
            ]
161

162
            await asyncio.gather(*self._sub_tasks)
1✔
163
        except asyncio.CancelledError:
1✔
164
            pass
1✔
165
        finally:
166
            # Stop the loop once we are done
167
            self._loop.stop()
1✔
168

169
    async def _handle_producer_connection(self, producer: Producer, lock: asyncio.Lock):
1✔
170
        """
171
        Main connection loop for the consumer
172

173
        :param producer: Dictionary containing producer info
174
        :param lock: An asyncio Lock
175
        """
176
        try:
1✔
177
            await self._create_connection(producer, lock)
1✔
178
        except asyncio.CancelledError:
1✔
179
            # Exit silently
180
            pass
1✔
181

182
    async def _create_connection(self, producer: Producer, lock: asyncio.Lock):
1✔
183
        """
184
        Actually creates and handles the websocket connection, pinging on timeout
185
        and handling connection errors.
186

187
        :param producer: Dictionary containing producer info
188
        :param lock: An asyncio Lock
189
        """
190
        while self._running:
1✔
191
            try:
1✔
192
                host, port = producer["host"], producer["port"]
1✔
193
                token = producer["ws_token"]
1✔
194
                name = producer["name"]
1✔
195
                scheme = "wss" if producer.get("secure", False) else "ws"
1✔
196
                ws_url = f"{scheme}://{host}:{port}/api/v1/message/ws?token={token}"
1✔
197

198
                # This will raise InvalidURI if the url is bad
199
                async with websockets.connect(
1✔
200
                    ws_url, max_size=self.message_size_limit, ping_interval=None
201
                ) as ws:
202
                    async with create_channel(ws, channel_id=name, send_throttle=0.5) as channel:
1✔
203
                        # Create the message stream for this channel
204
                        self._channel_streams[name] = MessageStream()
1✔
205

206
                        # Run the channel tasks while connected
207
                        await channel.run_channel_tasks(
1✔
208
                            self._receive_messages(channel, producer, lock),
209
                            self._send_requests(channel, self._channel_streams[name]),
210
                        )
211

212
            except (websockets.exceptions.InvalidURI, ValueError) as e:
1✔
213
                logger.error(f"{ws_url} is an invalid WebSocket URL - {e}")
1✔
214
                break
1✔
215

216
            except (
1✔
217
                socket.gaierror,
218
                ConnectionRefusedError,
219
                websockets.exceptions.InvalidStatusCode,
220
                websockets.exceptions.InvalidMessage,
221
            ) as e:
222
                logger.error(f"Connection Refused - {e} retrying in {self.sleep_time}s")
×
223
                await asyncio.sleep(self.sleep_time)
×
224
                continue
×
225

226
            except (
1✔
227
                websockets.exceptions.ConnectionClosedError,
228
                websockets.exceptions.ConnectionClosedOK,
229
            ):
230
                # Just keep trying to connect again indefinitely
231
                await asyncio.sleep(self.sleep_time)
×
232
                continue
×
233

234
            except Exception as e:
1✔
235
                # An unforeseen error has occurred, log and continue
236
                logger.error("Unexpected error has occurred:")
1✔
237
                logger.exception(e)
1✔
238
                await asyncio.sleep(self.sleep_time)
1✔
239
                continue
×
240

241
    async def _send_requests(self, channel: WebSocketChannel, channel_stream: MessageStream):
1✔
242
        # Send the initial requests
243
        for init_request in self._initial_requests:
1✔
244
            await channel.send(schema_to_dict(init_request))
1✔
245

246
        # Now send any subsequent requests published to
247
        # this channel's stream
248
        async for request, _ in channel_stream:
×
249
            logger.debug(f"Sending request to channel - {channel} - {request}")
×
250
            await channel.send(request)
×
251

252
    async def _receive_messages(
1✔
253
        self, channel: WebSocketChannel, producer: Producer, lock: asyncio.Lock
254
    ):
255
        """
256
        Loop to handle receiving messages from a Producer
257

258
        :param channel: The WebSocketChannel object for the WebSocket
259
        :param producer: Dictionary containing producer info
260
        :param lock: An asyncio Lock
261
        """
262
        while self._running:
1✔
263
            try:
1✔
264
                message = await asyncio.wait_for(channel.recv(), timeout=self.wait_timeout)
1✔
265

266
                try:
1✔
267
                    async with lock:
1✔
268
                        # Handle the message
269
                        self.handle_producer_message(producer, message)
1✔
270
                except Exception as e:
1✔
271
                    logger.exception(f"Error handling producer message: {e}")
1✔
272

273
            except (asyncio.TimeoutError, websockets.exceptions.ConnectionClosed):
1✔
274
                # We haven't received data yet. Check the connection and continue.
275
                try:
1✔
276
                    # ping
277
                    pong = await channel.ping()
1✔
278
                    latency = await asyncio.wait_for(pong, timeout=self.ping_timeout) * 1000
1✔
279

280
                    logger.info(f"Connection to {channel} still alive, latency: {latency}ms")
×
281
                    continue
×
282

283
                except Exception as e:
1✔
284
                    # Just eat the error and continue reconnecting
285
                    logger.warning(f"Ping error {channel} - {e} - retrying in {self.sleep_time}s")
1✔
286
                    logger.debug(e, exc_info=e)
1✔
287
                    raise
1✔
288

289
    def send_producer_request(
1✔
290
        self, producer_name: str, request: Union[WSRequestSchema, Dict[str, Any]]
291
    ):
292
        """
293
        Publish a message to the producer's message stream to be
294
        sent by the channel task.
295

296
        :param producer_name: The name of the producer to publish the message to
297
        :param request: The request to send to the producer
298
        """
299
        if isinstance(request, WSRequestSchema):
1✔
300
            request = schema_to_dict(request)
1✔
301

302
        if channel_stream := self._channel_streams.get(producer_name):
1✔
303
            channel_stream.publish(request)
×
304

305
    def handle_producer_message(self, producer: Producer, message: Dict[str, Any]):
1✔
306
        """
307
        Handles external messages from a Producer
308
        """
309
        producer_name = producer.get("name", "default")
1✔
310

311
        try:
1✔
312
            producer_message = WSMessageSchema.model_validate(message)
1✔
313
        except ValidationError as e:
1✔
314
            logger.error(f"Invalid message from `{producer_name}`: {e}")
1✔
315
            return
1✔
316

317
        if not producer_message.data:
1✔
318
            logger.error(f"Empty message received from `{producer_name}`")
1✔
319
            return
1✔
320

321
        logger.debug(f"Received message of type `{producer_message.type}` from `{producer_name}`")
1✔
322

323
        message_handler = self._message_handlers.get(producer_message.type)
1✔
324

325
        if not message_handler:
1✔
326
            logger.info(f"Received unhandled message: `{producer_message.data}`, ignoring...")
1✔
327
            return
1✔
328

329
        message_handler(producer_name, producer_message)
1✔
330

331
    def _consume_whitelist_message(self, producer_name: str, message: WSMessageSchema):
1✔
332
        try:
1✔
333
            # Validate the message
334
            whitelist_message = WSWhitelistMessage.model_validate(message.model_dump())
1✔
335
        except ValidationError as e:
1✔
336
            logger.error(f"Invalid message from `{producer_name}`: {e}")
1✔
337
            return
1✔
338

339
        # Add the pairlist data to the DataProvider
340
        self._dp._set_producer_pairs(whitelist_message.data, producer_name=producer_name)
1✔
341

342
        logger.debug(f"Consumed message from `{producer_name}` of type `RPCMessageType.WHITELIST`")
1✔
343

344
    def _consume_analyzed_df_message(self, producer_name: str, message: WSMessageSchema):
1✔
345
        try:
1✔
346
            df_message = WSAnalyzedDFMessage.model_validate(message.model_dump())
1✔
347
        except ValidationError as e:
1✔
348
            logger.error(f"Invalid message from `{producer_name}`: {e}")
1✔
349
            return
1✔
350

351
        key = df_message.data.key
1✔
352
        df = df_message.data.df
1✔
353
        la = df_message.data.la
1✔
354

355
        pair, timeframe, candle_type = key
1✔
356

357
        if df.empty:
1✔
358
            logger.debug(f"Received Empty Dataframe for {key}")
1✔
359
            return
1✔
360

361
        # If set, remove the Entry and Exit signals from the Producer
362
        if self._emc_config.get("remove_entry_exit_signals", False):
1✔
363
            df = remove_entry_exit_signals(df)
×
364

365
        logger.debug(f"Received {len(df)} candle(s) for {key}")
1✔
366

367
        did_append, n_missing = self._dp._add_external_df(
1✔
368
            pair,
369
            df,
370
            last_analyzed=la,
371
            timeframe=timeframe,
372
            candle_type=candle_type,
373
            producer_name=producer_name,
374
        )
375

376
        if not did_append:
1✔
377
            # We want an overlap in candles in case some data has changed
378
            n_missing += 1
1✔
379
            # Set to None for all candles if we missed a full df's worth of candles
380
            n_missing = n_missing if n_missing < FULL_DATAFRAME_THRESHOLD else 1500
1✔
381

382
            logger.warning(
1✔
383
                f"Holes in data or no existing df, requesting {n_missing} candles "
384
                f"for {key} from `{producer_name}`"
385
            )
386

387
            self.send_producer_request(
1✔
388
                producer_name, WSAnalyzedDFRequest(data={"limit": n_missing, "pair": pair})
389
            )
390
            return
1✔
391

392
        logger.debug(
×
393
            f"Consumed message from `{producer_name}` "
394
            f"of type `RPCMessageType.ANALYZED_DF` for {key}"
395
        )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc