• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

freqtrade / freqtrade / 6181253459

08 Sep 2023 06:04AM UTC coverage: 94.614% (+0.06%) from 94.556%
6181253459

push

github-actions

web-flow
Merge pull request #9159 from stash86/fix-adjust

remove old codes when we only can do partial entries

2 of 2 new or added lines in 1 file covered. (100.0%)

19114 of 20202 relevant lines covered (94.61%)

0.95 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

58.13
/freqtrade/freqai/freqai_interface.py
1
import logging
1✔
2
import threading
1✔
3
import time
1✔
4
from abc import ABC, abstractmethod
1✔
5
from collections import deque
1✔
6
from datetime import datetime, timezone
1✔
7
from pathlib import Path
1✔
8
from typing import Any, Dict, List, Literal, Optional, Tuple
1✔
9

10
import datasieve.transforms as ds
1✔
11
import numpy as np
1✔
12
import pandas as pd
1✔
13
import psutil
1✔
14
from datasieve.pipeline import Pipeline
1✔
15
from datasieve.transforms import SKLearnWrapper
1✔
16
from numpy.typing import NDArray
1✔
17
from pandas import DataFrame
1✔
18
from sklearn.preprocessing import MinMaxScaler
1✔
19

20
from freqtrade.configuration import TimeRange
1✔
21
from freqtrade.constants import DOCS_LINK, Config
1✔
22
from freqtrade.data.dataprovider import DataProvider
1✔
23
from freqtrade.enums import RunMode
1✔
24
from freqtrade.exceptions import OperationalException
1✔
25
from freqtrade.exchange import timeframe_to_seconds
1✔
26
from freqtrade.freqai.data_drawer import FreqaiDataDrawer
1✔
27
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
1✔
28
from freqtrade.freqai.utils import get_tb_logger, plot_feature_importance, record_params
1✔
29
from freqtrade.strategy.interface import IStrategy
1✔
30

31

32
pd.options.mode.chained_assignment = None
1✔
33
logger = logging.getLogger(__name__)
1✔
34

35

36
class IFreqaiModel(ABC):
1✔
37
    """
38
    Class containing all tools for training and prediction in the strategy.
39
    Base*PredictionModels inherit from this class.
40

41
    Record of contribution:
42
    FreqAI was developed by a group of individuals who all contributed specific skillsets to the
43
    project.
44

45
    Conception and software development:
46
    Robert Caulk @robcaulk
47

48
    Theoretical brainstorming:
49
    Elin Törnquist @th0rntwig
50

51
    Code review, software architecture brainstorming:
52
    @xmatthias
53

54
    Beta testing and bug reporting:
55
    @bloodhunter4rc, Salah Lamkadem @ikonx, @ken11o2, @longyu, @paranoidandy, @smidelis, @smarm
56
    Juha Nykänen @suikula, Wagner Costa @wagnercosta, Johan Vlugt @Jooopieeert
57
    """
58

59
    def __init__(self, config: Config) -> None:
1✔
60

61
        self.config = config
1✔
62
        self.assert_config(self.config)
1✔
63
        self.freqai_info: Dict[str, Any] = config["freqai"]
1✔
64
        self.data_split_parameters: Dict[str, Any] = config.get("freqai", {}).get(
1✔
65
            "data_split_parameters", {})
66
        self.model_training_parameters: Dict[str, Any] = config.get("freqai", {}).get(
1✔
67
            "model_training_parameters", {})
68
        self.identifier: str = self.freqai_info.get("identifier", "no_id_provided")
1✔
69
        self.retrain = False
1✔
70
        self.first = True
1✔
71
        self.set_full_path()
1✔
72
        self.save_backtest_models: bool = self.freqai_info.get("save_backtest_models", True)
1✔
73
        if self.save_backtest_models:
1✔
74
            logger.info('Backtesting module configured to save all models.')
1✔
75

76
        self.dd = FreqaiDataDrawer(Path(self.full_path), self.config)
1✔
77
        # set current candle to arbitrary historical date
78
        self.current_candle: datetime = datetime.fromtimestamp(637887600, tz=timezone.utc)
1✔
79
        self.dd.current_candle = self.current_candle
1✔
80
        self.scanning = False
1✔
81
        self.ft_params = self.freqai_info["feature_parameters"]
1✔
82
        self.corr_pairlist: List[str] = self.ft_params.get("include_corr_pairlist", [])
1✔
83
        self.keras: bool = self.freqai_info.get("keras", False)
1✔
84
        if self.keras and self.ft_params.get("DI_threshold", 0):
1✔
85
            self.ft_params["DI_threshold"] = 0
×
86
            logger.warning("DI threshold is not configured for Keras models yet. Deactivating.")
×
87

88
        self.CONV_WIDTH = self.freqai_info.get('conv_width', 1)
1✔
89
        self.class_names: List[str] = []  # used in classification subclasses
1✔
90
        self.pair_it = 0
1✔
91
        self.pair_it_train = 0
1✔
92
        self.total_pairs = len(self.config.get("exchange", {}).get("pair_whitelist"))
1✔
93
        self.train_queue = self._set_train_queue()
1✔
94
        self.inference_time: float = 0
1✔
95
        self.train_time: float = 0
1✔
96
        self.begin_time: float = 0
1✔
97
        self.begin_time_train: float = 0
1✔
98
        self.base_tf_seconds = timeframe_to_seconds(self.config['timeframe'])
1✔
99
        self.continual_learning = self.freqai_info.get('continual_learning', False)
1✔
100
        self.plot_features = self.ft_params.get("plot_feature_importances", 0)
1✔
101
        self.corr_dataframes: Dict[str, DataFrame] = {}
1✔
102
        # get_corr_dataframes is controlling the caching of corr_dataframes
103
        # for improved performance. Careful with this boolean.
104
        self.get_corr_dataframes: bool = True
1✔
105
        self._threads: List[threading.Thread] = []
1✔
106
        self._stop_event = threading.Event()
1✔
107
        self.metadata: Dict[str, Any] = self.dd.load_global_metadata_from_disk()
1✔
108
        self.data_provider: Optional[DataProvider] = None
1✔
109
        self.max_system_threads = max(int(psutil.cpu_count() * 2 - 2), 1)
1✔
110
        self.can_short = True  # overridden in start() with strategy.can_short
1✔
111
        self.model: Any = None
1✔
112
        if self.ft_params.get('principal_component_analysis', False) and self.continual_learning:
1✔
113
            self.ft_params.update({'principal_component_analysis': False})
×
114
            logger.warning('User tried to use PCA with continual learning. Deactivating PCA.')
×
115
        self.activate_tensorboard: bool = self.freqai_info.get('activate_tensorboard', True)
1✔
116

117
        record_params(config, self.full_path)
1✔
118

119
    def __getstate__(self):
1✔
120
        """
121
        Return an empty state to be pickled in hyperopt
122
        """
123
        return ({})
×
124

125
    def assert_config(self, config: Config) -> None:
1✔
126

127
        if not config.get("freqai", {}):
1✔
128
            raise OperationalException("No freqai parameters found in configuration file.")
×
129

130
    def start(self, dataframe: DataFrame, metadata: dict, strategy: IStrategy) -> DataFrame:
1✔
131
        """
132
        Entry point to the FreqaiModel from a specific pair, it will train a new model if
133
        necessary before making the prediction.
134

135
        :param dataframe: Full dataframe coming from strategy - it contains entire
136
                           backtesting timerange + additional historical data necessary to train
137
        the model.
138
        :param metadata: pair metadata coming from strategy.
139
        :param strategy: Strategy to train on
140
        """
141

142
        self.live = strategy.dp.runmode in (RunMode.DRY_RUN, RunMode.LIVE)
×
143
        self.dd.set_pair_dict_info(metadata)
×
144
        self.data_provider = strategy.dp
×
145
        self.can_short = strategy.can_short
×
146

147
        if self.live:
×
148
            self.inference_timer('start')
×
149
            self.dk = FreqaiDataKitchen(self.config, self.live, metadata["pair"])
×
150
            dk = self.start_live(dataframe, metadata, strategy, self.dk)
×
151
            dataframe = dk.remove_features_from_df(dk.return_dataframe)
×
152

153
        # For backtesting, each pair enters and then gets trained for each window along the
154
        # sliding window defined by "train_period_days" (training window) and "live_retrain_hours"
155
        # (backtest window, i.e. window immediately following the training window).
156
        # FreqAI slides the window and sequentially builds the backtesting results before returning
157
        # the concatenated results for the full backtesting period back to the strategy.
158
        else:
159
            self.dk = FreqaiDataKitchen(self.config, self.live, metadata["pair"])
×
160
            if not self.config.get("freqai_backtest_live_models", False):
×
161
                logger.info(f"Training {len(self.dk.training_timeranges)} timeranges")
×
162
                dk = self.start_backtesting(dataframe, metadata, self.dk, strategy)
×
163
                dataframe = dk.remove_features_from_df(dk.return_dataframe)
×
164
            else:
165
                logger.info("Backtesting using historic predictions (live models)")
×
166
                dk = self.start_backtesting_from_historic_predictions(
×
167
                    dataframe, metadata, self.dk)
168
                dataframe = dk.return_dataframe
×
169

170
        self.clean_up()
×
171
        if self.live:
×
172
            self.inference_timer('stop', metadata["pair"])
×
173

174
        return dataframe
×
175

176
    def clean_up(self):
1✔
177
        """
178
        Objects that should be handled by GC already between coins, but
179
        are explicitly shown here to help demonstrate the non-persistence of these
180
        objects.
181
        """
182
        self.model = None
×
183
        self.dk = None
×
184

185
    def _on_stop(self):
1✔
186
        """
187
        Callback for Subclasses to override to include logic for shutting down resources
188
        when SIGINT is sent.
189
        """
190
        return
×
191

192
    def shutdown(self):
1✔
193
        """
194
        Cleans up threads on Shutdown, set stop event. Join threads to wait
195
        for current training iteration.
196
        """
197
        logger.info("Stopping FreqAI")
×
198
        self._stop_event.set()
×
199

200
        self.data_provider = None
×
201
        self._on_stop()
×
202

203
        logger.info("Waiting on Training iteration")
×
204
        for _thread in self._threads:
×
205
            _thread.join()
×
206

207
    def start_scanning(self, *args, **kwargs) -> None:
1✔
208
        """
209
        Start `self._start_scanning` in a separate thread
210
        """
211
        _thread = threading.Thread(target=self._start_scanning, args=args, kwargs=kwargs)
×
212
        self._threads.append(_thread)
×
213
        _thread.start()
×
214

215
    def _start_scanning(self, strategy: IStrategy) -> None:
1✔
216
        """
217
        Function designed to constantly scan pairs for retraining on a separate thread (intracandle)
218
        to improve model youth. This function is agnostic to data preparation/collection/storage,
219
        it simply trains on what ever data is available in the self.dd.
220
        :param strategy: IStrategy = The user defined strategy class
221
        """
222
        while not self._stop_event.is_set():
×
223
            time.sleep(1)
×
224
            pair = self.train_queue[0]
×
225

226
            # ensure pair is avaialble in dp
227
            if pair not in strategy.dp.current_whitelist():
×
228
                self.train_queue.popleft()
×
229
                logger.warning(f'{pair} not in current whitelist, removing from train queue.')
×
230
                continue
×
231

232
            (_, trained_timestamp) = self.dd.get_pair_dict_info(pair)
×
233

234
            dk = FreqaiDataKitchen(self.config, self.live, pair)
×
235
            (
×
236
                retrain,
237
                new_trained_timerange,
238
                data_load_timerange,
239
            ) = dk.check_if_new_training_required(trained_timestamp)
240

241
            if retrain:
×
242
                self.train_timer('start')
×
243
                dk.set_paths(pair, new_trained_timerange.stopts)
×
244
                try:
×
245
                    self.extract_data_and_train_model(
×
246
                        new_trained_timerange, pair, strategy, dk, data_load_timerange
247
                    )
248
                except Exception as msg:
×
249
                    logger.exception(f"Training {pair} raised exception {msg.__class__.__name__}. "
×
250
                                     f"Message: {msg}, skipping.")
251

252
                self.train_timer('stop', pair)
×
253

254
                # only rotate the queue after the first has been trained.
255
                self.train_queue.rotate(-1)
×
256

257
                self.dd.save_historic_predictions_to_disk()
×
258
                if self.freqai_info.get('write_metrics_to_disk', False):
×
259
                    self.dd.save_metric_tracker_to_disk()
×
260

261
    def start_backtesting(
1✔
262
        self, dataframe: DataFrame, metadata: dict, dk: FreqaiDataKitchen, strategy: IStrategy
263
    ) -> FreqaiDataKitchen:
264
        """
265
        The main broad execution for backtesting. For backtesting, each pair enters and then gets
266
        trained for each window along the sliding window defined by "train_period_days"
267
        (training window) and "backtest_period_days" (backtest window, i.e. window immediately
268
        following the training window). FreqAI slides the window and sequentially builds
269
        the backtesting results before returning the concatenated results for the full
270
        backtesting period back to the strategy.
271
        :param dataframe: DataFrame = strategy passed dataframe
272
        :param metadata: Dict = pair metadata
273
        :param dk: FreqaiDataKitchen = Data management/analysis tool associated to present pair only
274
        :param strategy: Strategy to train on
275
        :return:
276
            FreqaiDataKitchen = Data management/analysis tool associated to present pair only
277
        """
278

279
        self.pair_it += 1
1✔
280
        train_it = 0
1✔
281
        pair = metadata["pair"]
1✔
282
        populate_indicators = True
1✔
283
        check_features = True
1✔
284
        # Loop enforcing the sliding window training/backtesting paradigm
285
        # tr_train is the training time range e.g. 1 historical month
286
        # tr_backtest is the backtesting time range e.g. the week directly
287
        # following tr_train. Both of these windows slide through the
288
        # entire backtest
289
        for tr_train, tr_backtest in zip(dk.training_timeranges, dk.backtesting_timeranges):
1✔
290
            (_, _) = self.dd.get_pair_dict_info(pair)
1✔
291
            train_it += 1
1✔
292
            total_trains = len(dk.backtesting_timeranges)
1✔
293
            self.training_timerange = tr_train
1✔
294
            len_backtest_df = len(dataframe.loc[(dataframe["date"] >= tr_backtest.startdt) & (
1✔
295
                                  dataframe["date"] < tr_backtest.stopdt), :])
296

297
            if not self.ensure_data_exists(len_backtest_df, tr_backtest, pair):
1✔
298
                continue
×
299

300
            self.log_backtesting_progress(tr_train, pair, train_it, total_trains)
1✔
301

302
            timestamp_model_id = int(tr_train.stopts)
1✔
303
            if dk.backtest_live_models:
1✔
304
                timestamp_model_id = int(tr_backtest.startts)
×
305

306
            dk.set_paths(pair, timestamp_model_id)
1✔
307

308
            dk.set_new_model_names(pair, timestamp_model_id)
1✔
309

310
            if dk.check_if_backtest_prediction_is_valid(len_backtest_df):
1✔
311
                if check_features:
1✔
312
                    self.dd.load_metadata(dk)
1✔
313
                    df_fts = self.dk.use_strategy_to_populate_indicators(
1✔
314
                        strategy, prediction_dataframe=dataframe.tail(1), pair=pair
315
                    )
316
                    df_fts = dk.remove_special_chars_from_feature_names(df_fts)
1✔
317
                    dk.find_features(df_fts)
1✔
318
                    self.check_if_feature_list_matches_strategy(dk)
1✔
319
                    check_features = False
1✔
320
                append_df = dk.get_backtesting_prediction()
1✔
321
                dk.append_predictions(append_df)
1✔
322
            else:
323
                if populate_indicators:
1✔
324
                    dataframe = self.dk.use_strategy_to_populate_indicators(
1✔
325
                        strategy, prediction_dataframe=dataframe, pair=pair
326
                    )
327
                    populate_indicators = False
1✔
328

329
                dataframe_base_train = dataframe.loc[dataframe["date"] < tr_train.stopdt, :]
1✔
330
                dataframe_base_train = strategy.set_freqai_targets(
1✔
331
                    dataframe_base_train, metadata=metadata)
332
                dataframe_base_backtest = dataframe.loc[dataframe["date"] < tr_backtest.stopdt, :]
1✔
333
                dataframe_base_backtest = strategy.set_freqai_targets(
1✔
334
                    dataframe_base_backtest, metadata=metadata)
335

336
                tr_train = dk.buffer_timerange(tr_train)
1✔
337

338
                dataframe_train = dk.slice_dataframe(tr_train, dataframe_base_train)
1✔
339
                dataframe_backtest = dk.slice_dataframe(tr_backtest, dataframe_base_backtest)
1✔
340

341
                dataframe_train = dk.remove_special_chars_from_feature_names(dataframe_train)
1✔
342
                dataframe_backtest = dk.remove_special_chars_from_feature_names(dataframe_backtest)
1✔
343
                dk.get_unique_classes_from_labels(dataframe_train)
1✔
344

345
                if not self.model_exists(dk):
1✔
346
                    dk.find_features(dataframe_train)
1✔
347
                    dk.find_labels(dataframe_train)
1✔
348

349
                    try:
1✔
350
                        self.tb_logger = get_tb_logger(self.dd.model_type, dk.data_path,
1✔
351
                                                       self.activate_tensorboard)
352
                        self.model = self.train(dataframe_train, pair, dk)
1✔
353
                        self.tb_logger.close()
1✔
354
                    except Exception as msg:
×
355
                        logger.warning(
×
356
                            f"Training {pair} raised exception {msg.__class__.__name__}. "
357
                            f"Message: {msg}, skipping.", exc_info=True)
358
                        self.model = None
×
359

360
                    self.dd.pair_dict[pair]["trained_timestamp"] = int(
1✔
361
                        tr_train.stopts)
362
                    if self.plot_features and self.model is not None:
1✔
363
                        plot_feature_importance(self.model, pair, dk, self.plot_features)
×
364
                    if self.save_backtest_models and self.model is not None:
1✔
365
                        logger.info('Saving backtest model to disk.')
1✔
366
                        self.dd.save_data(self.model, pair, dk)
1✔
367
                    else:
368
                        logger.info('Saving metadata to disk.')
×
369
                        self.dd.save_metadata(dk)
×
370
                else:
371
                    self.model = self.dd.load_data(pair, dk)
×
372

373
                pred_df, do_preds = self.predict(dataframe_backtest, dk)
1✔
374
                append_df = dk.get_predictions_to_append(pred_df, do_preds, dataframe_backtest)
1✔
375
                dk.append_predictions(append_df)
1✔
376
                dk.save_backtesting_prediction(append_df)
1✔
377

378
        self.backtesting_fit_live_predictions(dk)
1✔
379
        dk.fill_predictions(dataframe)
1✔
380

381
        return dk
1✔
382

383
    def start_live(
1✔
384
        self, dataframe: DataFrame, metadata: dict, strategy: IStrategy, dk: FreqaiDataKitchen
385
    ) -> FreqaiDataKitchen:
386
        """
387
        The main broad execution for dry/live. This function will check if a retraining should be
388
        performed, and if so, retrain and reset the model.
389
        :param dataframe: DataFrame = strategy passed dataframe
390
        :param metadata: Dict = pair metadata
391
        :param strategy: IStrategy = currently employed strategy
392
        dk: FreqaiDataKitchen = Data management/analysis tool associated to present pair only
393
        :returns:
394
        dk: FreqaiDataKitchen = Data management/analysis tool associated to present pair only
395
        """
396

397
        # get the model metadata associated with the current pair
398
        (_, trained_timestamp) = self.dd.get_pair_dict_info(metadata["pair"])
×
399

400
        # append the historic data once per round
401
        if self.dd.historic_data:
×
402
            self.dd.update_historic_data(strategy, dk)
×
403
            logger.debug(f'Updating historic data on pair {metadata["pair"]}')
×
404
            self.track_current_candle()
×
405

406
        (_, new_trained_timerange, data_load_timerange) = dk.check_if_new_training_required(
×
407
            trained_timestamp
408
        )
409
        dk.set_paths(metadata["pair"], new_trained_timerange.stopts)
×
410

411
        # load candle history into memory if it is not yet.
412
        if not self.dd.historic_data:
×
413
            self.dd.load_all_pair_histories(data_load_timerange, dk)
×
414

415
        if not self.scanning:
×
416
            self.scanning = True
×
417
            self.start_scanning(strategy)
×
418

419
        # load the model and associated data into the data kitchen
420
        self.model = self.dd.load_data(metadata["pair"], dk)
×
421

422
        dataframe = dk.use_strategy_to_populate_indicators(
×
423
            strategy, prediction_dataframe=dataframe, pair=metadata["pair"],
424
            do_corr_pairs=self.get_corr_dataframes
425
        )
426

427
        if not self.model:
×
428
            logger.warning(
×
429
                f"No model ready for {metadata['pair']}, returning null values to strategy."
430
            )
431
            self.dd.return_null_values_to_strategy(dataframe, dk)
×
432
            return dk
×
433

434
        if self.corr_pairlist:
×
435
            dataframe = self.cache_corr_pairlist_dfs(dataframe, dk)
×
436

437
        dk.find_labels(dataframe)
×
438

439
        self.build_strategy_return_arrays(dataframe, dk, metadata["pair"], trained_timestamp)
×
440

441
        return dk
×
442

443
    def build_strategy_return_arrays(
1✔
444
        self, dataframe: DataFrame, dk: FreqaiDataKitchen, pair: str, trained_timestamp: int
445
    ) -> None:
446

447
        # hold the historical predictions in memory so we are sending back
448
        # correct array to strategy
449

450
        if pair not in self.dd.model_return_values:
×
451
            # first predictions are made on entire historical candle set coming from strategy. This
452
            # allows FreqUI to show full return values.
453
            pred_df, do_preds = self.predict(dataframe, dk)
×
454
            if pair not in self.dd.historic_predictions:
×
455
                self.set_initial_historic_predictions(pred_df, dk, pair, dataframe)
×
456
            self.dd.set_initial_return_values(pair, pred_df)
×
457

458
            dk.return_dataframe = self.dd.attach_return_values_to_return_dataframe(pair, dataframe)
×
459
            return
×
460
        elif self.dk.check_if_model_expired(trained_timestamp):
×
461
            pred_df = DataFrame(np.zeros((2, len(dk.label_list))), columns=dk.label_list)
×
462
            do_preds = np.ones(2, dtype=np.int_) * 2
×
463
            dk.DI_values = np.zeros(2)
×
464
            logger.warning(
×
465
                f"Model expired for {pair}, returning null values to strategy. Strategy "
466
                "construction should take care to consider this event with "
467
                "prediction == 0 and do_predict == 2"
468
            )
469
        else:
470
            # remaining predictions are made only on the most recent candles for performance and
471
            # historical accuracy reasons.
472
            pred_df, do_preds = self.predict(dataframe.iloc[-self.CONV_WIDTH:], dk, first=False)
×
473

474
        if self.freqai_info.get('fit_live_predictions_candles', 0) and self.live:
×
475
            self.fit_live_predictions(dk, pair)
×
476
        self.dd.append_model_predictions(pair, pred_df, do_preds, dk, dataframe)
×
477
        dk.return_dataframe = self.dd.attach_return_values_to_return_dataframe(pair, dataframe)
×
478

479
        return
×
480

481
    def check_if_feature_list_matches_strategy(
1✔
482
        self, dk: FreqaiDataKitchen
483
    ) -> None:
484
        """
485
        Ensure user is passing the proper feature set if they are reusing an `identifier` pointing
486
        to a folder holding existing models.
487
        :param dataframe: DataFrame = strategy provided dataframe
488
        :param dk: FreqaiDataKitchen = non-persistent data container/analyzer for
489
                   current coin/bot loop
490
        """
491

492
        if "training_features_list_raw" in dk.data:
1✔
493
            feature_list = dk.data["training_features_list_raw"]
×
494
        else:
495
            feature_list = dk.data['training_features_list']
1✔
496

497
        if dk.training_features_list != feature_list:
1✔
498
            raise OperationalException(
×
499
                "Trying to access pretrained model with `identifier` "
500
                "but found different features furnished by current strategy. "
501
                "Change `identifier` to train from scratch, or ensure the "
502
                "strategy is furnishing the same features as the pretrained "
503
                "model. In case of --strategy-list, please be aware that FreqAI "
504
                "requires all strategies to maintain identical "
505
                "feature_engineering_* functions"
506
            )
507

508
    def define_data_pipeline(self, threads=-1) -> Pipeline:
1✔
509
        ft_params = self.freqai_info["feature_parameters"]
1✔
510
        pipe_steps = [
1✔
511
            ('const', ds.VarianceThreshold(threshold=0)),
512
            ('scaler', SKLearnWrapper(MinMaxScaler(feature_range=(-1, 1))))
513
            ]
514

515
        if ft_params.get("principal_component_analysis", False):
1✔
516
            pipe_steps.append(('pca', ds.PCA(n_components=0.999)))
1✔
517
            pipe_steps.append(('post-pca-scaler',
1✔
518
                               SKLearnWrapper(MinMaxScaler(feature_range=(-1, 1)))))
519

520
        if ft_params.get("use_SVM_to_remove_outliers", False):
1✔
521
            svm_params = ft_params.get(
1✔
522
                "svm_params", {"shuffle": False, "nu": 0.01})
523
            pipe_steps.append(('svm', ds.SVMOutlierExtractor(**svm_params)))
1✔
524

525
        di = ft_params.get("DI_threshold", 0)
1✔
526
        if di:
1✔
527
            pipe_steps.append(('di', ds.DissimilarityIndex(di_threshold=di, n_jobs=threads)))
1✔
528

529
        if ft_params.get("use_DBSCAN_to_remove_outliers", False):
1✔
530
            pipe_steps.append(('dbscan', ds.DBSCAN(n_jobs=threads)))
1✔
531

532
        sigma = self.freqai_info["feature_parameters"].get('noise_standard_deviation', 0)
1✔
533
        if sigma:
1✔
534
            pipe_steps.append(('noise', ds.Noise(sigma=sigma)))
1✔
535

536
        return Pipeline(pipe_steps)
1✔
537

538
    def define_label_pipeline(self, threads=-1) -> Pipeline:
1✔
539

540
        label_pipeline = Pipeline([
1✔
541
            ('scaler', SKLearnWrapper(MinMaxScaler(feature_range=(-1, 1))))
542
            ])
543

544
        return label_pipeline
1✔
545

546
    def model_exists(self, dk: FreqaiDataKitchen) -> bool:
1✔
547
        """
548
        Given a pair and path, check if a model already exists
549
        :param pair: pair e.g. BTC/USD
550
        :param path: path to model
551
        :return:
552
        :boolean: whether the model file exists or not.
553
        """
554
        if self.dd.model_type == 'joblib':
1✔
555
            file_type = ".joblib"
1✔
556
        elif self.dd.model_type in ["stable_baselines3", "sb3_contrib", "pytorch"]:
1✔
557
            file_type = ".zip"
1✔
558

559
        path_to_modelfile = Path(dk.data_path / f"{dk.model_filename}_model{file_type}")
1✔
560
        file_exists = path_to_modelfile.is_file()
1✔
561
        if file_exists:
1✔
562
            logger.info("Found model at %s", dk.data_path / dk.model_filename)
×
563
        else:
564
            logger.info("Could not find model at %s", dk.data_path / dk.model_filename)
1✔
565
        return file_exists
1✔
566

567
    def set_full_path(self) -> None:
1✔
568
        """
569
        Creates and sets the full path for the identifier
570
        """
571
        self.full_path = Path(
1✔
572
            self.config["user_data_dir"] / "models" / f"{self.identifier}"
573
        )
574
        self.full_path.mkdir(parents=True, exist_ok=True)
1✔
575

576
    def extract_data_and_train_model(
1✔
577
        self,
578
        new_trained_timerange: TimeRange,
579
        pair: str,
580
        strategy: IStrategy,
581
        dk: FreqaiDataKitchen,
582
        data_load_timerange: TimeRange,
583
    ):
584
        """
585
        Retrieve data and train model.
586
        :param new_trained_timerange: TimeRange = the timerange to train the model on
587
        :param metadata: dict = strategy provided metadata
588
        :param strategy: IStrategy = user defined strategy object
589
        :param dk: FreqaiDataKitchen = non-persistent data container for current coin/loop
590
        :param data_load_timerange: TimeRange = the amount of data to be loaded
591
                                    for populating indicators
592
                                    (larger than new_trained_timerange so that
593
                                    new_trained_timerange does not contain any NaNs)
594
        """
595

596
        corr_dataframes, base_dataframes = self.dd.get_base_and_corr_dataframes(
1✔
597
            data_load_timerange, pair, dk
598
        )
599

600
        unfiltered_dataframe = dk.use_strategy_to_populate_indicators(
1✔
601
            strategy, corr_dataframes, base_dataframes, pair
602
        )
603

604
        trained_timestamp = new_trained_timerange.stopts
1✔
605

606
        buffered_timerange = dk.buffer_timerange(new_trained_timerange)
1✔
607

608
        unfiltered_dataframe = dk.slice_dataframe(buffered_timerange, unfiltered_dataframe)
1✔
609

610
        # find the features indicated by strategy and store in datakitchen
611
        dk.find_features(unfiltered_dataframe)
1✔
612
        dk.find_labels(unfiltered_dataframe)
1✔
613

614
        self.tb_logger = get_tb_logger(self.dd.model_type, dk.data_path,
1✔
615
                                       self.activate_tensorboard)
616
        model = self.train(unfiltered_dataframe, pair, dk)
1✔
617
        self.tb_logger.close()
1✔
618

619
        self.dd.pair_dict[pair]["trained_timestamp"] = trained_timestamp
1✔
620
        dk.set_new_model_names(pair, trained_timestamp)
1✔
621
        self.dd.save_data(model, pair, dk)
1✔
622

623
        if self.plot_features:
1✔
624
            plot_feature_importance(model, pair, dk, self.plot_features)
×
625

626
        self.dd.purge_old_models()
1✔
627

628
    def set_initial_historic_predictions(
1✔
629
        self, pred_df: DataFrame, dk: FreqaiDataKitchen, pair: str, strat_df: DataFrame
630
    ) -> None:
631
        """
632
        This function is called only if the datadrawer failed to load an
633
        existing set of historic predictions. In this case, it builds
634
        the structure and sets fake predictions off the first training
635
        data. After that, FreqAI will append new real predictions to the
636
        set of historic predictions.
637

638
        These values are used to generate live statistics which can be used
639
        in the strategy for adaptive values. E.g. &*_mean/std are quantities
640
        that can computed based on live predictions from the set of historical
641
        predictions. Those values can be used in the user strategy to better
642
        assess prediction rarity, and thus wait for probabilistically favorable
643
        entries relative to the live historical predictions.
644

645
        If the user reuses an identifier on a subsequent instance,
646
        this function will not be called. In that case, "real" predictions
647
        will be appended to the loaded set of historic predictions.
648
        :param df: DataFrame = the dataframe containing the training feature data
649
        :param model: Any = A model which was `fit` using a common library such as
650
                      catboost or lightgbm
651
        :param dk: FreqaiDataKitchen = object containing methods for data analysis
652
        :param pair: str = current pair
653
        """
654

655
        self.dd.historic_predictions[pair] = pred_df
×
656
        hist_preds_df = self.dd.historic_predictions[pair]
×
657

658
        self.set_start_dry_live_date(strat_df)
×
659

660
        for label in hist_preds_df.columns:
×
661
            if hist_preds_df[label].dtype == object:
×
662
                continue
×
663
            hist_preds_df[f'{label}_mean'] = 0
×
664
            hist_preds_df[f'{label}_std'] = 0
×
665

666
        hist_preds_df['do_predict'] = 0
×
667

668
        if self.freqai_info['feature_parameters'].get('DI_threshold', 0) > 0:
×
669
            hist_preds_df['DI_values'] = 0
×
670

671
        for return_str in dk.data['extra_returns_per_train']:
×
672
            hist_preds_df[return_str] = dk.data['extra_returns_per_train'][return_str]
×
673

674
        hist_preds_df['close_price'] = strat_df['close']
×
675
        hist_preds_df['date_pred'] = strat_df['date']
×
676

677
    def fit_live_predictions(self, dk: FreqaiDataKitchen, pair: str) -> None:
1✔
678
        """
679
        Fit the labels with a gaussian distribution
680
        """
681
        import scipy as spy
1✔
682

683
        # add classes from classifier label types if used
684
        full_labels = dk.label_list + dk.unique_class_list
1✔
685

686
        num_candles = self.freqai_info.get("fit_live_predictions_candles", 100)
1✔
687
        dk.data["labels_mean"], dk.data["labels_std"] = {}, {}
1✔
688
        for label in full_labels:
1✔
689
            if self.dd.historic_predictions[dk.pair][label].dtype == object:
1✔
690
                continue
×
691
            f = spy.stats.norm.fit(
1✔
692
                self.dd.historic_predictions[dk.pair][label].tail(num_candles))
693
            dk.data["labels_mean"][label], dk.data["labels_std"][label] = f[0], f[1]
1✔
694

695
        return
1✔
696

697
    def inference_timer(self, do: Literal['start', 'stop'] = 'start', pair: str = ''):
1✔
698
        """
699
        Timer designed to track the cumulative time spent in FreqAI for one pass through
700
        the whitelist. This will check if the time spent is more than 1/4 the time
701
        of a single candle, and if so, it will warn the user of degraded performance
702
        """
703
        if do == 'start':
×
704
            self.pair_it += 1
×
705
            self.begin_time = time.time()
×
706
        elif do == 'stop':
×
707
            end = time.time()
×
708
            time_spent = (end - self.begin_time)
×
709
            if self.freqai_info.get('write_metrics_to_disk', False):
×
710
                self.dd.update_metric_tracker('inference_time', time_spent, pair)
×
711
            self.inference_time += time_spent
×
712
            if self.pair_it == self.total_pairs:
×
713
                logger.info(
×
714
                    f'Total time spent inferencing pairlist {self.inference_time:.2f} seconds')
715
                if self.inference_time > 0.25 * self.base_tf_seconds:
×
716
                    logger.warning("Inference took over 25% of the candle time. Reduce pairlist to"
×
717
                                   " avoid blinding open trades and degrading performance.")
718
                self.pair_it = 0
×
719
                self.inference_time = 0
×
720
        return
×
721

722
    def train_timer(self, do: Literal['start', 'stop'] = 'start', pair: str = ''):
1✔
723
        """
724
        Timer designed to track the cumulative time spent training the full pairlist in
725
        FreqAI.
726
        """
727
        if do == 'start':
1✔
728
            self.pair_it_train += 1
1✔
729
            self.begin_time_train = time.time()
1✔
730
        elif do == 'stop':
1✔
731
            end = time.time()
1✔
732
            time_spent = (end - self.begin_time_train)
1✔
733
            if self.freqai_info.get('write_metrics_to_disk', False):
1✔
734
                self.dd.collect_metrics(time_spent, pair)
×
735

736
            self.train_time += time_spent
1✔
737
            if self.pair_it_train == self.total_pairs:
1✔
738
                logger.info(
×
739
                    f'Total time spent training pairlist {self.train_time:.2f} seconds')
740
                self.pair_it_train = 0
×
741
                self.train_time = 0
×
742
        return
1✔
743

744
    def get_init_model(self, pair: str) -> Any:
1✔
745
        if pair not in self.dd.model_dictionary or not self.continual_learning:
1✔
746
            init_model = None
1✔
747
        else:
748
            init_model = self.dd.model_dictionary[pair]
×
749

750
        return init_model
1✔
751

752
    def _set_train_queue(self):
1✔
753
        """
754
        Sets train queue from existing train timestamps if they exist
755
        otherwise it sets the train queue based on the provided whitelist.
756
        """
757
        current_pairlist = self.config.get("exchange", {}).get("pair_whitelist")
1✔
758
        if not self.dd.pair_dict:
1✔
759
            logger.info('Set fresh train queue from whitelist. '
1✔
760
                        f'Queue: {current_pairlist}')
761
            return deque(current_pairlist)
1✔
762

763
        best_queue = deque()
1✔
764

765
        pair_dict_sorted = sorted(self.dd.pair_dict.items(),
1✔
766
                                  key=lambda k: k[1]['trained_timestamp'])
767
        for pair in pair_dict_sorted:
1✔
768
            if pair[0] in current_pairlist:
1✔
769
                best_queue.append(pair[0])
1✔
770
        for pair in current_pairlist:
1✔
771
            if pair not in best_queue:
1✔
772
                best_queue.appendleft(pair)
1✔
773

774
        logger.info('Set existing queue from trained timestamps. '
1✔
775
                    f'Best approximation queue: {best_queue}')
776
        return best_queue
1✔
777

778
    def cache_corr_pairlist_dfs(self, dataframe: DataFrame, dk: FreqaiDataKitchen) -> DataFrame:
1✔
779
        """
780
        Cache the corr_pairlist dfs to speed up performance for subsequent pairs during the
781
        current candle.
782
        :param dataframe: strategy fed dataframe
783
        :param dk: datakitchen object for current asset
784
        :return: dataframe to attach/extract cached corr_pair dfs to/from.
785
        """
786

787
        if self.get_corr_dataframes:
×
788
            self.corr_dataframes = dk.extract_corr_pair_columns_from_populated_indicators(dataframe)
×
789
            if not self.corr_dataframes:
×
790
                logger.warning("Couldn't cache corr_pair dataframes for improved performance. "
×
791
                               "Consider ensuring that the full coin/stake, e.g. XYZ/USD, "
792
                               "is included in the column names when you are creating features "
793
                               "in `feature_engineering_*` functions.")
794
            self.get_corr_dataframes = not bool(self.corr_dataframes)
×
795
        elif self.corr_dataframes:
×
796
            dataframe = dk.attach_corr_pair_columns(
×
797
                dataframe, self.corr_dataframes, dk.pair)
798

799
        return dataframe
×
800

801
    def track_current_candle(self):
1✔
802
        """
803
        Checks if the latest candle appended by the datadrawer is
804
        equivalent to the latest candle seen by FreqAI. If not, it
805
        asks to refresh the cached corr_dfs, and resets the pair
806
        counter.
807
        """
808
        if self.dd.current_candle > self.current_candle:
×
809
            self.get_corr_dataframes = True
×
810
            self.pair_it = 1
×
811
            self.current_candle = self.dd.current_candle
×
812

813
    def ensure_data_exists(self, len_dataframe_backtest: int,
1✔
814
                           tr_backtest: TimeRange, pair: str) -> bool:
815
        """
816
        Check if the dataframe is empty, if not, report useful information to user.
817
        :param len_dataframe_backtest: the len of backtesting dataframe
818
        :param tr_backtest: current backtesting timerange.
819
        :param pair: current pair
820
        :return: if the data exists or not
821
        """
822
        if self.config.get("freqai_backtest_live_models", False) and len_dataframe_backtest == 0:
1✔
823
            logger.info(f"No data found for pair {pair} from "
×
824
                        f"from { tr_backtest.start_fmt} to {tr_backtest.stop_fmt}. "
825
                        "Probably more than one training within the same candle period.")
826
            return False
×
827
        return True
1✔
828

829
    def log_backtesting_progress(self, tr_train: TimeRange, pair: str,
1✔
830
                                 train_it: int, total_trains: int):
831
        """
832
        Log the backtesting progress so user knows how many pairs have been trained and
833
        how many more pairs/trains remain.
834
        :param tr_train: the training timerange
835
        :param train_it: the train iteration for the current pair (the sliding window progress)
836
        :param pair: the current pair
837
        :param total_trains: total trains (total number of slides for the sliding window)
838
        """
839
        if not self.config.get("freqai_backtest_live_models", False):
1✔
840
            logger.info(
1✔
841
                f"Training {pair}, {self.pair_it}/{self.total_pairs} pairs"
842
                f" from {tr_train.start_fmt} "
843
                f"to {tr_train.stop_fmt}, {train_it}/{total_trains} "
844
                "trains"
845
            )
846

847
    def backtesting_fit_live_predictions(self, dk: FreqaiDataKitchen):
1✔
848
        """
849
        Apply fit_live_predictions function in backtesting with a dummy historic_predictions
850
        The loop is required to simulate dry/live operation, as it is not possible to predict
851
        the type of logic implemented by the user.
852
        :param dk: datakitchen object
853
        """
854
        fit_live_predictions_candles = self.freqai_info.get("fit_live_predictions_candles", 0)
1✔
855
        if fit_live_predictions_candles:
1✔
856
            logger.info("Applying fit_live_predictions in backtesting")
1✔
857
            label_columns = [col for col in dk.full_df.columns if (
1✔
858
                col.startswith("&") and
859
                not (col.startswith("&") and col.endswith("_mean")) and
860
                not (col.startswith("&") and col.endswith("_std")) and
861
                col not in self.dk.data["extra_returns_per_train"])
862
            ]
863

864
            for index in range(len(dk.full_df)):
1✔
865
                if index >= fit_live_predictions_candles:
1✔
866
                    self.dd.historic_predictions[self.dk.pair] = (
1✔
867
                        dk.full_df.iloc[index - fit_live_predictions_candles:index])
868
                    self.fit_live_predictions(self.dk, self.dk.pair)
1✔
869
                    for label in label_columns:
1✔
870
                        if dk.full_df[label].dtype == object:
1✔
871
                            continue
×
872
                        if "labels_mean" in self.dk.data:
1✔
873
                            dk.full_df.at[index, f"{label}_mean"] = (
1✔
874
                                self.dk.data["labels_mean"][label])
875
                        if "labels_std" in self.dk.data:
1✔
876
                            dk.full_df.at[index, f"{label}_std"] = self.dk.data["labels_std"][label]
1✔
877

878
                    for extra_col in self.dk.data["extra_returns_per_train"]:
1✔
879
                        dk.full_df.at[index, f"{extra_col}"] = (
×
880
                            self.dk.data["extra_returns_per_train"][extra_col])
881

882
        return
1✔
883

884
    def update_metadata(self, metadata: Dict[str, Any]):
1✔
885
        """
886
        Update global metadata and save the updated json file
887
        :param metadata: new global metadata dict
888
        """
889
        self.dd.save_global_metadata_to_disk(metadata)
×
890
        self.metadata = metadata
×
891

892
    def set_start_dry_live_date(self, live_dataframe: DataFrame):
1✔
893
        key_name = "start_dry_live_date"
×
894
        if key_name not in self.metadata:
×
895
            metadata = self.metadata
×
896
            metadata[key_name] = int(
×
897
                pd.to_datetime(live_dataframe.tail(1)["date"].values[0]).timestamp())
898
            self.update_metadata(metadata)
×
899

900
    def start_backtesting_from_historic_predictions(
1✔
901
        self, dataframe: DataFrame, metadata: dict, dk: FreqaiDataKitchen
902
    ) -> FreqaiDataKitchen:
903
        """
904
        :param dataframe: DataFrame = strategy passed dataframe
905
        :param metadata: Dict = pair metadata
906
        :param dk: FreqaiDataKitchen = Data management/analysis tool associated to present pair only
907
        :return:
908
            FreqaiDataKitchen = Data management/analysis tool associated to present pair only
909
        """
910
        pair = metadata["pair"]
×
911
        dk.return_dataframe = dataframe
×
912
        saved_dataframe = self.dd.historic_predictions[pair]
×
913
        columns_to_drop = list(set(saved_dataframe.columns).intersection(
×
914
            dk.return_dataframe.columns))
915
        dk.return_dataframe = dk.return_dataframe.drop(columns=list(columns_to_drop))
×
916
        dk.return_dataframe = pd.merge(
×
917
            dk.return_dataframe, saved_dataframe, how='left', left_on='date', right_on="date_pred")
918
        return dk
×
919

920
    # Following methods which are overridden by user made prediction models.
921
    # See freqai/prediction_models/CatboostPredictionModel.py for an example.
922

923
    @abstractmethod
1✔
924
    def train(self, unfiltered_df: DataFrame, pair: str,
1✔
925
              dk: FreqaiDataKitchen, **kwargs) -> Any:
926
        """
927
        Filter the training data and train a model to it. Train makes heavy use of the datahandler
928
        for storing, saving, loading, and analyzing the data.
929
        :param unfiltered_df: Full dataframe for the current training period
930
        :param metadata: pair metadata from strategy.
931
        :return: Trained model which can be used to inference (self.predict)
932
        """
933

934
    @abstractmethod
1✔
935
    def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs) -> Any:
1✔
936
        """
937
        Most regressors use the same function names and arguments e.g. user
938
        can drop in LGBMRegressor in place of CatBoostRegressor and all data
939
        management will be properly handled by Freqai.
940
        :param data_dictionary: Dict = the dictionary constructed by DataHandler to hold
941
                                all the training and test data/labels.
942
        """
943

944
        return
×
945

946
    @abstractmethod
1✔
947
    def predict(
1✔
948
        self, unfiltered_df: DataFrame, dk: FreqaiDataKitchen, **kwargs
949
    ) -> Tuple[DataFrame, NDArray[np.int_]]:
950
        """
951
        Filter the prediction features data and predict with it.
952
        :param unfiltered_df: Full dataframe for the current backtest period.
953
        :param dk: FreqaiDataKitchen = Data management/analysis tool associated to present pair only
954
        :param first: boolean = whether this is the first prediction or not.
955
        :return:
956
        :predictions: np.array of predictions
957
        :do_predict: np.array of 1s and 0s to indicate places where freqai needed to remove
958
        data (NaNs) or felt uncertain about data (i.e. SVM and/or DI index)
959
        """
960

961
    # deprecated functions
962
    def data_cleaning_train(self, dk: FreqaiDataKitchen, pair: str):
1✔
963
        """
964
        throw deprecation warning if this function is called
965
        """
966
        logger.warning(f"Your model {self.__class__.__name__} relies on the deprecated"
×
967
                       " data pipeline. Please update your model to use the new data pipeline."
968
                       " This can be achieved by following the migration guide at "
969
                       f"{DOCS_LINK}/strategy_migration/#freqai-new-data-pipeline")
970
        dk.feature_pipeline = self.define_data_pipeline(threads=dk.thread_count)
×
971
        dd = dk.data_dictionary
×
972
        (dd["train_features"],
×
973
         dd["train_labels"],
974
         dd["train_weights"]) = dk.feature_pipeline.fit_transform(dd["train_features"],
975
                                                                  dd["train_labels"],
976
                                                                  dd["train_weights"])
977

978
        (dd["test_features"],
×
979
         dd["test_labels"],
980
         dd["test_weights"]) = dk.feature_pipeline.transform(dd["test_features"],
981
                                                             dd["test_labels"],
982
                                                             dd["test_weights"])
983

984
        dk.label_pipeline = self.define_label_pipeline(threads=dk.thread_count)
×
985

986
        dd["train_labels"], _, _ = dk.label_pipeline.fit_transform(dd["train_labels"])
×
987
        dd["test_labels"], _, _ = dk.label_pipeline.transform(dd["test_labels"])
×
988
        return
×
989

990
    def data_cleaning_predict(self, dk: FreqaiDataKitchen, pair: str):
1✔
991
        """
992
        throw deprecation warning if this function is called
993
        """
994
        logger.warning(f"Your model {self.__class__.__name__} relies on the deprecated"
×
995
                       " data pipeline. Please update your model to use the new data pipeline."
996
                       " This can be achieved by following the migration guide at "
997
                       f"{DOCS_LINK}/strategy_migration/#freqai-new-data-pipeline")
998
        dd = dk.data_dictionary
×
999
        dd["predict_features"], outliers, _ = dk.feature_pipeline.transform(
×
1000
            dd["predict_features"], outlier_check=True)
1001
        if self.freqai_info.get("DI_threshold", 0) > 0:
×
1002
            dk.DI_values = dk.feature_pipeline["di"].di_values
×
1003
        else:
1004
            dk.DI_values = np.zeros(outliers.shape[0])
×
1005
        dk.do_predict = outliers
×
1006
        return
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc