• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

freqtrade / freqtrade / 9394559170

26 Apr 2024 06:36AM UTC coverage: 94.656% (-0.02%) from 94.674%
9394559170

push

github

xmatthias
Loader should be passed as kwarg for clarity

20280 of 21425 relevant lines covered (94.66%)

0.95 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

97.87
/freqtrade/freqai/base_models/BaseRegressionModel.py
1
import logging
1✔
2
from time import time
1✔
3
from typing import Any, Tuple
1✔
4

5
import numpy as np
1✔
6
import numpy.typing as npt
1✔
7
from pandas import DataFrame
1✔
8

9
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
1✔
10
from freqtrade.freqai.freqai_interface import IFreqaiModel
1✔
11

12

13
logger = logging.getLogger(__name__)
1✔
14

15

16
class BaseRegressionModel(IFreqaiModel):
1✔
17
    """
18
    Base class for regression type models (e.g. Catboost, LightGBM, XGboost etc.).
19
    User *must* inherit from this class and set fit(). See example scripts
20
    such as prediction_models/CatboostRegressor.py for guidance.
21
    """
22

23
    def train(
1✔
24
        self, unfiltered_df: DataFrame, pair: str, dk: FreqaiDataKitchen, **kwargs
25
    ) -> Any:
26
        """
27
        Filter the training data and train a model to it. Train makes heavy use of the datakitchen
28
        for storing, saving, loading, and analyzing the data.
29
        :param unfiltered_df: Full dataframe for the current training period
30
        :param metadata: pair metadata from strategy.
31
        :return:
32
        :model: Trained model which can be used to inference (self.predict)
33
        """
34

35
        logger.info(f"-------------------- Starting training {pair} --------------------")
1✔
36

37
        start_time = time()
1✔
38

39
        # filter the features requested by user in the configuration file and elegantly handle NaNs
40
        features_filtered, labels_filtered = dk.filter_features(
1✔
41
            unfiltered_df,
42
            dk.training_features_list,
43
            dk.label_list,
44
            training_filter=True,
45
        )
46

47
        start_date = unfiltered_df["date"].iloc[0].strftime("%Y-%m-%d")
1✔
48
        end_date = unfiltered_df["date"].iloc[-1].strftime("%Y-%m-%d")
1✔
49
        logger.info(f"-------------------- Training on data from {start_date} to "
1✔
50
                    f"{end_date} --------------------")
51
        # split data into train/test data.
52
        dd = dk.make_train_test_datasets(features_filtered, labels_filtered)
1✔
53
        if not self.freqai_info.get("fit_live_predictions_candles", 0) or not self.live:
1✔
54
            dk.fit_labels()
1✔
55
        dk.feature_pipeline = self.define_data_pipeline(threads=dk.thread_count)
1✔
56
        dk.label_pipeline = self.define_label_pipeline(threads=dk.thread_count)
1✔
57

58
        (dd["train_features"],
1✔
59
         dd["train_labels"],
60
         dd["train_weights"]) = dk.feature_pipeline.fit_transform(dd["train_features"],
61
                                                                  dd["train_labels"],
62
                                                                  dd["train_weights"])
63
        dd["train_labels"], _, _ = dk.label_pipeline.fit_transform(dd["train_labels"])
1✔
64

65
        if self.freqai_info.get('data_split_parameters', {}).get('test_size', 0.1) != 0:
1✔
66
            (dd["test_features"],
1✔
67
             dd["test_labels"],
68
             dd["test_weights"]) = dk.feature_pipeline.transform(dd["test_features"],
69
                                                                 dd["test_labels"],
70
                                                                 dd["test_weights"])
71
            dd["test_labels"], _, _ = dk.label_pipeline.transform(dd["test_labels"])
1✔
72

73
        logger.info(
1✔
74
            f"Training model on {len(dk.data_dictionary['train_features'].columns)} features"
75
        )
76
        logger.info(f"Training model on {len(dd['train_features'])} data points")
1✔
77

78
        model = self.fit(dd, dk)
1✔
79

80
        end_time = time()
1✔
81

82
        logger.info(f"-------------------- Done training {pair} "
1✔
83
                    f"({end_time - start_time:.2f} secs) --------------------")
84

85
        return model
1✔
86

87
    def predict(
1✔
88
        self, unfiltered_df: DataFrame, dk: FreqaiDataKitchen, **kwargs
89
    ) -> Tuple[DataFrame, npt.NDArray[np.int_]]:
90
        """
91
        Filter the prediction features data and predict with it.
92
        :param unfiltered_df: Full dataframe for the current backtest period.
93
        :return:
94
        :pred_df: dataframe containing the predictions
95
        :do_predict: np.array of 1s and 0s to indicate places where freqai needed to remove
96
        data (NaNs) or felt uncertain about data (PCA and DI index)
97
        """
98

99
        dk.find_features(unfiltered_df)
1✔
100
        dk.data_dictionary["prediction_features"], _ = dk.filter_features(
1✔
101
            unfiltered_df, dk.training_features_list, training_filter=False
102
        )
103

104
        dk.data_dictionary["prediction_features"], outliers, _ = dk.feature_pipeline.transform(
1✔
105
            dk.data_dictionary["prediction_features"], outlier_check=True)
106

107
        predictions = self.model.predict(dk.data_dictionary["prediction_features"])
1✔
108
        if self.CONV_WIDTH == 1:
1✔
109
            predictions = np.reshape(predictions, (-1, len(dk.label_list)))
1✔
110

111
        pred_df = DataFrame(predictions, columns=dk.label_list)
1✔
112

113
        pred_df, _, _ = dk.label_pipeline.inverse_transform(pred_df)
1✔
114
        if dk.feature_pipeline["di"]:
1✔
115
            dk.DI_values = dk.feature_pipeline["di"].di_values
1✔
116
        else:
117
            dk.DI_values = np.zeros(outliers.shape[0])
×
118
        dk.do_predict = outliers
1✔
119

120
        return (pred_df, dk.do_predict)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc