• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

freqtrade / freqtrade / 4131167254

pending completion
4131167254

push

github-actions

GitHub
Merge pull request #7983 from stash86/bt-metrics

16866 of 17748 relevant lines covered (95.03%)

0.95 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

83.64
/freqtrade/freqai/prediction_models/ReinforcementLearner.py
1
import logging
1✔
2
from pathlib import Path
1✔
3
from typing import Any, Dict
1✔
4

5
import torch as th
1✔
6

7
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
1✔
8
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv, Positions
1✔
9
from freqtrade.freqai.RL.BaseReinforcementLearningModel import BaseReinforcementLearningModel
1✔
10

11

12
logger = logging.getLogger(__name__)
1✔
13

14

15
class ReinforcementLearner(BaseReinforcementLearningModel):
1✔
16
    """
17
    Reinforcement Learning Model prediction model.
18

19
    Users can inherit from this class to make their own RL model with custom
20
    environment/training controls. Define the file as follows:
21

22
    ```
23
    from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner
24

25
    class MyCoolRLModel(ReinforcementLearner):
26
    ```
27

28
    Save the file to `user_data/freqaimodels`, then run it with:
29

30
    freqtrade trade --freqaimodel MyCoolRLModel --config config.json --strategy SomeCoolStrat
31

32
    Here the users can override any of the functions
33
    available in the `IFreqaiModel` inheritance tree. Most importantly for RL, this
34
    is where the user overrides `MyRLEnv` (see below), to define custom
35
    `calculate_reward()` function, or to override any other parts of the environment.
36

37
    This class also allows users to override any other part of the IFreqaiModel tree.
38
    For example, the user can override `def fit()` or `def train()` or `def predict()`
39
    to take fine-tuned control over these processes.
40

41
    Another common override may be `def data_cleaning_predict()` where the user can
42
    take fine-tuned control over the data handling pipeline.
43
    """
44

45
    def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
1✔
46
        """
47
        User customizable fit method
48
        :param data_dictionary: dict = common data dictionary containing all train/test
49
            features/labels/weights.
50
        :param dk: FreqaiDatakitchen = data kitchen for current pair.
51
        :return:
52
        model Any = trained model to be used for inference in dry/live/backtesting
53
        """
54
        train_df = data_dictionary["train_features"]
1✔
55
        total_timesteps = self.freqai_info["rl_config"]["train_cycles"] * len(train_df)
1✔
56

57
        policy_kwargs = dict(activation_fn=th.nn.ReLU,
1✔
58
                             net_arch=self.net_arch)
59

60
        if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
1✔
61
            model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs,
1✔
62
                                    tensorboard_log=Path(
63
                                        dk.full_path / "tensorboard" / dk.pair.split('/')[0]),
64
                                    **self.freqai_info.get('model_training_parameters', {})
65
                                    )
66
        else:
67
            logger.info('Continual training activated - starting training from previously '
×
68
                        'trained agent.')
69
            model = self.dd.model_dictionary[dk.pair]
×
70
            model.set_env(self.train_env)
×
71

72
        model.learn(
1✔
73
            total_timesteps=int(total_timesteps),
74
            callback=[self.eval_callback, self.tensorboard_callback]
75
        )
76

77
        if Path(dk.data_path / "best_model.zip").is_file():
1✔
78
            logger.info('Callback found a best model.')
1✔
79
            best_model = self.MODELCLASS.load(dk.data_path / "best_model")
1✔
80
            return best_model
1✔
81

82
        logger.info('Couldnt find best model, using final model instead.')
×
83

84
        return model
×
85

86
    class MyRLEnv(Base5ActionRLEnv):
1✔
87
        """
88
        User can override any function in BaseRLEnv and gym.Env. Here the user
89
        sets a custom reward based on profit and trade duration.
90
        """
91

92
        def calculate_reward(self, action: int) -> float:
1✔
93
            """
94
            An example reward function. This is the one function that users will likely
95
            wish to inject their own creativity into.
96
            :param action: int = The action made by the agent for the current candle.
97
            :return:
98
            float = the reward to give to the agent for current step (used for optimization
99
                of weights in NN)
100
            """
101
            # first, penalize if the action is not valid
102
            if not self._is_valid(action):
1✔
103
                self.tensorboard_log("is_valid")
1✔
104
                return -2
1✔
105

106
            pnl = self.get_unrealized_profit()
1✔
107
            factor = 100.
1✔
108

109
            # reward agent for entering trades
110
            if (action == Actions.Long_enter.value
1✔
111
                    and self._position == Positions.Neutral):
112
                return 25
1✔
113
            if (action == Actions.Short_enter.value
1✔
114
                    and self._position == Positions.Neutral):
115
                return 25
1✔
116
            # discourage agent from not entering trades
117
            if action == Actions.Neutral.value and self._position == Positions.Neutral:
1✔
118
                return -1
1✔
119

120
            max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300)
1✔
121
            trade_duration = self._current_tick - self._last_trade_tick  # type: ignore
1✔
122

123
            if trade_duration <= max_trade_duration:
1✔
124
                factor *= 1.5
1✔
125
            elif trade_duration > max_trade_duration:
×
126
                factor *= 0.5
×
127

128
            # discourage sitting in position
129
            if (self._position in (Positions.Short, Positions.Long) and
1✔
130
                    action == Actions.Neutral.value):
131
                return -1 * trade_duration / max_trade_duration
1✔
132

133
            # close long
134
            if action == Actions.Long_exit.value and self._position == Positions.Long:
1✔
135
                if pnl > self.profit_aim * self.rr:
1✔
136
                    factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
×
137
                return float(pnl * factor)
1✔
138

139
            # close short
140
            if action == Actions.Short_exit.value and self._position == Positions.Short:
1✔
141
                if pnl > self.profit_aim * self.rr:
1✔
142
                    factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
×
143
                return float(pnl * factor)
1✔
144

145
            return 0.
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc