• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

freqtrade / freqtrade / 9394559170

26 Apr 2024 06:36AM UTC coverage: 94.656% (-0.02%) from 94.674%
9394559170

push

github

xmatthias
Loader should be passed as kwarg for clarity

20280 of 21425 relevant lines covered (94.66%)

0.95 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.98
/freqtrade/freqai/base_models/BasePyTorchClassifier.py
1
import logging
1✔
2
from time import time
1✔
3
from typing import Any, Dict, List, Tuple
1✔
4

5
import numpy as np
1✔
6
import numpy.typing as npt
1✔
7
import pandas as pd
1✔
8
import torch
1✔
9
from pandas import DataFrame
1✔
10
from torch.nn import functional as F
1✔
11

12
from freqtrade.exceptions import OperationalException
1✔
13
from freqtrade.freqai.base_models.BasePyTorchModel import BasePyTorchModel
1✔
14
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
1✔
15

16

17
logger = logging.getLogger(__name__)
1✔
18

19

20
class BasePyTorchClassifier(BasePyTorchModel):
1✔
21
    """
22
    A PyTorch implementation of a classifier.
23
    User must implement fit method
24

25
    Important!
26

27
    - User must declare the target class names in the strategy,
28
    under IStrategy.set_freqai_targets method.
29

30
    for example, in your strategy:
31
    ```
32
        def set_freqai_targets(self, dataframe: DataFrame, metadata: Dict, **kwargs):
33
            self.freqai.class_names = ["down", "up"]
34
            dataframe['&s-up_or_down'] = np.where(dataframe["close"].shift(-100) >
35
                                                  dataframe["close"], 'up', 'down')
36

37
            return dataframe
38
    """
39

40
    def __init__(self, **kwargs):
1✔
41
        super().__init__(**kwargs)
1✔
42
        self.class_name_to_index = None
1✔
43
        self.index_to_class_name = None
1✔
44

45
    def predict(
1✔
46
        self, unfiltered_df: DataFrame, dk: FreqaiDataKitchen, **kwargs
47
    ) -> Tuple[DataFrame, npt.NDArray[np.int_]]:
48
        """
49
        Filter the prediction features data and predict with it.
50
        :param dk: dk: The datakitchen object
51
        :param unfiltered_df: Full dataframe for the current backtest period.
52
        :return:
53
        :pred_df: dataframe containing the predictions
54
        :do_predict: np.array of 1s and 0s to indicate places where freqai needed to remove
55
        data (NaNs) or felt uncertain about data (PCA and DI index)
56
        :raises ValueError: if 'class_names' doesn't exist in model meta_data.
57
        """
58

59
        class_names = self.model.model_meta_data.get("class_names", None)
1✔
60
        if not class_names:
1✔
61
            raise ValueError(
×
62
                "Missing class names. "
63
                "self.model.model_meta_data['class_names'] is None."
64
            )
65

66
        if not self.class_name_to_index:
1✔
67
            self.init_class_names_to_index_mapping(class_names)
×
68

69
        dk.find_features(unfiltered_df)
1✔
70
        filtered_df, _ = dk.filter_features(
1✔
71
            unfiltered_df, dk.training_features_list, training_filter=False
72
        )
73

74
        dk.data_dictionary["prediction_features"] = filtered_df
1✔
75

76
        dk.data_dictionary["prediction_features"], outliers, _ = dk.feature_pipeline.transform(
1✔
77
            dk.data_dictionary["prediction_features"], outlier_check=True)
78

79
        x = self.data_convertor.convert_x(
1✔
80
            dk.data_dictionary["prediction_features"],
81
            device=self.device
82
        )
83
        self.model.model.eval()
1✔
84
        logits = self.model.model(x)
1✔
85
        probs = F.softmax(logits, dim=-1)
1✔
86
        predicted_classes = torch.argmax(probs, dim=-1)
1✔
87
        predicted_classes_str = self.decode_class_names(predicted_classes)
1✔
88
        # used .tolist to convert probs into an iterable, in this way Tensors
89
        # are automatically moved to the CPU first if necessary.
90
        pred_df_prob = DataFrame(probs.detach().tolist(), columns=class_names)
1✔
91
        pred_df = DataFrame(predicted_classes_str, columns=[dk.label_list[0]])
1✔
92
        pred_df = pd.concat([pred_df, pred_df_prob], axis=1)
1✔
93

94
        if dk.feature_pipeline["di"]:
1✔
95
            dk.DI_values = dk.feature_pipeline["di"].di_values
1✔
96
        else:
97
            dk.DI_values = np.zeros(outliers.shape[0])
×
98
        dk.do_predict = outliers
1✔
99

100
        return (pred_df, dk.do_predict)
1✔
101

102
    def encode_class_names(
1✔
103
            self,
104
            data_dictionary: Dict[str, pd.DataFrame],
105
            dk: FreqaiDataKitchen,
106
            class_names: List[str],
107
    ):
108
        """
109
        encode class name, str -> int
110
        assuming first column of *_labels data frame to be the target column
111
        containing the class names
112
        """
113

114
        target_column_name = dk.label_list[0]
1✔
115
        for split in self.splits:
1✔
116
            label_df = data_dictionary[f"{split}_labels"]
1✔
117
            self.assert_valid_class_names(label_df[target_column_name], class_names)
1✔
118
            label_df[target_column_name] = list(
1✔
119
                map(lambda x: self.class_name_to_index[x], label_df[target_column_name])
120
            )
121

122
    @staticmethod
1✔
123
    def assert_valid_class_names(
1✔
124
            target_column: pd.Series,
125
            class_names: List[str]
126
    ):
127
        non_defined_labels = set(target_column) - set(class_names)
1✔
128
        if len(non_defined_labels) != 0:
1✔
129
            raise OperationalException(
×
130
                f"Found non defined labels: {non_defined_labels}, ",
131
                f"expecting labels: {class_names}"
132
            )
133

134
    def decode_class_names(self, class_ints: torch.Tensor) -> List[str]:
1✔
135
        """
136
        decode class name, int -> str
137
        """
138

139
        return list(map(lambda x: self.index_to_class_name[x.item()], class_ints))
1✔
140

141
    def init_class_names_to_index_mapping(self, class_names):
1✔
142
        self.class_name_to_index = {s: i for i, s in enumerate(class_names)}
1✔
143
        self.index_to_class_name = {i: s for i, s in enumerate(class_names)}
1✔
144
        logger.info(f"encoded class name to index: {self.class_name_to_index}")
1✔
145

146
    def convert_label_column_to_int(
1✔
147
            self,
148
            data_dictionary: Dict[str, pd.DataFrame],
149
            dk: FreqaiDataKitchen,
150
            class_names: List[str]
151
    ):
152
        self.init_class_names_to_index_mapping(class_names)
1✔
153
        self.encode_class_names(data_dictionary, dk, class_names)
1✔
154

155
    def get_class_names(self) -> List[str]:
1✔
156
        if not self.class_names:
1✔
157
            raise ValueError(
×
158
                "self.class_names is empty, "
159
                "set self.freqai.class_names = ['class a', 'class b', 'class c'] "
160
                "inside IStrategy.set_freqai_targets method."
161
            )
162

163
        return self.class_names
1✔
164

165
    def train(
1✔
166
        self, unfiltered_df: DataFrame, pair: str, dk: FreqaiDataKitchen, **kwargs
167
    ) -> Any:
168
        """
169
        Filter the training data and train a model to it. Train makes heavy use of the datakitchen
170
        for storing, saving, loading, and analyzing the data.
171
        :param unfiltered_df: Full dataframe for the current training period
172
        :return:
173
        :model: Trained model which can be used to inference (self.predict)
174
        """
175

176
        logger.info(f"-------------------- Starting training {pair} --------------------")
1✔
177

178
        start_time = time()
1✔
179

180
        features_filtered, labels_filtered = dk.filter_features(
1✔
181
            unfiltered_df,
182
            dk.training_features_list,
183
            dk.label_list,
184
            training_filter=True,
185
        )
186

187
        # split data into train/test data.
188
        dd = dk.make_train_test_datasets(features_filtered, labels_filtered)
1✔
189
        if not self.freqai_info.get("fit_live_predictions_candles", 0) or not self.live:
1✔
190
            dk.fit_labels()
1✔
191

192
        dk.feature_pipeline = self.define_data_pipeline(threads=dk.thread_count)
1✔
193

194
        (dd["train_features"],
1✔
195
         dd["train_labels"],
196
         dd["train_weights"]) = dk.feature_pipeline.fit_transform(dd["train_features"],
197
                                                                  dd["train_labels"],
198
                                                                  dd["train_weights"])
199

200
        if self.freqai_info.get('data_split_parameters', {}).get('test_size', 0.1) != 0:
1✔
201
            (dd["test_features"],
1✔
202
             dd["test_labels"],
203
             dd["test_weights"]) = dk.feature_pipeline.transform(dd["test_features"],
204
                                                                 dd["test_labels"],
205
                                                                 dd["test_weights"])
206

207
        logger.info(
1✔
208
            f"Training model on {len(dk.data_dictionary['train_features'].columns)} features"
209
        )
210
        logger.info(f"Training model on {len(dd['train_features'])} data points")
1✔
211

212
        model = self.fit(dd, dk)
1✔
213
        end_time = time()
1✔
214

215
        logger.info(f"-------------------- Done training {pair} "
1✔
216
                    f"({end_time - start_time:.2f} secs) --------------------")
217

218
        return model
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc