• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

freqtrade / freqtrade / 4131167254

pending completion
4131167254

push

github-actions

GitHub
Merge pull request #7983 from stash86/bt-metrics

16866 of 17748 relevant lines covered (95.03%)

0.95 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

73.98
/freqtrade/freqai/RL/BaseReinforcementLearningModel.py
1
import importlib
1✔
2
import logging
1✔
3
from abc import abstractmethod
1✔
4
from datetime import datetime, timezone
1✔
5
from pathlib import Path
1✔
6
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
1✔
7

8
import gym
1✔
9
import numpy as np
1✔
10
import numpy.typing as npt
1✔
11
import pandas as pd
1✔
12
import torch as th
1✔
13
import torch.multiprocessing
1✔
14
from pandas import DataFrame
1✔
15
from stable_baselines3.common.callbacks import EvalCallback
1✔
16
from stable_baselines3.common.monitor import Monitor
1✔
17
from stable_baselines3.common.utils import set_random_seed
1✔
18
from stable_baselines3.common.vec_env import SubprocVecEnv
1✔
19

20
from freqtrade.exceptions import OperationalException
1✔
21
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
1✔
22
from freqtrade.freqai.freqai_interface import IFreqaiModel
1✔
23
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv
1✔
24
from freqtrade.freqai.RL.BaseEnvironment import BaseActions, Positions
1✔
25
from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback
1✔
26
from freqtrade.persistence import Trade
1✔
27

28

29
logger = logging.getLogger(__name__)
1✔
30

31
torch.multiprocessing.set_sharing_strategy('file_system')
1✔
32

33
SB3_MODELS = ['PPO', 'A2C', 'DQN']
1✔
34
SB3_CONTRIB_MODELS = ['TRPO', 'ARS', 'RecurrentPPO', 'MaskablePPO']
1✔
35

36

37
class BaseReinforcementLearningModel(IFreqaiModel):
1✔
38
    """
39
    User created Reinforcement Learning Model prediction class
40
    """
41

42
    def __init__(self, **kwargs) -> None:
1✔
43
        super().__init__(config=kwargs['config'])
1✔
44
        self.max_threads = min(self.freqai_info['rl_config'].get(
1✔
45
            'cpu_count', 1), max(int(self.max_system_threads / 2), 1))
46
        th.set_num_threads(self.max_threads)
1✔
47
        self.reward_params = self.freqai_info['rl_config']['model_reward_parameters']
1✔
48
        self.train_env: Union[SubprocVecEnv, Type[gym.Env]] = gym.Env()
1✔
49
        self.eval_env: Union[SubprocVecEnv, Type[gym.Env]] = gym.Env()
1✔
50
        self.eval_callback: Optional[EvalCallback] = None
1✔
51
        self.model_type = self.freqai_info['rl_config']['model_type']
1✔
52
        self.rl_config = self.freqai_info['rl_config']
1✔
53
        self.continual_learning = self.freqai_info.get('continual_learning', False)
1✔
54
        if self.model_type in SB3_MODELS:
1✔
55
            import_str = 'stable_baselines3'
1✔
56
        elif self.model_type in SB3_CONTRIB_MODELS:
×
57
            import_str = 'sb3_contrib'
×
58
        else:
59
            raise OperationalException(f'{self.model_type} not available in stable_baselines3 or '
×
60
                                       f'sb3_contrib. please choose one of {SB3_MODELS} or '
61
                                       f'{SB3_CONTRIB_MODELS}')
62

63
        mod = importlib.import_module(import_str, self.model_type)
1✔
64
        self.MODELCLASS = getattr(mod, self.model_type)
1✔
65
        self.policy_type = self.freqai_info['rl_config']['policy_type']
1✔
66
        self.unset_outlier_removal()
1✔
67
        self.net_arch = self.rl_config.get('net_arch', [128, 128])
1✔
68
        self.dd.model_type = import_str
1✔
69
        self.tensorboard_callback: TensorboardCallback = \
1✔
70
            TensorboardCallback(verbose=1, actions=BaseActions)
71

72
    def unset_outlier_removal(self):
1✔
73
        """
74
        If user has activated any function that may remove training points, this
75
        function will set them to false and warn them
76
        """
77
        if self.ft_params.get('use_SVM_to_remove_outliers', False):
1✔
78
            self.ft_params.update({'use_SVM_to_remove_outliers': False})
1✔
79
            logger.warning('User tried to use SVM with RL. Deactivating SVM.')
1✔
80
        if self.ft_params.get('use_DBSCAN_to_remove_outliers', False):
1✔
81
            self.ft_params.update({'use_DBSCAN_to_remove_outliers': False})
1✔
82
            logger.warning('User tried to use DBSCAN with RL. Deactivating DBSCAN.')
1✔
83
        if self.freqai_info['data_split_parameters'].get('shuffle', False):
1✔
84
            self.freqai_info['data_split_parameters'].update({'shuffle': False})
1✔
85
            logger.warning('User tried to shuffle training data. Setting shuffle to False')
1✔
86

87
    def train(
1✔
88
        self, unfiltered_df: DataFrame, pair: str, dk: FreqaiDataKitchen, **kwargs
89
    ) -> Any:
90
        """
91
        Filter the training data and train a model to it. Train makes heavy use of the datakitchen
92
        for storing, saving, loading, and analyzing the data.
93
        :param unfiltered_df: Full dataframe for the current training period
94
        :param metadata: pair metadata from strategy.
95
        :returns:
96
        :model: Trained model which can be used to inference (self.predict)
97
        """
98

99
        logger.info("--------------------Starting training " f"{pair} --------------------")
1✔
100

101
        features_filtered, labels_filtered = dk.filter_features(
1✔
102
            unfiltered_df,
103
            dk.training_features_list,
104
            dk.label_list,
105
            training_filter=True,
106
        )
107

108
        data_dictionary: Dict[str, Any] = dk.make_train_test_datasets(
1✔
109
            features_filtered, labels_filtered)
110
        dk.fit_labels()  # FIXME useless for now, but just satiating append methods
1✔
111

112
        # normalize all data based on train_dataset only
113
        prices_train, prices_test = self.build_ohlc_price_dataframes(dk.data_dictionary, pair, dk)
1✔
114
        data_dictionary = dk.normalize_data(data_dictionary)
1✔
115

116
        # data cleaning/analysis
117
        self.data_cleaning_train(dk)
1✔
118

119
        logger.info(
1✔
120
            f'Training model on {len(dk.data_dictionary["train_features"].columns)}'
121
            f' features and {len(data_dictionary["train_features"])} data points'
122
        )
123

124
        self.set_train_and_eval_environments(data_dictionary, prices_train, prices_test, dk)
1✔
125

126
        model = self.fit(data_dictionary, dk)
1✔
127

128
        logger.info(f"--------------------done training {pair}--------------------")
1✔
129

130
        return model
1✔
131

132
    def set_train_and_eval_environments(self, data_dictionary: Dict[str, DataFrame],
1✔
133
                                        prices_train: DataFrame, prices_test: DataFrame,
134
                                        dk: FreqaiDataKitchen):
135
        """
136
        User can override this if they are using a custom MyRLEnv
137
        :param data_dictionary: dict = common data dictionary containing train and test
138
            features/labels/weights.
139
        :param prices_train/test: DataFrame = dataframe comprised of the prices to be used in the
140
            environment during training or testing
141
        :param dk: FreqaiDataKitchen = the datakitchen for the current pair
142
        """
143
        train_df = data_dictionary["train_features"]
1✔
144
        test_df = data_dictionary["test_features"]
1✔
145

146
        env_info = self.pack_env_dict()
1✔
147

148
        self.train_env = self.MyRLEnv(df=train_df,
1✔
149
                                      prices=prices_train,
150
                                      **env_info)
151
        self.eval_env = Monitor(self.MyRLEnv(df=test_df,
1✔
152
                                             prices=prices_test,
153
                                             **env_info))
154
        self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
1✔
155
                                          render=False, eval_freq=len(train_df),
156
                                          best_model_save_path=str(dk.data_path))
157

158
        actions = self.train_env.get_actions()
1✔
159
        self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions)
1✔
160

161
    def pack_env_dict(self) -> Dict[str, Any]:
1✔
162
        """
163
        Create dictionary of environment arguments
164
        """
165
        env_info = {"window_size": self.CONV_WIDTH,
1✔
166
                    "reward_kwargs": self.reward_params,
167
                    "config": self.config,
168
                    "live": self.live,
169
                    "can_short": self.can_short}
170
        if self.data_provider:
1✔
171
            env_info["fee"] = self.data_provider._exchange \
×
172
                .get_fee(symbol=self.data_provider.current_whitelist()[0])  # type: ignore
173

174
        return env_info
1✔
175

176
    @abstractmethod
1✔
177
    def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
1✔
178
        """
179
        Agent customizations and abstract Reinforcement Learning customizations
180
        go in here. Abstract method, so this function must be overridden by
181
        user class.
182
        """
183
        return
×
184

185
    def get_state_info(self, pair: str) -> Tuple[float, float, int]:
1✔
186
        """
187
        State info during dry/live (not backtesting) which is fed back
188
        into the model.
189
        :param pair: str = COIN/STAKE to get the environment information for
190
        :return:
191
        :market_side: float = representing short, long, or neutral for
192
            pair
193
        :current_profit: float = unrealized profit of the current trade
194
        :trade_duration: int = the number of candles that the trade has
195
            been open for
196
        """
197
        open_trades = Trade.get_trades_proxy(is_open=True)
1✔
198
        market_side = 0.5
1✔
199
        current_profit: float = 0
1✔
200
        trade_duration = 0
1✔
201
        for trade in open_trades:
1✔
202
            if trade.pair == pair:
1✔
203
                if self.data_provider._exchange is None:  # type: ignore
1✔
204
                    logger.error('No exchange available.')
1✔
205
                    return 0, 0, 0
1✔
206
                else:
207
                    current_rate = self.data_provider._exchange.get_rate(  # type: ignore
1✔
208
                                pair, refresh=False, side="exit", is_short=trade.is_short)
209

210
                now = datetime.now(timezone.utc).timestamp()
1✔
211
                trade_duration = int((now - trade.open_date_utc.timestamp()) / self.base_tf_seconds)
1✔
212
                current_profit = trade.calc_profit_ratio(current_rate)
1✔
213
                if trade.is_short:
1✔
214
                    market_side = 0
×
215
                else:
216
                    market_side = 1
1✔
217

218
        return market_side, current_profit, int(trade_duration)
1✔
219

220
    def predict(
1✔
221
        self, unfiltered_df: DataFrame, dk: FreqaiDataKitchen, **kwargs
222
    ) -> Tuple[DataFrame, npt.NDArray[np.int_]]:
223
        """
224
        Filter the prediction features data and predict with it.
225
        :param unfiltered_dataframe: Full dataframe for the current backtest period.
226
        :return:
227
        :pred_df: dataframe containing the predictions
228
        :do_predict: np.array of 1s and 0s to indicate places where freqai needed to remove
229
        data (NaNs) or felt uncertain about data (PCA and DI index)
230
        """
231

232
        dk.find_features(unfiltered_df)
1✔
233
        filtered_dataframe, _ = dk.filter_features(
1✔
234
            unfiltered_df, dk.training_features_list, training_filter=False
235
        )
236
        filtered_dataframe = dk.normalize_data_from_metadata(filtered_dataframe)
1✔
237
        dk.data_dictionary["prediction_features"] = filtered_dataframe
1✔
238

239
        # optional additional data cleaning/analysis
240
        self.data_cleaning_predict(dk)
1✔
241

242
        pred_df = self.rl_model_predict(
1✔
243
            dk.data_dictionary["prediction_features"], dk, self.model)
244
        pred_df.fillna(0, inplace=True)
1✔
245

246
        return (pred_df, dk.do_predict)
1✔
247

248
    def rl_model_predict(self, dataframe: DataFrame,
1✔
249
                         dk: FreqaiDataKitchen, model: Any) -> DataFrame:
250
        """
251
        A helper function to make predictions in the Reinforcement learning module.
252
        :param dataframe: DataFrame = the dataframe of features to make the predictions on
253
        :param dk: FreqaiDatakitchen = data kitchen for the current pair
254
        :param model: Any = the trained model used to inference the features.
255
        """
256
        output = pd.DataFrame(np.zeros(len(dataframe)), columns=dk.label_list)
1✔
257

258
        def _predict(window):
1✔
259
            observations = dataframe.iloc[window.index]
1✔
260
            if self.live and self.rl_config.get('add_state_info', False):
1✔
261
                market_side, current_profit, trade_duration = self.get_state_info(dk.pair)
×
262
                observations['current_profit_pct'] = current_profit
×
263
                observations['position'] = market_side
×
264
                observations['trade_duration'] = trade_duration
×
265
            res, _ = model.predict(observations, deterministic=True)
1✔
266
            return res
1✔
267

268
        output = output.rolling(window=self.CONV_WIDTH).apply(_predict)
1✔
269

270
        return output
1✔
271

272
    def build_ohlc_price_dataframes(self, data_dictionary: dict,
1✔
273
                                    pair: str, dk: FreqaiDataKitchen) -> Tuple[DataFrame,
274
                                                                               DataFrame]:
275
        """
276
        Builds the train prices and test prices for the environment.
277
        """
278

279
        pair = pair.replace(':', '')
1✔
280
        train_df = data_dictionary["train_features"]
1✔
281
        test_df = data_dictionary["test_features"]
1✔
282

283
        # price data for model training and evaluation
284
        tf = self.config['timeframe']
1✔
285
        ohlc_list = [f'%-{pair}raw_open_{tf}', f'%-{pair}raw_low_{tf}',
1✔
286
                     f'%-{pair}raw_high_{tf}', f'%-{pair}raw_close_{tf}']
287
        rename_dict = {f'%-{pair}raw_open_{tf}': 'open', f'%-{pair}raw_low_{tf}': 'low',
1✔
288
                       f'%-{pair}raw_high_{tf}': ' high', f'%-{pair}raw_close_{tf}': 'close'}
289

290
        prices_train = train_df.filter(ohlc_list, axis=1)
1✔
291
        if prices_train.empty:
1✔
292
            raise OperationalException('Reinforcement learning module didnt find the raw prices '
×
293
                                       'assigned in populate_any_indicators. Please assign them '
294
                                       'with:\n'
295
                                       'informative[f"%-{pair}raw_close"] = informative["close"]\n'
296
                                       'informative[f"%-{pair}raw_open"] = informative["open"]\n'
297
                                       'informative[f"%-{pair}raw_high"] = informative["high"]\n'
298
                                       'informative[f"%-{pair}raw_low"] = informative["low"]\n')
299
        prices_train.rename(columns=rename_dict, inplace=True)
1✔
300
        prices_train.reset_index(drop=True)
1✔
301

302
        prices_test = test_df.filter(ohlc_list, axis=1)
1✔
303
        prices_test.rename(columns=rename_dict, inplace=True)
1✔
304
        prices_test.reset_index(drop=True)
1✔
305

306
        return prices_train, prices_test
1✔
307

308
    def load_model_from_disk(self, dk: FreqaiDataKitchen) -> Any:
1✔
309
        """
310
        Can be used by user if they are trying to limit_ram_usage *and*
311
        perform continual learning.
312
        For now, this is unused.
313
        """
314
        exists = Path(dk.data_path / f"{dk.model_filename}_model").is_file()
×
315
        if exists:
×
316
            model = self.MODELCLASS.load(dk.data_path / f"{dk.model_filename}_model")
×
317
        else:
318
            logger.info('No model file on disk to continue learning from.')
×
319

320
        return model
×
321

322
    def _on_stop(self):
1✔
323
        """
324
        Hook called on bot shutdown. Close SubprocVecEnv subprocesses for clean shutdown.
325
        """
326

327
        if self.train_env:
×
328
            self.train_env.close()
×
329

330
        if self.eval_env:
×
331
            self.eval_env.close()
×
332

333
    # Nested class which can be overridden by user to customize further
334
    class MyRLEnv(Base5ActionRLEnv):
1✔
335
        """
336
        User can override any function in BaseRLEnv and gym.Env. Here the user
337
        sets a custom reward based on profit and trade duration.
338
        """
339

340
        def calculate_reward(self, action: int) -> float:
1✔
341
            """
342
            An example reward function. This is the one function that users will likely
343
            wish to inject their own creativity into.
344
            :param action: int = The action made by the agent for the current candle.
345
            :return:
346
            float = the reward to give to the agent for current step (used for optimization
347
                of weights in NN)
348
            """
349
            # first, penalize if the action is not valid
350
            if not self._is_valid(action):
×
351
                return -2
×
352

353
            pnl = self.get_unrealized_profit()
×
354
            factor = 100.
×
355

356
            # reward agent for entering trades
357
            if (action in (Actions.Long_enter.value, Actions.Short_enter.value)
×
358
                    and self._position == Positions.Neutral):
359
                return 25
×
360
            # discourage agent from not entering trades
361
            if action == Actions.Neutral.value and self._position == Positions.Neutral:
×
362
                return -1
×
363

364
            max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300)
×
365
            if self._last_trade_tick:
×
366
                trade_duration = self._current_tick - self._last_trade_tick
×
367
            else:
368
                trade_duration = 0
×
369

370
            if trade_duration <= max_trade_duration:
×
371
                factor *= 1.5
×
372
            elif trade_duration > max_trade_duration:
×
373
                factor *= 0.5
×
374

375
            # discourage sitting in position
376
            if (self._position in (Positions.Short, Positions.Long) and
×
377
               action == Actions.Neutral.value):
378
                return -1 * trade_duration / max_trade_duration
×
379

380
            # close long
381
            if action == Actions.Long_exit.value and self._position == Positions.Long:
×
382
                if pnl > self.profit_aim * self.rr:
×
383
                    factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
×
384
                return float(pnl * factor)
×
385

386
            # close short
387
            if action == Actions.Short_exit.value and self._position == Positions.Short:
×
388
                if pnl > self.profit_aim * self.rr:
×
389
                    factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
×
390
                return float(pnl * factor)
×
391

392
            return 0.
×
393

394

395
def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int,
1✔
396
             seed: int, train_df: DataFrame, price: DataFrame,
397
             monitor: bool = False,
398
             env_info: Dict[str, Any] = {}) -> Callable:
399
    """
400
    Utility function for multiprocessed env.
401

402
    :param env_id: (str) the environment ID
403
    :param num_env: (int) the number of environment you wish to have in subprocesses
404
    :param seed: (int) the inital seed for RNG
405
    :param rank: (int) index of the subprocess
406
    :param env_info: (dict) all required arguments to instantiate the environment.
407
    :return: (Callable)
408
    """
409

410
    def _init() -> gym.Env:
1✔
411

412
        env = MyRLEnv(df=train_df, prices=price, id=env_id, seed=seed + rank,
×
413
                      **env_info)
414
        if monitor:
×
415
            env = Monitor(env)
×
416
        return env
×
417
    set_random_seed(seed)
1✔
418
    return _init
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc