• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

freqtrade / freqtrade / 9394559170

26 Apr 2024 06:36AM UTC coverage: 94.656% (-0.02%) from 94.674%
9394559170

push

github

xmatthias
Loader should be passed as kwarg for clarity

20280 of 21425 relevant lines covered (94.66%)

0.95 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

74.21
/freqtrade/freqai/RL/BaseReinforcementLearningModel.py
1
import copy
1✔
2
import importlib
1✔
3
import logging
1✔
4
from abc import abstractmethod
1✔
5
from datetime import datetime, timezone
1✔
6
from pathlib import Path
1✔
7
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
1✔
8

9
import gymnasium as gym
1✔
10
import numpy as np
1✔
11
import numpy.typing as npt
1✔
12
import pandas as pd
1✔
13
import torch as th
1✔
14
import torch.multiprocessing
1✔
15
from pandas import DataFrame
1✔
16
from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback
1✔
17
from sb3_contrib.common.maskable.utils import is_masking_supported
1✔
18
from stable_baselines3.common.monitor import Monitor
1✔
19
from stable_baselines3.common.utils import set_random_seed
1✔
20
from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor
1✔
21

22
from freqtrade.exceptions import OperationalException
1✔
23
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
1✔
24
from freqtrade.freqai.freqai_interface import IFreqaiModel
1✔
25
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv
1✔
26
from freqtrade.freqai.RL.BaseEnvironment import BaseActions, BaseEnvironment, Positions
1✔
27
from freqtrade.freqai.tensorboard.TensorboardCallback import TensorboardCallback
1✔
28
from freqtrade.persistence import Trade
1✔
29

30

31
logger = logging.getLogger(__name__)
1✔
32

33
torch.multiprocessing.set_sharing_strategy('file_system')
1✔
34

35
SB3_MODELS = ['PPO', 'A2C', 'DQN']
1✔
36
SB3_CONTRIB_MODELS = ['TRPO', 'ARS', 'RecurrentPPO', 'MaskablePPO', 'QRDQN']
1✔
37

38

39
class BaseReinforcementLearningModel(IFreqaiModel):
1✔
40
    """
41
    User created Reinforcement Learning Model prediction class
42
    """
43

44
    def __init__(self, **kwargs) -> None:
1✔
45
        super().__init__(config=kwargs['config'])
1✔
46
        self.max_threads = min(self.freqai_info['rl_config'].get(
1✔
47
            'cpu_count', 1), max(int(self.max_system_threads / 2), 1))
48
        th.set_num_threads(self.max_threads)
1✔
49
        self.reward_params = self.freqai_info['rl_config']['model_reward_parameters']
1✔
50
        self.train_env: Union[VecMonitor, SubprocVecEnv, gym.Env] = gym.Env()
1✔
51
        self.eval_env: Union[VecMonitor, SubprocVecEnv, gym.Env] = gym.Env()
1✔
52
        self.eval_callback: Optional[MaskableEvalCallback] = None
1✔
53
        self.model_type = self.freqai_info['rl_config']['model_type']
1✔
54
        self.rl_config = self.freqai_info['rl_config']
1✔
55
        self.df_raw: DataFrame = DataFrame()
1✔
56
        self.continual_learning = self.freqai_info.get('continual_learning', False)
1✔
57
        if self.model_type in SB3_MODELS:
1✔
58
            import_str = 'stable_baselines3'
1✔
59
        elif self.model_type in SB3_CONTRIB_MODELS:
×
60
            import_str = 'sb3_contrib'
×
61
        else:
62
            raise OperationalException(f'{self.model_type} not available in stable_baselines3 or '
×
63
                                       f'sb3_contrib. please choose one of {SB3_MODELS} or '
64
                                       f'{SB3_CONTRIB_MODELS}')
65

66
        mod = importlib.import_module(import_str, self.model_type)
1✔
67
        self.MODELCLASS = getattr(mod, self.model_type)
1✔
68
        self.policy_type = self.freqai_info['rl_config']['policy_type']
1✔
69
        self.unset_outlier_removal()
1✔
70
        self.net_arch = self.rl_config.get('net_arch', [128, 128])
1✔
71
        self.dd.model_type = import_str
1✔
72
        self.tensorboard_callback: TensorboardCallback = \
1✔
73
            TensorboardCallback(verbose=1, actions=BaseActions)
74

75
    def unset_outlier_removal(self):
1✔
76
        """
77
        If user has activated any function that may remove training points, this
78
        function will set them to false and warn them
79
        """
80
        if self.ft_params.get('use_SVM_to_remove_outliers', False):
1✔
81
            self.ft_params.update({'use_SVM_to_remove_outliers': False})
1✔
82
            logger.warning('User tried to use SVM with RL. Deactivating SVM.')
1✔
83
        if self.ft_params.get('use_DBSCAN_to_remove_outliers', False):
1✔
84
            self.ft_params.update({'use_DBSCAN_to_remove_outliers': False})
1✔
85
            logger.warning('User tried to use DBSCAN with RL. Deactivating DBSCAN.')
1✔
86
        if self.ft_params.get('DI_threshold', False):
1✔
87
            self.ft_params.update({'DI_threshold': False})
1✔
88
            logger.warning('User tried to use DI_threshold with RL. Deactivating DI_threshold.')
1✔
89
        if self.freqai_info['data_split_parameters'].get('shuffle', False):
1✔
90
            self.freqai_info['data_split_parameters'].update({'shuffle': False})
1✔
91
            logger.warning('User tried to shuffle training data. Setting shuffle to False')
1✔
92

93
    def train(
1✔
94
        self, unfiltered_df: DataFrame, pair: str, dk: FreqaiDataKitchen, **kwargs
95
    ) -> Any:
96
        """
97
        Filter the training data and train a model to it. Train makes heavy use of the datakitchen
98
        for storing, saving, loading, and analyzing the data.
99
        :param unfiltered_df: Full dataframe for the current training period
100
        :param metadata: pair metadata from strategy.
101
        :returns:
102
        :model: Trained model which can be used to inference (self.predict)
103
        """
104

105
        logger.info("--------------------Starting training " f"{pair} --------------------")
1✔
106

107
        features_filtered, labels_filtered = dk.filter_features(
1✔
108
            unfiltered_df,
109
            dk.training_features_list,
110
            dk.label_list,
111
            training_filter=True,
112
        )
113

114
        dd: Dict[str, Any] = dk.make_train_test_datasets(
1✔
115
            features_filtered, labels_filtered)
116
        self.df_raw = copy.deepcopy(dd["train_features"])
1✔
117
        dk.fit_labels()  # FIXME useless for now, but just satiating append methods
1✔
118

119
        # normalize all data based on train_dataset only
120
        prices_train, prices_test = self.build_ohlc_price_dataframes(dk.data_dictionary, pair, dk)
1✔
121

122
        dk.feature_pipeline = self.define_data_pipeline(threads=dk.thread_count)
1✔
123

124
        (dd["train_features"],
1✔
125
         dd["train_labels"],
126
         dd["train_weights"]) = dk.feature_pipeline.fit_transform(dd["train_features"],
127
                                                                  dd["train_labels"],
128
                                                                  dd["train_weights"])
129

130
        if self.freqai_info.get('data_split_parameters', {}).get('test_size', 0.1) != 0:
1✔
131
            (dd["test_features"],
1✔
132
             dd["test_labels"],
133
             dd["test_weights"]) = dk.feature_pipeline.transform(dd["test_features"],
134
                                                                 dd["test_labels"],
135
                                                                 dd["test_weights"])
136

137
        logger.info(
1✔
138
            f'Training model on {len(dk.data_dictionary["train_features"].columns)}'
139
            f' features and {len(dd["train_features"])} data points'
140
        )
141

142
        self.set_train_and_eval_environments(dd, prices_train, prices_test, dk)
1✔
143

144
        model = self.fit(dd, dk)
1✔
145

146
        logger.info(f"--------------------done training {pair}--------------------")
1✔
147

148
        return model
1✔
149

150
    def set_train_and_eval_environments(self, data_dictionary: Dict[str, DataFrame],
1✔
151
                                        prices_train: DataFrame, prices_test: DataFrame,
152
                                        dk: FreqaiDataKitchen):
153
        """
154
        User can override this if they are using a custom MyRLEnv
155
        :param data_dictionary: dict = common data dictionary containing train and test
156
            features/labels/weights.
157
        :param prices_train/test: DataFrame = dataframe comprised of the prices to be used in the
158
            environment during training or testing
159
        :param dk: FreqaiDataKitchen = the datakitchen for the current pair
160
        """
161
        train_df = data_dictionary["train_features"]
1✔
162
        test_df = data_dictionary["test_features"]
1✔
163

164
        env_info = self.pack_env_dict(dk.pair)
1✔
165

166
        self.train_env = self.MyRLEnv(df=train_df, prices=prices_train, **env_info)
1✔
167
        self.eval_env = Monitor(self.MyRLEnv(df=test_df, prices=prices_test, **env_info))
1✔
168
        self.eval_callback = MaskableEvalCallback(self.eval_env, deterministic=True,
1✔
169
                                                  render=False, eval_freq=len(train_df),
170
                                                  best_model_save_path=str(dk.data_path),
171
                                                  use_masking=(self.model_type == 'MaskablePPO' and
172
                                                               is_masking_supported(self.eval_env)))
173

174
        actions = self.train_env.get_actions()
1✔
175
        self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions)
1✔
176

177
    def pack_env_dict(self, pair: str) -> Dict[str, Any]:
1✔
178
        """
179
        Create dictionary of environment arguments
180
        """
181
        env_info = {"window_size": self.CONV_WIDTH,
1✔
182
                    "reward_kwargs": self.reward_params,
183
                    "config": self.config,
184
                    "live": self.live,
185
                    "can_short": self.can_short,
186
                    "pair": pair,
187
                    "df_raw": self.df_raw}
188
        if self.data_provider:
1✔
189
            env_info["fee"] = self.data_provider._exchange \
×
190
                .get_fee(symbol=self.data_provider.current_whitelist()[0])  # type: ignore
191

192
        return env_info
1✔
193

194
    @abstractmethod
1✔
195
    def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
1✔
196
        """
197
        Agent customizations and abstract Reinforcement Learning customizations
198
        go in here. Abstract method, so this function must be overridden by
199
        user class.
200
        """
201
        return
×
202

203
    def get_state_info(self, pair: str) -> Tuple[float, float, int]:
1✔
204
        """
205
        State info during dry/live (not backtesting) which is fed back
206
        into the model.
207
        :param pair: str = COIN/STAKE to get the environment information for
208
        :return:
209
        :market_side: float = representing short, long, or neutral for
210
            pair
211
        :current_profit: float = unrealized profit of the current trade
212
        :trade_duration: int = the number of candles that the trade has
213
            been open for
214
        """
215
        open_trades = Trade.get_trades_proxy(is_open=True)
1✔
216
        market_side = 0.5
1✔
217
        current_profit: float = 0
1✔
218
        trade_duration = 0
1✔
219
        for trade in open_trades:
1✔
220
            if trade.pair == pair:
1✔
221
                if self.data_provider._exchange is None:  # type: ignore
1✔
222
                    logger.error('No exchange available.')
1✔
223
                    return 0, 0, 0
1✔
224
                else:
225
                    current_rate = self.data_provider._exchange.get_rate(  # type: ignore
1✔
226
                                pair, refresh=False, side="exit", is_short=trade.is_short)
227

228
                now = datetime.now(timezone.utc).timestamp()
1✔
229
                trade_duration = int((now - trade.open_date_utc.timestamp()) / self.base_tf_seconds)
1✔
230
                current_profit = trade.calc_profit_ratio(current_rate)
1✔
231
                if trade.is_short:
1✔
232
                    market_side = 0
×
233
                else:
234
                    market_side = 1
1✔
235

236
        return market_side, current_profit, int(trade_duration)
1✔
237

238
    def predict(
1✔
239
        self, unfiltered_df: DataFrame, dk: FreqaiDataKitchen, **kwargs
240
    ) -> Tuple[DataFrame, npt.NDArray[np.int_]]:
241
        """
242
        Filter the prediction features data and predict with it.
243
        :param unfiltered_dataframe: Full dataframe for the current backtest period.
244
        :return:
245
        :pred_df: dataframe containing the predictions
246
        :do_predict: np.array of 1s and 0s to indicate places where freqai needed to remove
247
        data (NaNs) or felt uncertain about data (PCA and DI index)
248
        """
249

250
        dk.find_features(unfiltered_df)
1✔
251
        filtered_dataframe, _ = dk.filter_features(
1✔
252
            unfiltered_df, dk.training_features_list, training_filter=False
253
        )
254

255
        dk.data_dictionary["prediction_features"] = self.drop_ohlc_from_df(filtered_dataframe, dk)
1✔
256

257
        dk.data_dictionary["prediction_features"], _, _ = dk.feature_pipeline.transform(
1✔
258
            dk.data_dictionary["prediction_features"], outlier_check=True)
259

260
        pred_df = self.rl_model_predict(
1✔
261
            dk.data_dictionary["prediction_features"], dk, self.model)
262
        pred_df.fillna(0, inplace=True)
1✔
263

264
        return (pred_df, dk.do_predict)
1✔
265

266
    def rl_model_predict(self, dataframe: DataFrame,
1✔
267
                         dk: FreqaiDataKitchen, model: Any) -> DataFrame:
268
        """
269
        A helper function to make predictions in the Reinforcement learning module.
270
        :param dataframe: DataFrame = the dataframe of features to make the predictions on
271
        :param dk: FreqaiDatakitchen = data kitchen for the current pair
272
        :param model: Any = the trained model used to inference the features.
273
        """
274
        output = pd.DataFrame(np.zeros(len(dataframe)), columns=dk.label_list)
1✔
275

276
        def _predict(window):
1✔
277
            observations = dataframe.iloc[window.index]
1✔
278
            if self.live and self.rl_config.get('add_state_info', False):
1✔
279
                market_side, current_profit, trade_duration = self.get_state_info(dk.pair)
×
280
                observations['current_profit_pct'] = current_profit
×
281
                observations['position'] = market_side
×
282
                observations['trade_duration'] = trade_duration
×
283
            res, _ = model.predict(observations, deterministic=True)
1✔
284
            return res
1✔
285

286
        output = output.rolling(window=self.CONV_WIDTH).apply(_predict)
1✔
287

288
        return output
1✔
289

290
    def build_ohlc_price_dataframes(self, data_dictionary: dict,
1✔
291
                                    pair: str, dk: FreqaiDataKitchen) -> Tuple[DataFrame,
292
                                                                               DataFrame]:
293
        """
294
        Builds the train prices and test prices for the environment.
295
        """
296

297
        pair = pair.replace(':', '')
1✔
298
        train_df = data_dictionary["train_features"]
1✔
299
        test_df = data_dictionary["test_features"]
1✔
300

301
        # price data for model training and evaluation
302
        tf = self.config['timeframe']
1✔
303
        rename_dict = {'%-raw_open': 'open', '%-raw_low': 'low',
1✔
304
                       '%-raw_high': ' high', '%-raw_close': 'close'}
305
        rename_dict_old = {f'%-{pair}raw_open_{tf}': 'open', f'%-{pair}raw_low_{tf}': 'low',
1✔
306
                           f'%-{pair}raw_high_{tf}': ' high', f'%-{pair}raw_close_{tf}': 'close'}
307

308
        prices_train = train_df.filter(rename_dict.keys(), axis=1)
1✔
309
        prices_train_old = train_df.filter(rename_dict_old.keys(), axis=1)
1✔
310
        if prices_train.empty or not prices_train_old.empty:
1✔
311
            if not prices_train_old.empty:
×
312
                prices_train = prices_train_old
×
313
                rename_dict = rename_dict_old
×
314
            logger.warning('Reinforcement learning module didn\'t find the correct raw prices '
×
315
                           'assigned in feature_engineering_standard(). '
316
                           'Please assign them with:\n'
317
                           'dataframe["%-raw_close"] = dataframe["close"]\n'
318
                           'dataframe["%-raw_open"] = dataframe["open"]\n'
319
                           'dataframe["%-raw_high"] = dataframe["high"]\n'
320
                           'dataframe["%-raw_low"] = dataframe["low"]\n'
321
                           'inside `feature_engineering_standard()')
322
        elif prices_train.empty:
1✔
323
            raise OperationalException("No prices found, please follow log warning "
×
324
                                       "instructions to correct the strategy.")
325

326
        prices_train.rename(columns=rename_dict, inplace=True)
1✔
327
        prices_train.reset_index(drop=True)
1✔
328

329
        prices_test = test_df.filter(rename_dict.keys(), axis=1)
1✔
330
        prices_test.rename(columns=rename_dict, inplace=True)
1✔
331
        prices_test.reset_index(drop=True)
1✔
332

333
        train_df = self.drop_ohlc_from_df(train_df, dk)
1✔
334
        test_df = self.drop_ohlc_from_df(test_df, dk)
1✔
335

336
        return prices_train, prices_test
1✔
337

338
    def drop_ohlc_from_df(self, df: DataFrame, dk: FreqaiDataKitchen):
1✔
339
        """
340
        Given a dataframe, drop the ohlc data
341
        """
342
        drop_list = ['%-raw_open', '%-raw_low', '%-raw_high', '%-raw_close']
1✔
343

344
        if self.rl_config["drop_ohlc_from_features"]:
1✔
345
            df.drop(drop_list, axis=1, inplace=True)
1✔
346
            feature_list = dk.training_features_list
1✔
347
            dk.training_features_list = [e for e in feature_list if e not in drop_list]
1✔
348

349
        return df
1✔
350

351
    def load_model_from_disk(self, dk: FreqaiDataKitchen) -> Any:
1✔
352
        """
353
        Can be used by user if they are trying to limit_ram_usage *and*
354
        perform continual learning.
355
        For now, this is unused.
356
        """
357
        exists = Path(dk.data_path / f"{dk.model_filename}_model").is_file()
×
358
        if exists:
×
359
            model = self.MODELCLASS.load(dk.data_path / f"{dk.model_filename}_model")
×
360
        else:
361
            logger.info('No model file on disk to continue learning from.')
×
362

363
        return model
×
364

365
    def _on_stop(self):
1✔
366
        """
367
        Hook called on bot shutdown. Close SubprocVecEnv subprocesses for clean shutdown.
368
        """
369

370
        if self.train_env:
×
371
            self.train_env.close()
×
372

373
        if self.eval_env:
×
374
            self.eval_env.close()
×
375

376
    # Nested class which can be overridden by user to customize further
377
    class MyRLEnv(Base5ActionRLEnv):
1✔
378
        """
379
        User can override any function in BaseRLEnv and gym.Env. Here the user
380
        sets a custom reward based on profit and trade duration.
381
        """
382

383
        def calculate_reward(self, action: int) -> float:  # noqa: C901
1✔
384
            """
385
            An example reward function. This is the one function that users will likely
386
            wish to inject their own creativity into.
387

388
            Warning!
389
            This is function is a showcase of functionality designed to show as many possible
390
            environment control features as possible. It is also designed to run quickly
391
            on small computers. This is a benchmark, it is *not* for live production.
392

393
            :param action: int = The action made by the agent for the current candle.
394
            :return:
395
            float = the reward to give to the agent for current step (used for optimization
396
                of weights in NN)
397
            """
398
            # first, penalize if the action is not valid
399
            if not self._is_valid(action):
×
400
                return -2
×
401

402
            pnl = self.get_unrealized_profit()
×
403
            factor = 100.
×
404

405
            # you can use feature values from dataframe
406
            rsi_now = self.raw_features[f"%-rsi-period-10_shift-1_{self.pair}_"
×
407
                                        f"{self.config['timeframe']}"].iloc[self._current_tick]
408

409
            # reward agent for entering trades
410
            if (action in (Actions.Long_enter.value, Actions.Short_enter.value)
×
411
                    and self._position == Positions.Neutral):
412
                if rsi_now < 40:
×
413
                    factor = 40 / rsi_now
×
414
                else:
415
                    factor = 1
×
416
                return 25 * factor
×
417

418
            # discourage agent from not entering trades
419
            if action == Actions.Neutral.value and self._position == Positions.Neutral:
×
420
                return -1
×
421

422
            max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300)
×
423
            if self._last_trade_tick:
×
424
                trade_duration = self._current_tick - self._last_trade_tick
×
425
            else:
426
                trade_duration = 0
×
427

428
            if trade_duration <= max_trade_duration:
×
429
                factor *= 1.5
×
430
            elif trade_duration > max_trade_duration:
×
431
                factor *= 0.5
×
432

433
            # discourage sitting in position
434
            if (self._position in (Positions.Short, Positions.Long) and
×
435
               action == Actions.Neutral.value):
436
                return -1 * trade_duration / max_trade_duration
×
437

438
            # close long
439
            if action == Actions.Long_exit.value and self._position == Positions.Long:
×
440
                if pnl > self.profit_aim * self.rr:
×
441
                    factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
×
442
                return float(pnl * factor)
×
443

444
            # close short
445
            if action == Actions.Short_exit.value and self._position == Positions.Short:
×
446
                if pnl > self.profit_aim * self.rr:
×
447
                    factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
×
448
                return float(pnl * factor)
×
449

450
            return 0.
×
451

452

453
def make_env(MyRLEnv: Type[BaseEnvironment], env_id: str, rank: int,
1✔
454
             seed: int, train_df: DataFrame, price: DataFrame,
455
             env_info: Dict[str, Any] = {}) -> Callable:
456
    """
457
    Utility function for multiprocessed env.
458

459
    :param env_id: (str) the environment ID
460
    :param num_env: (int) the number of environment you wish to have in subprocesses
461
    :param seed: (int) the initial seed for RNG
462
    :param rank: (int) index of the subprocess
463
    :param env_info: (dict) all required arguments to instantiate the environment.
464
    :return: (Callable)
465
    """
466

467
    def _init() -> gym.Env:
1✔
468

469
        env = MyRLEnv(df=train_df, prices=price, id=env_id, seed=seed + rank,
×
470
                      **env_info)
471

472
        return env
×
473
    set_random_seed(seed)
1✔
474
    return _init
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc