• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

freqtrade / freqtrade / 9394559170

26 Apr 2024 06:36AM UTC coverage: 94.656% (-0.02%) from 94.674%
9394559170

push

github

xmatthias
Loader should be passed as kwarg for clarity

20280 of 21425 relevant lines covered (94.66%)

0.95 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

61.0
/freqtrade/freqai/freqai_interface.py
1
import logging
1✔
2
import threading
1✔
3
import time
1✔
4
from abc import ABC, abstractmethod
1✔
5
from collections import deque
1✔
6
from datetime import datetime, timezone
1✔
7
from pathlib import Path
1✔
8
from typing import Any, Dict, List, Literal, Optional, Tuple
1✔
9

10
import datasieve.transforms as ds
1✔
11
import numpy as np
1✔
12
import pandas as pd
1✔
13
import psutil
1✔
14
from datasieve.pipeline import Pipeline
1✔
15
from datasieve.transforms import SKLearnWrapper
1✔
16
from numpy.typing import NDArray
1✔
17
from pandas import DataFrame
1✔
18
from sklearn.preprocessing import MinMaxScaler
1✔
19

20
from freqtrade.configuration import TimeRange
1✔
21
from freqtrade.constants import DOCS_LINK, Config
1✔
22
from freqtrade.data.dataprovider import DataProvider
1✔
23
from freqtrade.enums import RunMode
1✔
24
from freqtrade.exceptions import OperationalException
1✔
25
from freqtrade.exchange import timeframe_to_seconds
1✔
26
from freqtrade.freqai.data_drawer import FreqaiDataDrawer
1✔
27
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
1✔
28
from freqtrade.freqai.utils import get_tb_logger, plot_feature_importance, record_params
1✔
29
from freqtrade.strategy.interface import IStrategy
1✔
30

31

32
pd.options.mode.chained_assignment = None
1✔
33
logger = logging.getLogger(__name__)
1✔
34

35

36
class IFreqaiModel(ABC):
1✔
37
    """
38
    Class containing all tools for training and prediction in the strategy.
39
    Base*PredictionModels inherit from this class.
40

41
    Record of contribution:
42
    FreqAI was developed by a group of individuals who all contributed specific skillsets to the
43
    project.
44

45
    Conception and software development:
46
    Robert Caulk @robcaulk
47

48
    Theoretical brainstorming:
49
    Elin Törnquist @th0rntwig
50

51
    Code review, software architecture brainstorming:
52
    @xmatthias
53

54
    Beta testing and bug reporting:
55
    @bloodhunter4rc, Salah Lamkadem @ikonx, @ken11o2, @longyu, @paranoidandy, @smidelis, @smarm
56
    Juha Nykänen @suikula, Wagner Costa @wagnercosta, Johan Vlugt @Jooopieeert
57
    """
58

59
    def __init__(self, config: Config) -> None:
1✔
60

61
        self.config = config
1✔
62
        self.assert_config(self.config)
1✔
63
        self.freqai_info: Dict[str, Any] = config["freqai"]
1✔
64
        self.data_split_parameters: Dict[str, Any] = config.get("freqai", {}).get(
1✔
65
            "data_split_parameters", {})
66
        self.model_training_parameters: Dict[str, Any] = config.get("freqai", {}).get(
1✔
67
            "model_training_parameters", {})
68
        self.identifier: str = self.freqai_info.get("identifier", "no_id_provided")
1✔
69
        self.retrain = False
1✔
70
        self.first = True
1✔
71
        self.set_full_path()
1✔
72
        self.save_backtest_models: bool = self.freqai_info.get("save_backtest_models", True)
1✔
73
        if self.save_backtest_models:
1✔
74
            logger.info('Backtesting module configured to save all models.')
1✔
75

76
        self.dd = FreqaiDataDrawer(Path(self.full_path), self.config)
1✔
77
        # set current candle to arbitrary historical date
78
        self.current_candle: datetime = datetime.fromtimestamp(637887600, tz=timezone.utc)
1✔
79
        self.dd.current_candle = self.current_candle
1✔
80
        self.scanning = False
1✔
81
        self.ft_params = self.freqai_info["feature_parameters"]
1✔
82
        self.corr_pairlist: List[str] = self.ft_params.get("include_corr_pairlist", [])
1✔
83
        self.keras: bool = self.freqai_info.get("keras", False)
1✔
84
        if self.keras and self.ft_params.get("DI_threshold", 0):
1✔
85
            self.ft_params["DI_threshold"] = 0
×
86
            logger.warning("DI threshold is not configured for Keras models yet. Deactivating.")
×
87

88
        self.CONV_WIDTH = self.freqai_info.get('conv_width', 1)
1✔
89
        self.class_names: List[str] = []  # used in classification subclasses
1✔
90
        self.pair_it = 0
1✔
91
        self.pair_it_train = 0
1✔
92
        self.total_pairs = len(self.config.get("exchange", {}).get("pair_whitelist"))
1✔
93
        self.train_queue = self._set_train_queue()
1✔
94
        self.inference_time: float = 0
1✔
95
        self.train_time: float = 0
1✔
96
        self.begin_time: float = 0
1✔
97
        self.begin_time_train: float = 0
1✔
98
        self.base_tf_seconds = timeframe_to_seconds(self.config['timeframe'])
1✔
99
        self.continual_learning = self.freqai_info.get('continual_learning', False)
1✔
100
        self.plot_features = self.ft_params.get("plot_feature_importances", 0)
1✔
101
        self.corr_dataframes: Dict[str, DataFrame] = {}
1✔
102
        # get_corr_dataframes is controlling the caching of corr_dataframes
103
        # for improved performance. Careful with this boolean.
104
        self.get_corr_dataframes: bool = True
1✔
105
        self._threads: List[threading.Thread] = []
1✔
106
        self._stop_event = threading.Event()
1✔
107
        self.metadata: Dict[str, Any] = self.dd.load_global_metadata_from_disk()
1✔
108
        self.data_provider: Optional[DataProvider] = None
1✔
109
        self.max_system_threads = max(int(psutil.cpu_count() * 2 - 2), 1)
1✔
110
        self.can_short = True  # overridden in start() with strategy.can_short
1✔
111
        self.model: Any = None
1✔
112
        if self.ft_params.get('principal_component_analysis', False) and self.continual_learning:
1✔
113
            self.ft_params.update({'principal_component_analysis': False})
×
114
            logger.warning('User tried to use PCA with continual learning. Deactivating PCA.')
×
115
        self.activate_tensorboard: bool = self.freqai_info.get('activate_tensorboard', True)
1✔
116

117
        record_params(config, self.full_path)
1✔
118

119
    def __getstate__(self):
1✔
120
        """
121
        Return an empty state to be pickled in hyperopt
122
        """
123
        return ({})
×
124

125
    def assert_config(self, config: Config) -> None:
1✔
126

127
        if not config.get("freqai", {}):
1✔
128
            raise OperationalException("No freqai parameters found in configuration file.")
×
129

130
    def start(self, dataframe: DataFrame, metadata: dict, strategy: IStrategy) -> DataFrame:
1✔
131
        """
132
        Entry point to the FreqaiModel from a specific pair, it will train a new model if
133
        necessary before making the prediction.
134

135
        :param dataframe: Full dataframe coming from strategy - it contains entire
136
                           backtesting timerange + additional historical data necessary to train
137
        the model.
138
        :param metadata: pair metadata coming from strategy.
139
        :param strategy: Strategy to train on
140
        """
141
        self.live = strategy.dp.runmode in (RunMode.DRY_RUN, RunMode.LIVE)
1✔
142
        self.dd.set_pair_dict_info(metadata)
1✔
143
        self.data_provider = strategy.dp
1✔
144
        self.can_short = strategy.can_short
1✔
145

146
        if self.live:
1✔
147
            self.inference_timer('start')
×
148
            self.dk = FreqaiDataKitchen(self.config, self.live, metadata["pair"])
×
149
            dk = self.start_live(dataframe, metadata, strategy, self.dk)
×
150
            dataframe = dk.remove_features_from_df(dk.return_dataframe)
×
151

152
        # For backtesting, each pair enters and then gets trained for each window along the
153
        # sliding window defined by "train_period_days" (training window) and "live_retrain_hours"
154
        # (backtest window, i.e. window immediately following the training window).
155
        # FreqAI slides the window and sequentially builds the backtesting results before returning
156
        # the concatenated results for the full backtesting period back to the strategy.
157
        else:
158
            self.dk = FreqaiDataKitchen(self.config, self.live, metadata["pair"])
1✔
159
            if not self.config.get("freqai_backtest_live_models", False):
1✔
160
                logger.info(f"Training {len(self.dk.training_timeranges)} timeranges")
1✔
161
                dk = self.start_backtesting(dataframe, metadata, self.dk, strategy)
1✔
162
                dataframe = dk.remove_features_from_df(dk.return_dataframe)
1✔
163
            else:
164
                logger.info("Backtesting using historic predictions (live models)")
×
165
                dk = self.start_backtesting_from_historic_predictions(
×
166
                    dataframe, metadata, self.dk)
167
                dataframe = dk.return_dataframe
×
168

169
        self.clean_up()
1✔
170
        if self.live:
1✔
171
            self.inference_timer('stop', metadata["pair"])
×
172

173
        return dataframe
1✔
174

175
    def clean_up(self):
1✔
176
        """
177
        Objects that should be handled by GC already between coins, but
178
        are explicitly shown here to help demonstrate the non-persistence of these
179
        objects.
180
        """
181
        self.model = None
1✔
182
        self.dk = None
1✔
183

184
    def _on_stop(self):
1✔
185
        """
186
        Callback for Subclasses to override to include logic for shutting down resources
187
        when SIGINT is sent.
188
        """
189
        return
×
190

191
    def shutdown(self):
1✔
192
        """
193
        Cleans up threads on Shutdown, set stop event. Join threads to wait
194
        for current training iteration.
195
        """
196
        logger.info("Stopping FreqAI")
×
197
        self._stop_event.set()
×
198

199
        self.data_provider = None
×
200
        self._on_stop()
×
201

202
        logger.info("Waiting on Training iteration")
×
203
        for _thread in self._threads:
×
204
            _thread.join()
×
205

206
    def start_scanning(self, *args, **kwargs) -> None:
1✔
207
        """
208
        Start `self._start_scanning` in a separate thread
209
        """
210
        _thread = threading.Thread(target=self._start_scanning, args=args, kwargs=kwargs)
×
211
        self._threads.append(_thread)
×
212
        _thread.start()
×
213

214
    def _start_scanning(self, strategy: IStrategy) -> None:
1✔
215
        """
216
        Function designed to constantly scan pairs for retraining on a separate thread (intracandle)
217
        to improve model youth. This function is agnostic to data preparation/collection/storage,
218
        it simply trains on what ever data is available in the self.dd.
219
        :param strategy: IStrategy = The user defined strategy class
220
        """
221
        while not self._stop_event.is_set():
×
222
            time.sleep(1)
×
223
            pair = self.train_queue[0]
×
224

225
            # ensure pair is available in dp
226
            if pair not in strategy.dp.current_whitelist():
×
227
                self.train_queue.popleft()
×
228
                logger.warning(f'{pair} not in current whitelist, removing from train queue.')
×
229
                continue
×
230

231
            (_, trained_timestamp) = self.dd.get_pair_dict_info(pair)
×
232

233
            dk = FreqaiDataKitchen(self.config, self.live, pair)
×
234
            (
×
235
                retrain,
236
                new_trained_timerange,
237
                data_load_timerange,
238
            ) = dk.check_if_new_training_required(trained_timestamp)
239

240
            if retrain:
×
241
                self.train_timer('start')
×
242
                dk.set_paths(pair, new_trained_timerange.stopts)
×
243
                try:
×
244
                    self.extract_data_and_train_model(
×
245
                        new_trained_timerange, pair, strategy, dk, data_load_timerange
246
                    )
247
                except Exception as msg:
×
248
                    logger.exception(f"Training {pair} raised exception {msg.__class__.__name__}. "
×
249
                                     f"Message: {msg}, skipping.")
250

251
                self.train_timer('stop', pair)
×
252

253
                # only rotate the queue after the first has been trained.
254
                self.train_queue.rotate(-1)
×
255

256
                self.dd.save_historic_predictions_to_disk()
×
257
                if self.freqai_info.get('write_metrics_to_disk', False):
×
258
                    self.dd.save_metric_tracker_to_disk()
×
259

260
    def start_backtesting(
1✔
261
        self, dataframe: DataFrame, metadata: dict, dk: FreqaiDataKitchen, strategy: IStrategy
262
    ) -> FreqaiDataKitchen:
263
        """
264
        The main broad execution for backtesting. For backtesting, each pair enters and then gets
265
        trained for each window along the sliding window defined by "train_period_days"
266
        (training window) and "backtest_period_days" (backtest window, i.e. window immediately
267
        following the training window). FreqAI slides the window and sequentially builds
268
        the backtesting results before returning the concatenated results for the full
269
        backtesting period back to the strategy.
270
        :param dataframe: DataFrame = strategy passed dataframe
271
        :param metadata: Dict = pair metadata
272
        :param dk: FreqaiDataKitchen = Data management/analysis tool associated to present pair only
273
        :param strategy: Strategy to train on
274
        :return:
275
            FreqaiDataKitchen = Data management/analysis tool associated to present pair only
276
        """
277

278
        self.pair_it += 1
1✔
279
        train_it = 0
1✔
280
        pair = metadata["pair"]
1✔
281
        populate_indicators = True
1✔
282
        check_features = True
1✔
283
        # Loop enforcing the sliding window training/backtesting paradigm
284
        # tr_train is the training time range e.g. 1 historical month
285
        # tr_backtest is the backtesting time range e.g. the week directly
286
        # following tr_train. Both of these windows slide through the
287
        # entire backtest
288
        for tr_train, tr_backtest in zip(dk.training_timeranges, dk.backtesting_timeranges):
1✔
289
            (_, _) = self.dd.get_pair_dict_info(pair)
1✔
290
            train_it += 1
1✔
291
            total_trains = len(dk.backtesting_timeranges)
1✔
292
            self.training_timerange = tr_train
1✔
293
            len_backtest_df = len(dataframe.loc[(dataframe["date"] >= tr_backtest.startdt) & (
1✔
294
                                  dataframe["date"] < tr_backtest.stopdt), :])
295

296
            if not self.ensure_data_exists(len_backtest_df, tr_backtest, pair):
1✔
297
                continue
×
298

299
            self.log_backtesting_progress(tr_train, pair, train_it, total_trains)
1✔
300

301
            timestamp_model_id = int(tr_train.stopts)
1✔
302
            if dk.backtest_live_models:
1✔
303
                timestamp_model_id = int(tr_backtest.startts)
×
304

305
            dk.set_paths(pair, timestamp_model_id)
1✔
306

307
            dk.set_new_model_names(pair, timestamp_model_id)
1✔
308

309
            if dk.check_if_backtest_prediction_is_valid(len_backtest_df):
1✔
310
                if check_features:
1✔
311
                    self.dd.load_metadata(dk)
1✔
312
                    df_fts = self.dk.use_strategy_to_populate_indicators(
1✔
313
                        strategy, prediction_dataframe=dataframe.tail(1), pair=pair
314
                    )
315
                    df_fts = dk.remove_special_chars_from_feature_names(df_fts)
1✔
316
                    dk.find_features(df_fts)
1✔
317
                    self.check_if_feature_list_matches_strategy(dk)
1✔
318
                    check_features = False
1✔
319
                append_df = dk.get_backtesting_prediction()
1✔
320
                dk.append_predictions(append_df)
1✔
321
            else:
322
                if populate_indicators:
1✔
323
                    dataframe = self.dk.use_strategy_to_populate_indicators(
1✔
324
                        strategy, prediction_dataframe=dataframe, pair=pair
325
                    )
326
                    populate_indicators = False
1✔
327

328
                dataframe_base_train = dataframe.loc[dataframe["date"] < tr_train.stopdt, :]
1✔
329
                dataframe_base_train = strategy.set_freqai_targets(
1✔
330
                    dataframe_base_train, metadata=metadata)
331
                dataframe_base_backtest = dataframe.loc[dataframe["date"] < tr_backtest.stopdt, :]
1✔
332
                dataframe_base_backtest = strategy.set_freqai_targets(
1✔
333
                    dataframe_base_backtest, metadata=metadata)
334

335
                tr_train = dk.buffer_timerange(tr_train)
1✔
336

337
                dataframe_train = dk.slice_dataframe(tr_train, dataframe_base_train)
1✔
338
                dataframe_backtest = dk.slice_dataframe(tr_backtest, dataframe_base_backtest)
1✔
339

340
                dataframe_train = dk.remove_special_chars_from_feature_names(dataframe_train)
1✔
341
                dataframe_backtest = dk.remove_special_chars_from_feature_names(dataframe_backtest)
1✔
342
                dk.get_unique_classes_from_labels(dataframe_train)
1✔
343

344
                if not self.model_exists(dk):
1✔
345
                    dk.find_features(dataframe_train)
1✔
346
                    dk.find_labels(dataframe_train)
1✔
347

348
                    try:
1✔
349
                        self.tb_logger = get_tb_logger(self.dd.model_type, dk.data_path,
1✔
350
                                                       self.activate_tensorboard)
351
                        self.model = self.train(dataframe_train, pair, dk)
1✔
352
                        self.tb_logger.close()
1✔
353
                    except Exception as msg:
×
354
                        logger.warning(
×
355
                            f"Training {pair} raised exception {msg.__class__.__name__}. "
356
                            f"Message: {msg}, skipping.", exc_info=True)
357
                        self.model = None
×
358

359
                    self.dd.pair_dict[pair]["trained_timestamp"] = int(
1✔
360
                        tr_train.stopts)
361
                    if self.plot_features and self.model is not None:
1✔
362
                        plot_feature_importance(self.model, pair, dk, self.plot_features)
×
363
                    if self.save_backtest_models and self.model is not None:
1✔
364
                        logger.info('Saving backtest model to disk.')
1✔
365
                        self.dd.save_data(self.model, pair, dk)
1✔
366
                    else:
367
                        logger.info('Saving metadata to disk.')
×
368
                        self.dd.save_metadata(dk)
×
369
                else:
370
                    self.model = self.dd.load_data(pair, dk)
×
371

372
                pred_df, do_preds = self.predict(dataframe_backtest, dk)
1✔
373
                append_df = dk.get_predictions_to_append(pred_df, do_preds, dataframe_backtest)
1✔
374
                dk.append_predictions(append_df)
1✔
375
                dk.save_backtesting_prediction(append_df)
1✔
376

377
        self.backtesting_fit_live_predictions(dk)
1✔
378
        dk.fill_predictions(dataframe)
1✔
379

380
        return dk
1✔
381

382
    def start_live(
1✔
383
        self, dataframe: DataFrame, metadata: dict, strategy: IStrategy, dk: FreqaiDataKitchen
384
    ) -> FreqaiDataKitchen:
385
        """
386
        The main broad execution for dry/live. This function will check if a retraining should be
387
        performed, and if so, retrain and reset the model.
388
        :param dataframe: DataFrame = strategy passed dataframe
389
        :param metadata: Dict = pair metadata
390
        :param strategy: IStrategy = currently employed strategy
391
        dk: FreqaiDataKitchen = Data management/analysis tool associated to present pair only
392
        :returns:
393
        dk: FreqaiDataKitchen = Data management/analysis tool associated to present pair only
394
        """
395

396
        if not strategy.process_only_new_candles:
×
397
            raise OperationalException("You are trying to use a FreqAI strategy with "
×
398
                                       "process_only_new_candles = False. This is not supported "
399
                                       "by FreqAI, and it is therefore aborting.")
400

401
        # get the model metadata associated with the current pair
402
        (_, trained_timestamp) = self.dd.get_pair_dict_info(metadata["pair"])
×
403

404
        # append the historic data once per round
405
        if self.dd.historic_data:
×
406
            self.dd.update_historic_data(strategy, dk)
×
407
            logger.debug(f'Updating historic data on pair {metadata["pair"]}')
×
408
            self.track_current_candle()
×
409

410
        (_, new_trained_timerange, data_load_timerange) = dk.check_if_new_training_required(
×
411
            trained_timestamp
412
        )
413
        dk.set_paths(metadata["pair"], new_trained_timerange.stopts)
×
414

415
        # load candle history into memory if it is not yet.
416
        if not self.dd.historic_data:
×
417
            self.dd.load_all_pair_histories(data_load_timerange, dk)
×
418

419
        if not self.scanning:
×
420
            self.scanning = True
×
421
            self.start_scanning(strategy)
×
422

423
        # load the model and associated data into the data kitchen
424
        self.model = self.dd.load_data(metadata["pair"], dk)
×
425

426
        dataframe = dk.use_strategy_to_populate_indicators(
×
427
            strategy, prediction_dataframe=dataframe, pair=metadata["pair"],
428
            do_corr_pairs=self.get_corr_dataframes
429
        )
430

431
        if not self.model:
×
432
            logger.warning(
×
433
                f"No model ready for {metadata['pair']}, returning null values to strategy."
434
            )
435
            self.dd.return_null_values_to_strategy(dataframe, dk)
×
436
            return dk
×
437

438
        if self.corr_pairlist:
×
439
            dataframe = self.cache_corr_pairlist_dfs(dataframe, dk)
×
440

441
        dk.find_labels(dataframe)
×
442

443
        self.build_strategy_return_arrays(dataframe, dk, metadata["pair"], trained_timestamp)
×
444

445
        return dk
×
446

447
    def build_strategy_return_arrays(
1✔
448
        self, dataframe: DataFrame, dk: FreqaiDataKitchen, pair: str, trained_timestamp: int
449
    ) -> None:
450

451
        # hold the historical predictions in memory so we are sending back
452
        # correct array to strategy
453

454
        if pair not in self.dd.model_return_values:
×
455
            # first predictions are made on entire historical candle set coming from strategy. This
456
            # allows FreqUI to show full return values.
457
            pred_df, do_preds = self.predict(dataframe, dk)
×
458
            if pair not in self.dd.historic_predictions:
×
459
                self.set_initial_historic_predictions(pred_df, dk, pair, dataframe)
×
460
            self.dd.set_initial_return_values(pair, pred_df, dataframe)
×
461

462
            dk.return_dataframe = self.dd.attach_return_values_to_return_dataframe(pair, dataframe)
×
463
            return
×
464
        elif self.dk.check_if_model_expired(trained_timestamp):
×
465
            pred_df = DataFrame(np.zeros((2, len(dk.label_list))), columns=dk.label_list)
×
466
            do_preds = np.ones(2, dtype=np.int_) * 2
×
467
            dk.DI_values = np.zeros(2)
×
468
            logger.warning(
×
469
                f"Model expired for {pair}, returning null values to strategy. Strategy "
470
                "construction should take care to consider this event with "
471
                "prediction == 0 and do_predict == 2"
472
            )
473
        else:
474
            # remaining predictions are made only on the most recent candles for performance and
475
            # historical accuracy reasons.
476
            pred_df, do_preds = self.predict(dataframe.iloc[-self.CONV_WIDTH:], dk, first=False)
×
477

478
        if self.freqai_info.get('fit_live_predictions_candles', 0) and self.live:
×
479
            self.fit_live_predictions(dk, pair)
×
480
        self.dd.append_model_predictions(pair, pred_df, do_preds, dk, dataframe)
×
481
        dk.return_dataframe = self.dd.attach_return_values_to_return_dataframe(pair, dataframe)
×
482

483
        return
×
484

485
    def check_if_feature_list_matches_strategy(
1✔
486
        self, dk: FreqaiDataKitchen
487
    ) -> None:
488
        """
489
        Ensure user is passing the proper feature set if they are reusing an `identifier` pointing
490
        to a folder holding existing models.
491
        :param dataframe: DataFrame = strategy provided dataframe
492
        :param dk: FreqaiDataKitchen = non-persistent data container/analyzer for
493
                   current coin/bot loop
494
        """
495

496
        if "training_features_list_raw" in dk.data:
1✔
497
            feature_list = dk.data["training_features_list_raw"]
×
498
        else:
499
            feature_list = dk.data['training_features_list']
1✔
500

501
        if dk.training_features_list != feature_list:
1✔
502
            raise OperationalException(
×
503
                "Trying to access pretrained model with `identifier` "
504
                "but found different features furnished by current strategy. "
505
                "Change `identifier` to train from scratch, or ensure the "
506
                "strategy is furnishing the same features as the pretrained "
507
                "model. In case of --strategy-list, please be aware that FreqAI "
508
                "requires all strategies to maintain identical "
509
                "feature_engineering_* functions"
510
            )
511

512
    def define_data_pipeline(self, threads=-1) -> Pipeline:
1✔
513
        ft_params = self.freqai_info["feature_parameters"]
1✔
514
        pipe_steps = [
1✔
515
            ('const', ds.VarianceThreshold(threshold=0)),
516
            ('scaler', SKLearnWrapper(MinMaxScaler(feature_range=(-1, 1))))
517
            ]
518

519
        if ft_params.get("principal_component_analysis", False):
1✔
520
            pipe_steps.append(('pca', ds.PCA(n_components=0.999)))
1✔
521
            pipe_steps.append(('post-pca-scaler',
1✔
522
                               SKLearnWrapper(MinMaxScaler(feature_range=(-1, 1)))))
523

524
        if ft_params.get("use_SVM_to_remove_outliers", False):
1✔
525
            svm_params = ft_params.get(
1✔
526
                "svm_params", {"shuffle": False, "nu": 0.01})
527
            pipe_steps.append(('svm', ds.SVMOutlierExtractor(**svm_params)))
1✔
528

529
        di = ft_params.get("DI_threshold", 0)
1✔
530
        if di:
1✔
531
            pipe_steps.append(('di', ds.DissimilarityIndex(di_threshold=di, n_jobs=threads)))
1✔
532

533
        if ft_params.get("use_DBSCAN_to_remove_outliers", False):
1✔
534
            pipe_steps.append(('dbscan', ds.DBSCAN(n_jobs=threads)))
1✔
535

536
        sigma = self.freqai_info["feature_parameters"].get('noise_standard_deviation', 0)
1✔
537
        if sigma:
1✔
538
            pipe_steps.append(('noise', ds.Noise(sigma=sigma)))
1✔
539

540
        return Pipeline(pipe_steps)
1✔
541

542
    def define_label_pipeline(self, threads=-1) -> Pipeline:
1✔
543

544
        label_pipeline = Pipeline([
1✔
545
            ('scaler', SKLearnWrapper(MinMaxScaler(feature_range=(-1, 1))))
546
            ])
547

548
        return label_pipeline
1✔
549

550
    def model_exists(self, dk: FreqaiDataKitchen) -> bool:
1✔
551
        """
552
        Given a pair and path, check if a model already exists
553
        :param pair: pair e.g. BTC/USD
554
        :param path: path to model
555
        :return:
556
        :boolean: whether the model file exists or not.
557
        """
558
        if self.dd.model_type == 'joblib':
1✔
559
            file_type = ".joblib"
1✔
560
        elif self.dd.model_type in ["stable_baselines3", "sb3_contrib", "pytorch"]:
1✔
561
            file_type = ".zip"
1✔
562

563
        path_to_modelfile = Path(dk.data_path / f"{dk.model_filename}_model{file_type}")
1✔
564
        file_exists = path_to_modelfile.is_file()
1✔
565
        if file_exists:
1✔
566
            logger.info("Found model at %s", dk.data_path / dk.model_filename)
×
567
        else:
568
            logger.info("Could not find model at %s", dk.data_path / dk.model_filename)
1✔
569
        return file_exists
1✔
570

571
    def set_full_path(self) -> None:
1✔
572
        """
573
        Creates and sets the full path for the identifier
574
        """
575
        self.full_path = Path(
1✔
576
            self.config["user_data_dir"] / "models" / f"{self.identifier}"
577
        )
578
        self.full_path.mkdir(parents=True, exist_ok=True)
1✔
579

580
    def extract_data_and_train_model(
1✔
581
        self,
582
        new_trained_timerange: TimeRange,
583
        pair: str,
584
        strategy: IStrategy,
585
        dk: FreqaiDataKitchen,
586
        data_load_timerange: TimeRange,
587
    ):
588
        """
589
        Retrieve data and train model.
590
        :param new_trained_timerange: TimeRange = the timerange to train the model on
591
        :param metadata: dict = strategy provided metadata
592
        :param strategy: IStrategy = user defined strategy object
593
        :param dk: FreqaiDataKitchen = non-persistent data container for current coin/loop
594
        :param data_load_timerange: TimeRange = the amount of data to be loaded
595
                                    for populating indicators
596
                                    (larger than new_trained_timerange so that
597
                                    new_trained_timerange does not contain any NaNs)
598
        """
599

600
        corr_dataframes, base_dataframes = self.dd.get_base_and_corr_dataframes(
1✔
601
            data_load_timerange, pair, dk
602
        )
603

604
        unfiltered_dataframe = dk.use_strategy_to_populate_indicators(
1✔
605
            strategy, corr_dataframes, base_dataframes, pair
606
        )
607

608
        trained_timestamp = new_trained_timerange.stopts
1✔
609

610
        buffered_timerange = dk.buffer_timerange(new_trained_timerange)
1✔
611

612
        unfiltered_dataframe = dk.slice_dataframe(buffered_timerange, unfiltered_dataframe)
1✔
613

614
        # find the features indicated by strategy and store in datakitchen
615
        dk.find_features(unfiltered_dataframe)
1✔
616
        dk.find_labels(unfiltered_dataframe)
1✔
617

618
        self.tb_logger = get_tb_logger(self.dd.model_type, dk.data_path,
1✔
619
                                       self.activate_tensorboard)
620
        model = self.train(unfiltered_dataframe, pair, dk)
1✔
621
        self.tb_logger.close()
1✔
622

623
        self.dd.pair_dict[pair]["trained_timestamp"] = trained_timestamp
1✔
624
        dk.set_new_model_names(pair, trained_timestamp)
1✔
625
        self.dd.save_data(model, pair, dk)
1✔
626

627
        if self.plot_features:
1✔
628
            plot_feature_importance(model, pair, dk, self.plot_features)
×
629

630
        self.dd.purge_old_models()
1✔
631

632
    def set_initial_historic_predictions(
1✔
633
        self, pred_df: DataFrame, dk: FreqaiDataKitchen, pair: str, strat_df: DataFrame
634
    ) -> None:
635
        """
636
        This function is called only if the datadrawer failed to load an
637
        existing set of historic predictions. In this case, it builds
638
        the structure and sets fake predictions off the first training
639
        data. After that, FreqAI will append new real predictions to the
640
        set of historic predictions.
641

642
        These values are used to generate live statistics which can be used
643
        in the strategy for adaptive values. E.g. &*_mean/std are quantities
644
        that can computed based on live predictions from the set of historical
645
        predictions. Those values can be used in the user strategy to better
646
        assess prediction rarity, and thus wait for probabilistically favorable
647
        entries relative to the live historical predictions.
648

649
        If the user reuses an identifier on a subsequent instance,
650
        this function will not be called. In that case, "real" predictions
651
        will be appended to the loaded set of historic predictions.
652
        :param pred_df: DataFrame = the dataframe containing the predictions coming
653
            out of a model
654
        :param dk: FreqaiDataKitchen = object containing methods for data analysis
655
        :param pair: str = current pair
656
        :param strat_df: DataFrame = dataframe coming from strategy
657
        """
658

659
        self.dd.historic_predictions[pair] = pred_df
×
660
        hist_preds_df = self.dd.historic_predictions[pair]
×
661

662
        self.set_start_dry_live_date(strat_df)
×
663

664
        for label in hist_preds_df.columns:
×
665
            if hist_preds_df[label].dtype == object:
×
666
                continue
×
667
            hist_preds_df[f'{label}_mean'] = 0
×
668
            hist_preds_df[f'{label}_std'] = 0
×
669

670
        hist_preds_df['do_predict'] = 0
×
671

672
        if self.freqai_info['feature_parameters'].get('DI_threshold', 0) > 0:
×
673
            hist_preds_df['DI_values'] = 0
×
674

675
        for return_str in dk.data['extra_returns_per_train']:
×
676
            hist_preds_df[return_str] = dk.data['extra_returns_per_train'][return_str]
×
677

678
        hist_preds_df['high_price'] = strat_df['high']
×
679
        hist_preds_df['low_price'] = strat_df['low']
×
680
        hist_preds_df['close_price'] = strat_df['close']
×
681
        hist_preds_df['date_pred'] = strat_df['date']
×
682

683
    def fit_live_predictions(self, dk: FreqaiDataKitchen, pair: str) -> None:
1✔
684
        """
685
        Fit the labels with a gaussian distribution
686
        """
687
        import scipy as spy
1✔
688

689
        # add classes from classifier label types if used
690
        full_labels = dk.label_list + dk.unique_class_list
1✔
691

692
        num_candles = self.freqai_info.get("fit_live_predictions_candles", 100)
1✔
693
        dk.data["labels_mean"], dk.data["labels_std"] = {}, {}
1✔
694
        for label in full_labels:
1✔
695
            if self.dd.historic_predictions[dk.pair][label].dtype == object:
1✔
696
                continue
×
697
            f = spy.stats.norm.fit(
1✔
698
                self.dd.historic_predictions[dk.pair][label].tail(num_candles))
699
            dk.data["labels_mean"][label], dk.data["labels_std"][label] = f[0], f[1]
1✔
700

701
        return
1✔
702

703
    def inference_timer(self, do: Literal['start', 'stop'] = 'start', pair: str = ''):
1✔
704
        """
705
        Timer designed to track the cumulative time spent in FreqAI for one pass through
706
        the whitelist. This will check if the time spent is more than 1/4 the time
707
        of a single candle, and if so, it will warn the user of degraded performance
708
        """
709
        if do == 'start':
×
710
            self.pair_it += 1
×
711
            self.begin_time = time.time()
×
712
        elif do == 'stop':
×
713
            end = time.time()
×
714
            time_spent = (end - self.begin_time)
×
715
            if self.freqai_info.get('write_metrics_to_disk', False):
×
716
                self.dd.update_metric_tracker('inference_time', time_spent, pair)
×
717
            self.inference_time += time_spent
×
718
            if self.pair_it == self.total_pairs:
×
719
                logger.info(
×
720
                    f'Total time spent inferencing pairlist {self.inference_time:.2f} seconds')
721
                self.pair_it = 0
×
722
                self.inference_time = 0
×
723
        return
×
724

725
    def train_timer(self, do: Literal['start', 'stop'] = 'start', pair: str = ''):
1✔
726
        """
727
        Timer designed to track the cumulative time spent training the full pairlist in
728
        FreqAI.
729
        """
730
        if do == 'start':
1✔
731
            self.pair_it_train += 1
1✔
732
            self.begin_time_train = time.time()
1✔
733
        elif do == 'stop':
1✔
734
            end = time.time()
1✔
735
            time_spent = (end - self.begin_time_train)
1✔
736
            if self.freqai_info.get('write_metrics_to_disk', False):
1✔
737
                self.dd.collect_metrics(time_spent, pair)
×
738

739
            self.train_time += time_spent
1✔
740
            if self.pair_it_train == self.total_pairs:
1✔
741
                logger.info(
×
742
                    f'Total time spent training pairlist {self.train_time:.2f} seconds')
743
                self.pair_it_train = 0
×
744
                self.train_time = 0
×
745
        return
1✔
746

747
    def get_init_model(self, pair: str) -> Any:
1✔
748
        if pair not in self.dd.model_dictionary or not self.continual_learning:
1✔
749
            init_model = None
1✔
750
        else:
751
            init_model = self.dd.model_dictionary[pair]
×
752

753
        return init_model
1✔
754

755
    def _set_train_queue(self):
1✔
756
        """
757
        Sets train queue from existing train timestamps if they exist
758
        otherwise it sets the train queue based on the provided whitelist.
759
        """
760
        current_pairlist = self.config.get("exchange", {}).get("pair_whitelist")
1✔
761
        if not self.dd.pair_dict:
1✔
762
            logger.info('Set fresh train queue from whitelist. '
1✔
763
                        f'Queue: {current_pairlist}')
764
            return deque(current_pairlist)
1✔
765

766
        best_queue = deque()
1✔
767

768
        pair_dict_sorted = sorted(self.dd.pair_dict.items(),
1✔
769
                                  key=lambda k: k[1]['trained_timestamp'])
770
        for pair in pair_dict_sorted:
1✔
771
            if pair[0] in current_pairlist:
1✔
772
                best_queue.append(pair[0])
1✔
773
        for pair in current_pairlist:
1✔
774
            if pair not in best_queue:
1✔
775
                best_queue.appendleft(pair)
1✔
776

777
        logger.info('Set existing queue from trained timestamps. '
1✔
778
                    f'Best approximation queue: {best_queue}')
779
        return best_queue
1✔
780

781
    def cache_corr_pairlist_dfs(self, dataframe: DataFrame, dk: FreqaiDataKitchen) -> DataFrame:
1✔
782
        """
783
        Cache the corr_pairlist dfs to speed up performance for subsequent pairs during the
784
        current candle.
785
        :param dataframe: strategy fed dataframe
786
        :param dk: datakitchen object for current asset
787
        :return: dataframe to attach/extract cached corr_pair dfs to/from.
788
        """
789

790
        if self.get_corr_dataframes:
×
791
            self.corr_dataframes = dk.extract_corr_pair_columns_from_populated_indicators(dataframe)
×
792
            if not self.corr_dataframes:
×
793
                logger.warning("Couldn't cache corr_pair dataframes for improved performance. "
×
794
                               "Consider ensuring that the full coin/stake, e.g. XYZ/USD, "
795
                               "is included in the column names when you are creating features "
796
                               "in `feature_engineering_*` functions.")
797
            self.get_corr_dataframes = not bool(self.corr_dataframes)
×
798
        elif self.corr_dataframes:
×
799
            dataframe = dk.attach_corr_pair_columns(
×
800
                dataframe, self.corr_dataframes, dk.pair)
801

802
        return dataframe
×
803

804
    def track_current_candle(self):
1✔
805
        """
806
        Checks if the latest candle appended by the datadrawer is
807
        equivalent to the latest candle seen by FreqAI. If not, it
808
        asks to refresh the cached corr_dfs, and resets the pair
809
        counter.
810
        """
811
        if self.dd.current_candle > self.current_candle:
×
812
            self.get_corr_dataframes = True
×
813
            self.pair_it = 1
×
814
            self.current_candle = self.dd.current_candle
×
815

816
    def ensure_data_exists(self, len_dataframe_backtest: int,
1✔
817
                           tr_backtest: TimeRange, pair: str) -> bool:
818
        """
819
        Check if the dataframe is empty, if not, report useful information to user.
820
        :param len_dataframe_backtest: the len of backtesting dataframe
821
        :param tr_backtest: current backtesting timerange.
822
        :param pair: current pair
823
        :return: if the data exists or not
824
        """
825
        if self.config.get("freqai_backtest_live_models", False) and len_dataframe_backtest == 0:
1✔
826
            logger.info(f"No data found for pair {pair} from "
×
827
                        f"from {tr_backtest.start_fmt} to {tr_backtest.stop_fmt}. "
828
                        "Probably more than one training within the same candle period.")
829
            return False
×
830
        return True
1✔
831

832
    def log_backtesting_progress(self, tr_train: TimeRange, pair: str,
1✔
833
                                 train_it: int, total_trains: int):
834
        """
835
        Log the backtesting progress so user knows how many pairs have been trained and
836
        how many more pairs/trains remain.
837
        :param tr_train: the training timerange
838
        :param train_it: the train iteration for the current pair (the sliding window progress)
839
        :param pair: the current pair
840
        :param total_trains: total trains (total number of slides for the sliding window)
841
        """
842
        if not self.config.get("freqai_backtest_live_models", False):
1✔
843
            logger.info(
1✔
844
                f"Training {pair}, {self.pair_it}/{self.total_pairs} pairs"
845
                f" from {tr_train.start_fmt} "
846
                f"to {tr_train.stop_fmt}, {train_it}/{total_trains} "
847
                "trains"
848
            )
849

850
    def backtesting_fit_live_predictions(self, dk: FreqaiDataKitchen):
1✔
851
        """
852
        Apply fit_live_predictions function in backtesting with a dummy historic_predictions
853
        The loop is required to simulate dry/live operation, as it is not possible to predict
854
        the type of logic implemented by the user.
855
        :param dk: datakitchen object
856
        """
857
        fit_live_predictions_candles = self.freqai_info.get("fit_live_predictions_candles", 0)
1✔
858
        if fit_live_predictions_candles:
1✔
859
            logger.info("Applying fit_live_predictions in backtesting")
1✔
860
            label_columns = [col for col in dk.full_df.columns if (
1✔
861
                col.startswith("&") and
862
                not (col.startswith("&") and col.endswith("_mean")) and
863
                not (col.startswith("&") and col.endswith("_std")) and
864
                col not in self.dk.data["extra_returns_per_train"])
865
            ]
866

867
            for index in range(len(dk.full_df)):
1✔
868
                if index >= fit_live_predictions_candles:
1✔
869
                    self.dd.historic_predictions[self.dk.pair] = (
1✔
870
                        dk.full_df.iloc[index - fit_live_predictions_candles:index])
871
                    self.fit_live_predictions(self.dk, self.dk.pair)
1✔
872
                    for label in label_columns:
1✔
873
                        if dk.full_df[label].dtype == object:
1✔
874
                            continue
×
875
                        if "labels_mean" in self.dk.data:
1✔
876
                            dk.full_df.at[index, f"{label}_mean"] = (
1✔
877
                                self.dk.data["labels_mean"][label])
878
                        if "labels_std" in self.dk.data:
1✔
879
                            dk.full_df.at[index, f"{label}_std"] = self.dk.data["labels_std"][label]
1✔
880

881
                    for extra_col in self.dk.data["extra_returns_per_train"]:
1✔
882
                        dk.full_df.at[index, f"{extra_col}"] = (
×
883
                            self.dk.data["extra_returns_per_train"][extra_col])
884

885
        return
1✔
886

887
    def update_metadata(self, metadata: Dict[str, Any]):
1✔
888
        """
889
        Update global metadata and save the updated json file
890
        :param metadata: new global metadata dict
891
        """
892
        self.dd.save_global_metadata_to_disk(metadata)
×
893
        self.metadata = metadata
×
894

895
    def set_start_dry_live_date(self, live_dataframe: DataFrame):
1✔
896
        key_name = "start_dry_live_date"
×
897
        if key_name not in self.metadata:
×
898
            metadata = self.metadata
×
899
            metadata[key_name] = int(
×
900
                pd.to_datetime(live_dataframe.tail(1)["date"].values[0]).timestamp())
901
            self.update_metadata(metadata)
×
902

903
    def start_backtesting_from_historic_predictions(
1✔
904
        self, dataframe: DataFrame, metadata: dict, dk: FreqaiDataKitchen
905
    ) -> FreqaiDataKitchen:
906
        """
907
        :param dataframe: DataFrame = strategy passed dataframe
908
        :param metadata: Dict = pair metadata
909
        :param dk: FreqaiDataKitchen = Data management/analysis tool associated to present pair only
910
        :return:
911
            FreqaiDataKitchen = Data management/analysis tool associated to present pair only
912
        """
913
        pair = metadata["pair"]
×
914
        dk.return_dataframe = dataframe
×
915
        saved_dataframe = self.dd.historic_predictions[pair]
×
916
        columns_to_drop = list(set(saved_dataframe.columns).intersection(
×
917
            dk.return_dataframe.columns))
918
        dk.return_dataframe = dk.return_dataframe.drop(columns=list(columns_to_drop))
×
919
        dk.return_dataframe = pd.merge(
×
920
            dk.return_dataframe, saved_dataframe, how='left', left_on='date', right_on="date_pred")
921
        return dk
×
922

923
    # Following methods which are overridden by user made prediction models.
924
    # See freqai/prediction_models/CatboostPredictionModel.py for an example.
925

926
    @abstractmethod
1✔
927
    def train(self, unfiltered_df: DataFrame, pair: str,
1✔
928
              dk: FreqaiDataKitchen, **kwargs) -> Any:
929
        """
930
        Filter the training data and train a model to it. Train makes heavy use of the datahandler
931
        for storing, saving, loading, and analyzing the data.
932
        :param unfiltered_df: Full dataframe for the current training period
933
        :param metadata: pair metadata from strategy.
934
        :return: Trained model which can be used to inference (self.predict)
935
        """
936

937
    @abstractmethod
1✔
938
    def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs) -> Any:
1✔
939
        """
940
        Most regressors use the same function names and arguments e.g. user
941
        can drop in LGBMRegressor in place of CatBoostRegressor and all data
942
        management will be properly handled by Freqai.
943
        :param data_dictionary: Dict = the dictionary constructed by DataHandler to hold
944
                                all the training and test data/labels.
945
        """
946

947
        return
×
948

949
    @abstractmethod
1✔
950
    def predict(
1✔
951
        self, unfiltered_df: DataFrame, dk: FreqaiDataKitchen, **kwargs
952
    ) -> Tuple[DataFrame, NDArray[np.int_]]:
953
        """
954
        Filter the prediction features data and predict with it.
955
        :param unfiltered_df: Full dataframe for the current backtest period.
956
        :param dk: FreqaiDataKitchen = Data management/analysis tool associated to present pair only
957
        :param first: boolean = whether this is the first prediction or not.
958
        :return:
959
        :predictions: np.array of predictions
960
        :do_predict: np.array of 1s and 0s to indicate places where freqai needed to remove
961
        data (NaNs) or felt uncertain about data (i.e. SVM and/or DI index)
962
        """
963

964
    # deprecated functions
965
    def data_cleaning_train(self, dk: FreqaiDataKitchen, pair: str):
1✔
966
        """
967
        throw deprecation warning if this function is called
968
        """
969
        logger.warning(f"Your model {self.__class__.__name__} relies on the deprecated"
×
970
                       " data pipeline. Please update your model to use the new data pipeline."
971
                       " This can be achieved by following the migration guide at "
972
                       f"{DOCS_LINK}/strategy_migration/#freqai-new-data-pipeline")
973
        dk.feature_pipeline = self.define_data_pipeline(threads=dk.thread_count)
×
974
        dd = dk.data_dictionary
×
975
        (dd["train_features"],
×
976
         dd["train_labels"],
977
         dd["train_weights"]) = dk.feature_pipeline.fit_transform(dd["train_features"],
978
                                                                  dd["train_labels"],
979
                                                                  dd["train_weights"])
980

981
        (dd["test_features"],
×
982
         dd["test_labels"],
983
         dd["test_weights"]) = dk.feature_pipeline.transform(dd["test_features"],
984
                                                             dd["test_labels"],
985
                                                             dd["test_weights"])
986

987
        dk.label_pipeline = self.define_label_pipeline(threads=dk.thread_count)
×
988

989
        dd["train_labels"], _, _ = dk.label_pipeline.fit_transform(dd["train_labels"])
×
990
        dd["test_labels"], _, _ = dk.label_pipeline.transform(dd["test_labels"])
×
991
        return
×
992

993
    def data_cleaning_predict(self, dk: FreqaiDataKitchen, pair: str):
1✔
994
        """
995
        throw deprecation warning if this function is called
996
        """
997
        logger.warning(f"Your model {self.__class__.__name__} relies on the deprecated"
×
998
                       " data pipeline. Please update your model to use the new data pipeline."
999
                       " This can be achieved by following the migration guide at "
1000
                       f"{DOCS_LINK}/strategy_migration/#freqai-new-data-pipeline")
1001
        dd = dk.data_dictionary
×
1002
        dd["predict_features"], outliers, _ = dk.feature_pipeline.transform(
×
1003
            dd["predict_features"], outlier_check=True)
1004
        if self.freqai_info.get("DI_threshold", 0) > 0:
×
1005
            dk.DI_values = dk.feature_pipeline["di"].di_values
×
1006
        else:
1007
            dk.DI_values = np.zeros(outliers.shape[0])
×
1008
        dk.do_predict = outliers
×
1009
        return
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc