• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

freqtrade / freqtrade / 9394559170

26 Apr 2024 06:36AM UTC coverage: 94.656% (-0.02%) from 94.674%
9394559170

push

github

xmatthias
Loader should be passed as kwarg for clarity

20280 of 21425 relevant lines covered (94.66%)

0.95 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

77.5
/freqtrade/freqai/prediction_models/XGBoostRFClassifier.py
1
import logging
1✔
2
from typing import Any, Dict, Tuple
1✔
3

4
import numpy as np
1✔
5
import numpy.typing as npt
1✔
6
import pandas as pd
1✔
7
from pandas import DataFrame
1✔
8
from pandas.api.types import is_integer_dtype
1✔
9
from sklearn.preprocessing import LabelEncoder
1✔
10
from xgboost import XGBRFClassifier
1✔
11

12
from freqtrade.freqai.base_models.BaseClassifierModel import BaseClassifierModel
1✔
13
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
1✔
14

15

16
logger = logging.getLogger(__name__)
1✔
17

18

19
class XGBoostRFClassifier(BaseClassifierModel):
1✔
20
    """
21
    User created prediction model. The class inherits IFreqaiModel, which
22
    means it has full access to all Frequency AI functionality. Typically,
23
    users would use this to override the common `fit()`, `train()`, or
24
    `predict()` methods to add their custom data handling tools or change
25
    various aspects of the training that cannot be configured via the
26
    top level config.json file.
27
    """
28

29
    def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
1✔
30
        """
31
        User sets up the training and test data to fit their desired model here
32
        :param data_dictionary: the dictionary holding all data for train, test,
33
            labels, weights
34
        :param dk: The datakitchen object for the current coin/model
35
        """
36

37
        X = data_dictionary["train_features"].to_numpy()
1✔
38
        y = data_dictionary["train_labels"].to_numpy()[:, 0]
1✔
39

40
        le = LabelEncoder()
1✔
41
        if not is_integer_dtype(y):
1✔
42
            y = pd.Series(le.fit_transform(y), dtype="int64")
1✔
43

44
        if self.freqai_info.get('data_split_parameters', {}).get('test_size', 0.1) == 0:
1✔
45
            eval_set = None
×
46
        else:
47
            test_features = data_dictionary["test_features"].to_numpy()
1✔
48
            test_labels = data_dictionary["test_labels"].to_numpy()[:, 0]
1✔
49

50
            if not is_integer_dtype(test_labels):
1✔
51
                test_labels = pd.Series(le.transform(test_labels), dtype="int64")
1✔
52

53
            eval_set = [(test_features, test_labels)]
1✔
54

55
        train_weights = data_dictionary["train_weights"]
1✔
56

57
        init_model = self.get_init_model(dk.pair)
1✔
58

59
        model = XGBRFClassifier(**self.model_training_parameters)
1✔
60

61
        model.fit(X=X, y=y, eval_set=eval_set, sample_weight=train_weights,
1✔
62
                  xgb_model=init_model)
63

64
        return model
1✔
65

66
    def predict(
1✔
67
        self, unfiltered_df: DataFrame, dk: FreqaiDataKitchen, **kwargs
68
    ) -> Tuple[DataFrame, npt.NDArray[np.int_]]:
69
        """
70
        Filter the prediction features data and predict with it.
71
        :param  unfiltered_df: Full dataframe for the current backtest period.
72
        :return:
73
        :pred_df: dataframe containing the predictions
74
        :do_predict: np.array of 1s and 0s to indicate places where freqai needed to remove
75
        data (NaNs) or felt uncertain about data (PCA and DI index)
76
        """
77

78
        (pred_df, dk.do_predict) = super().predict(unfiltered_df, dk, **kwargs)
×
79

80
        le = LabelEncoder()
×
81
        label = dk.label_list[0]
×
82
        labels_before = list(dk.data['labels_std'].keys())
×
83
        labels_after = le.fit_transform(labels_before).tolist()
×
84
        pred_df[label] = le.inverse_transform(pred_df[label])
×
85
        pred_df = pred_df.rename(
×
86
            columns={labels_after[i]: labels_before[i] for i in range(len(labels_before))})
87

88
        return (pred_df, dk.do_predict)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc