• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

freqtrade / freqtrade / 9394559170

26 Apr 2024 06:36AM UTC coverage: 94.656% (-0.02%) from 94.674%
9394559170

push

github

xmatthias
Loader should be passed as kwarg for clarity

20280 of 21425 relevant lines covered (94.66%)

0.95 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

81.16
/freqtrade/freqai/prediction_models/ReinforcementLearner.py
1
import logging
1✔
2
from pathlib import Path
1✔
3
from typing import Any, Dict, List, Optional, Type
1✔
4

5
import torch as th
1✔
6
from stable_baselines3.common.callbacks import ProgressBarCallback
1✔
7

8
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
1✔
9
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv, Positions
1✔
10
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment
1✔
11
from freqtrade.freqai.RL.BaseReinforcementLearningModel import BaseReinforcementLearningModel
1✔
12

13

14
logger = logging.getLogger(__name__)
1✔
15

16

17
class ReinforcementLearner(BaseReinforcementLearningModel):
1✔
18
    """
19
    Reinforcement Learning Model prediction model.
20

21
    Users can inherit from this class to make their own RL model with custom
22
    environment/training controls. Define the file as follows:
23

24
    ```
25
    from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner
26

27
    class MyCoolRLModel(ReinforcementLearner):
28
    ```
29

30
    Save the file to `user_data/freqaimodels`, then run it with:
31

32
    freqtrade trade --freqaimodel MyCoolRLModel --config config.json --strategy SomeCoolStrat
33

34
    Here the users can override any of the functions
35
    available in the `IFreqaiModel` inheritance tree. Most importantly for RL, this
36
    is where the user overrides `MyRLEnv` (see below), to define custom
37
    `calculate_reward()` function, or to override any other parts of the environment.
38

39
    This class also allows users to override any other part of the IFreqaiModel tree.
40
    For example, the user can override `def fit()` or `def train()` or `def predict()`
41
    to take fine-tuned control over these processes.
42

43
    Another common override may be `def data_cleaning_predict()` where the user can
44
    take fine-tuned control over the data handling pipeline.
45
    """
46

47
    def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
1✔
48
        """
49
        User customizable fit method
50
        :param data_dictionary: dict = common data dictionary containing all train/test
51
            features/labels/weights.
52
        :param dk: FreqaiDatakitchen = data kitchen for current pair.
53
        :return:
54
        model Any = trained model to be used for inference in dry/live/backtesting
55
        """
56
        train_df = data_dictionary["train_features"]
1✔
57
        total_timesteps = self.freqai_info["rl_config"]["train_cycles"] * len(train_df)
1✔
58

59
        policy_kwargs = dict(activation_fn=th.nn.ReLU,
1✔
60
                             net_arch=self.net_arch)
61

62
        if self.activate_tensorboard:
1✔
63
            tb_path = Path(dk.full_path / "tensorboard" / dk.pair.split('/')[0])
1✔
64
        else:
65
            tb_path = None
×
66

67
        if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
1✔
68
            model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs,
1✔
69
                                    tensorboard_log=tb_path,
70
                                    **self.freqai_info.get('model_training_parameters', {})
71
                                    )
72
        else:
73
            logger.info('Continual training activated - starting training from previously '
×
74
                        'trained agent.')
75
            model = self.dd.model_dictionary[dk.pair]
×
76
            model.set_env(self.train_env)
×
77
        callbacks: List[Any] = [self.eval_callback, self.tensorboard_callback]
1✔
78
        progressbar_callback: Optional[ProgressBarCallback] = None
1✔
79
        if self.rl_config.get('progress_bar', False):
1✔
80
            progressbar_callback = ProgressBarCallback()
×
81
            callbacks.insert(0, progressbar_callback)
×
82

83
        try:
1✔
84
            model.learn(
1✔
85
                total_timesteps=int(total_timesteps),
86
                callback=callbacks,
87
            )
88
        finally:
89
            if progressbar_callback:
1✔
90
                progressbar_callback.on_training_end()
×
91

92
        if Path(dk.data_path / "best_model.zip").is_file():
1✔
93
            logger.info('Callback found a best model.')
1✔
94
            best_model = self.MODELCLASS.load(dk.data_path / "best_model")
1✔
95
            return best_model
1✔
96

97
        logger.info("Couldn't find best model, using final model instead.")
×
98

99
        return model
×
100

101
    MyRLEnv: Type[BaseEnvironment]
1✔
102

103
    class MyRLEnv(Base5ActionRLEnv):  # type: ignore[no-redef]
1✔
104
        """
105
        User can override any function in BaseRLEnv and gym.Env. Here the user
106
        sets a custom reward based on profit and trade duration.
107
        """
108

109
        def calculate_reward(self, action: int) -> float:
1✔
110
            """
111
            An example reward function. This is the one function that users will likely
112
            wish to inject their own creativity into.
113

114
                        Warning!
115
            This is function is a showcase of functionality designed to show as many possible
116
            environment control features as possible. It is also designed to run quickly
117
            on small computers. This is a benchmark, it is *not* for live production.
118

119
            :param action: int = The action made by the agent for the current candle.
120
            :return:
121
            float = the reward to give to the agent for current step (used for optimization
122
                of weights in NN)
123
            """
124
            # first, penalize if the action is not valid
125
            if not self._is_valid(action):
1✔
126
                self.tensorboard_log("invalid", category="actions")
1✔
127
                return -2
1✔
128

129
            pnl = self.get_unrealized_profit()
1✔
130
            factor = 100.
1✔
131

132
            # reward agent for entering trades
133
            if (action == Actions.Long_enter.value
1✔
134
                    and self._position == Positions.Neutral):
135
                return 25
1✔
136
            if (action == Actions.Short_enter.value
1✔
137
                    and self._position == Positions.Neutral):
138
                return 25
1✔
139
            # discourage agent from not entering trades
140
            if action == Actions.Neutral.value and self._position == Positions.Neutral:
1✔
141
                return -1
1✔
142

143
            max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300)
1✔
144
            trade_duration = self._current_tick - self._last_trade_tick  # type: ignore
1✔
145

146
            if trade_duration <= max_trade_duration:
1✔
147
                factor *= 1.5
1✔
148
            elif trade_duration > max_trade_duration:
×
149
                factor *= 0.5
×
150

151
            # discourage sitting in position
152
            if (self._position in (Positions.Short, Positions.Long) and
1✔
153
                    action == Actions.Neutral.value):
154
                return -1 * trade_duration / max_trade_duration
1✔
155

156
            # close long
157
            if action == Actions.Long_exit.value and self._position == Positions.Long:
1✔
158
                if pnl > self.profit_aim * self.rr:
1✔
159
                    factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
×
160
                return float(pnl * factor)
1✔
161

162
            # close short
163
            if action == Actions.Short_exit.value and self._position == Positions.Short:
1✔
164
                if pnl > self.profit_aim * self.rr:
1✔
165
                    factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
×
166
                return float(pnl * factor)
1✔
167

168
            return 0.
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc