• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

freqtrade / freqtrade / 9394559170

26 Apr 2024 06:36AM UTC coverage: 94.656% (-0.02%) from 94.674%
9394559170

push

github

xmatthias
Loader should be passed as kwarg for clarity

20280 of 21425 relevant lines covered (94.66%)

0.95 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.0
/freqtrade/freqai/base_models/BaseClassifierModel.py
1
import logging
1✔
2
from time import time
1✔
3
from typing import Any, Tuple
1✔
4

5
import numpy as np
1✔
6
import numpy.typing as npt
1✔
7
import pandas as pd
1✔
8
from pandas import DataFrame
1✔
9

10
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
1✔
11
from freqtrade.freqai.freqai_interface import IFreqaiModel
1✔
12

13

14
logger = logging.getLogger(__name__)
1✔
15

16

17
class BaseClassifierModel(IFreqaiModel):
1✔
18
    """
19
    Base class for regression type models (e.g. Catboost, LightGBM, XGboost etc.).
20
    User *must* inherit from this class and set fit(). See example scripts
21
    such as prediction_models/CatboostClassifier.py for guidance.
22
    """
23

24
    def train(
1✔
25
        self, unfiltered_df: DataFrame, pair: str, dk: FreqaiDataKitchen, **kwargs
26
    ) -> Any:
27
        """
28
        Filter the training data and train a model to it. Train makes heavy use of the datakitchen
29
        for storing, saving, loading, and analyzing the data.
30
        :param unfiltered_df: Full dataframe for the current training period
31
        :param metadata: pair metadata from strategy.
32
        :return:
33
        :model: Trained model which can be used to inference (self.predict)
34
        """
35

36
        logger.info(f"-------------------- Starting training {pair} --------------------")
1✔
37

38
        start_time = time()
1✔
39

40
        # filter the features requested by user in the configuration file and elegantly handle NaNs
41
        features_filtered, labels_filtered = dk.filter_features(
1✔
42
            unfiltered_df,
43
            dk.training_features_list,
44
            dk.label_list,
45
            training_filter=True,
46
        )
47

48
        start_date = unfiltered_df["date"].iloc[0].strftime("%Y-%m-%d")
1✔
49
        end_date = unfiltered_df["date"].iloc[-1].strftime("%Y-%m-%d")
1✔
50
        logger.info(f"-------------------- Training on data from {start_date} to "
1✔
51
                    f"{end_date} --------------------")
52
        # split data into train/test data.
53
        dd = dk.make_train_test_datasets(features_filtered, labels_filtered)
1✔
54
        if not self.freqai_info.get("fit_live_predictions_candles", 0) or not self.live:
1✔
55
            dk.fit_labels()
1✔
56
        dk.feature_pipeline = self.define_data_pipeline(threads=dk.thread_count)
1✔
57

58
        (dd["train_features"],
1✔
59
         dd["train_labels"],
60
         dd["train_weights"]) = dk.feature_pipeline.fit_transform(dd["train_features"],
61
                                                                  dd["train_labels"],
62
                                                                  dd["train_weights"])
63

64
        if self.freqai_info.get('data_split_parameters', {}).get('test_size', 0.1) != 0:
1✔
65
            (dd["test_features"],
1✔
66
             dd["test_labels"],
67
             dd["test_weights"]) = dk.feature_pipeline.transform(dd["test_features"],
68
                                                                 dd["test_labels"],
69
                                                                 dd["test_weights"])
70

71
        logger.info(
1✔
72
            f"Training model on {len(dk.data_dictionary['train_features'].columns)} features"
73
        )
74
        logger.info(f"Training model on {len(dd['train_features'])} data points")
1✔
75

76
        model = self.fit(dd, dk)
1✔
77

78
        end_time = time()
1✔
79

80
        logger.info(f"-------------------- Done training {pair} "
1✔
81
                    f"({end_time - start_time:.2f} secs) --------------------")
82

83
        return model
1✔
84

85
    def predict(
1✔
86
        self, unfiltered_df: DataFrame, dk: FreqaiDataKitchen, **kwargs
87
    ) -> Tuple[DataFrame, npt.NDArray[np.int_]]:
88
        """
89
        Filter the prediction features data and predict with it.
90
        :param unfiltered_df: Full dataframe for the current backtest period.
91
        :return:
92
        :pred_df: dataframe containing the predictions
93
        :do_predict: np.array of 1s and 0s to indicate places where freqai needed to remove
94
        data (NaNs) or felt uncertain about data (PCA and DI index)
95
        """
96

97
        dk.find_features(unfiltered_df)
1✔
98
        filtered_df, _ = dk.filter_features(
1✔
99
            unfiltered_df, dk.training_features_list, training_filter=False
100
        )
101

102
        dk.data_dictionary["prediction_features"] = filtered_df
1✔
103

104
        dk.data_dictionary["prediction_features"], outliers, _ = dk.feature_pipeline.transform(
1✔
105
            dk.data_dictionary["prediction_features"], outlier_check=True)
106

107
        predictions = self.model.predict(dk.data_dictionary["prediction_features"])
1✔
108
        if self.CONV_WIDTH == 1:
1✔
109
            predictions = np.reshape(predictions, (-1, len(dk.label_list)))
1✔
110

111
        pred_df = DataFrame(predictions, columns=dk.label_list)
1✔
112

113
        predictions_prob = self.model.predict_proba(dk.data_dictionary["prediction_features"])
1✔
114
        if self.CONV_WIDTH == 1:
1✔
115
            predictions_prob = np.reshape(predictions_prob, (-1, len(self.model.classes_)))
1✔
116
        pred_df_prob = DataFrame(predictions_prob, columns=self.model.classes_)
1✔
117

118
        pred_df = pd.concat([pred_df, pred_df_prob], axis=1)
1✔
119

120
        if dk.feature_pipeline["di"]:
1✔
121
            dk.DI_values = dk.feature_pipeline["di"].di_values
1✔
122
        else:
123
            dk.DI_values = np.zeros(outliers.shape[0])
×
124
        dk.do_predict = outliers
1✔
125

126
        return (pred_df, dk.do_predict)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc