• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

freqtrade / freqtrade / 4131167254

pending completion
4131167254

push

github-actions

GitHub
Merge pull request #7983 from stash86/bt-metrics

16866 of 17748 relevant lines covered (95.03%)

0.95 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

86.75
/freqtrade/freqai/RL/BaseEnvironment.py
1
import logging
1✔
2
import random
1✔
3
from abc import abstractmethod
1✔
4
from enum import Enum
1✔
5
from typing import Optional, Type, Union
1✔
6

7
import gym
1✔
8
import numpy as np
1✔
9
import pandas as pd
1✔
10
from gym import spaces
1✔
11
from gym.utils import seeding
1✔
12
from pandas import DataFrame
1✔
13

14

15
logger = logging.getLogger(__name__)
1✔
16

17

18
class BaseActions(Enum):
1✔
19
    """
20
    Default action space, mostly used for type handling.
21
    """
22
    Neutral = 0
1✔
23
    Long_enter = 1
1✔
24
    Long_exit = 2
1✔
25
    Short_enter = 3
1✔
26
    Short_exit = 4
1✔
27

28

29
class Positions(Enum):
1✔
30
    Short = 0
1✔
31
    Long = 1
1✔
32
    Neutral = 0.5
1✔
33

34
    def opposite(self):
1✔
35
        return Positions.Short if self == Positions.Long else Positions.Long
×
36

37

38
class BaseEnvironment(gym.Env):
1✔
39
    """
40
    Base class for environments. This class is agnostic to action count.
41
    Inherited classes customize this to include varying action counts/types,
42
    See RL/Base5ActionRLEnv.py and RL/Base4ActionRLEnv.py
43
    """
44

45
    def __init__(self, df: DataFrame = DataFrame(), prices: DataFrame = DataFrame(),
1✔
46
                 reward_kwargs: dict = {}, window_size=10, starting_point=True,
47
                 id: str = 'baseenv-1', seed: int = 1, config: dict = {}, live: bool = False,
48
                 fee: float = 0.0015, can_short: bool = False):
49
        """
50
        Initializes the training/eval environment.
51
        :param df: dataframe of features
52
        :param prices: dataframe of prices to be used in the training environment
53
        :param window_size: size of window (temporal) to pass to the agent
54
        :param reward_kwargs: extra config settings assigned by user in `rl_config`
55
        :param starting_point: start at edge of window or not
56
        :param id: string id of the environment (used in backend for multiprocessed env)
57
        :param seed: Sets the seed of the environment higher in the gym.Env object
58
        :param config: Typical user configuration file
59
        :param live: Whether or not this environment is active in dry/live/backtesting
60
        :param fee: The fee to use for environmental interactions.
61
        :param can_short: Whether or not the environment can short
62
        """
63
        self.config = config
1✔
64
        self.rl_config = config['freqai']['rl_config']
1✔
65
        self.add_state_info = self.rl_config.get('add_state_info', False)
1✔
66
        self.id = id
1✔
67
        self.max_drawdown = 1 - self.rl_config.get('max_training_drawdown_pct', 0.8)
1✔
68
        self.compound_trades = config['stake_amount'] == 'unlimited'
1✔
69
        if self.config.get('fee', None) is not None:
1✔
70
            self.fee = self.config['fee']
×
71
        else:
72
            self.fee = fee
1✔
73

74
        # set here to default 5Ac, but all children envs can override this
75
        self.actions: Type[Enum] = BaseActions
1✔
76
        self.tensorboard_metrics: dict = {}
1✔
77
        self.can_short = can_short
1✔
78
        self.live = live
1✔
79
        if not self.live and self.add_state_info:
1✔
80
            self.add_state_info = False
×
81
            logger.warning("add_state_info is not available in backtesting. Deactivating.")
×
82
        self.seed(seed)
1✔
83
        self.reset_env(df, prices, window_size, reward_kwargs, starting_point)
1✔
84

85
    def reset_env(self, df: DataFrame, prices: DataFrame, window_size: int,
1✔
86
                  reward_kwargs: dict, starting_point=True):
87
        """
88
        Resets the environment when the agent fails (in our case, if the drawdown
89
        exceeds the user set max_training_drawdown_pct)
90
        :param df: dataframe of features
91
        :param prices: dataframe of prices to be used in the training environment
92
        :param window_size: size of window (temporal) to pass to the agent
93
        :param reward_kwargs: extra config settings assigned by user in `rl_config`
94
        :param starting_point: start at edge of window or not
95
        """
96
        self.df = df
1✔
97
        self.signal_features = self.df
1✔
98
        self.prices = prices
1✔
99
        self.window_size = window_size
1✔
100
        self.starting_point = starting_point
1✔
101
        self.rr = reward_kwargs["rr"]
1✔
102
        self.profit_aim = reward_kwargs["profit_aim"]
1✔
103

104
        # # spaces
105
        if self.add_state_info:
1✔
106
            self.total_features = self.signal_features.shape[1] + 3
×
107
        else:
108
            self.total_features = self.signal_features.shape[1]
1✔
109
        self.shape = (window_size, self.total_features)
1✔
110
        self.set_action_space()
1✔
111
        self.observation_space = spaces.Box(
1✔
112
            low=-1, high=1, shape=self.shape, dtype=np.float32)
113

114
        # episode
115
        self._start_tick: int = self.window_size
1✔
116
        self._end_tick: int = len(self.prices) - 1
1✔
117
        self._done: bool = False
1✔
118
        self._current_tick: int = self._start_tick
1✔
119
        self._last_trade_tick: Optional[int] = None
1✔
120
        self._position = Positions.Neutral
1✔
121
        self._position_history: list = [None]
1✔
122
        self.total_reward: float = 0
1✔
123
        self._total_profit: float = 1
1✔
124
        self._total_unrealized_profit: float = 1
1✔
125
        self.history: dict = {}
1✔
126
        self.trade_history: list = []
1✔
127

128
    @abstractmethod
1✔
129
    def set_action_space(self):
1✔
130
        """
131
        Unique to the environment action count. Must be inherited.
132
        """
133

134
    def seed(self, seed: int = 1):
1✔
135
        self.np_random, seed = seeding.np_random(seed)
1✔
136
        return [seed]
1✔
137

138
    def tensorboard_log(self, metric: str, value: Union[int, float] = 1, inc: bool = True):
1✔
139
        """
140
        Function builds the tensorboard_metrics dictionary
141
        to be parsed by the TensorboardCallback. This
142
        function is designed for tracking incremented objects,
143
        events, actions inside the training environment.
144
        For example, a user can call this to track the
145
        frequency of occurence of an `is_valid` call in
146
        their `calculate_reward()`:
147

148
        def calculate_reward(self, action: int) -> float:
149
            if not self._is_valid(action):
150
                self.tensorboard_log("is_valid")
151
                return -2
152

153
        :param metric: metric to be tracked and incremented
154
        :param value: value to increment `metric` by
155
        :param inc: sets whether the `value` is incremented or not
156
        """
157
        if not inc or metric not in self.tensorboard_metrics:
1✔
158
            self.tensorboard_metrics[metric] = value
1✔
159
        else:
160
            self.tensorboard_metrics[metric] += value
1✔
161

162
    def reset_tensorboard_log(self):
1✔
163
        self.tensorboard_metrics = {}
1✔
164

165
    def reset(self):
1✔
166
        """
167
        Reset is called at the beginning of every episode
168
        """
169
        self.reset_tensorboard_log()
1✔
170

171
        self._done = False
1✔
172

173
        if self.starting_point is True:
1✔
174
            if self.rl_config.get('randomize_starting_position', False):
1✔
175
                length_of_data = int(self._end_tick / 4)
×
176
                start_tick = random.randint(self.window_size + 1, length_of_data)
×
177
                self._start_tick = start_tick
×
178
            self._position_history = (self._start_tick * [None]) + [self._position]
1✔
179
        else:
180
            self._position_history = (self.window_size * [None]) + [self._position]
×
181

182
        self._current_tick = self._start_tick
1✔
183
        self._last_trade_tick = None
1✔
184
        self._position = Positions.Neutral
1✔
185

186
        self.total_reward = 0.
1✔
187
        self._total_profit = 1.  # unit
1✔
188
        self.history = {}
1✔
189
        self.trade_history = []
1✔
190
        self.portfolio_log_returns = np.zeros(len(self.prices))
1✔
191

192
        self._profits = [(self._start_tick, 1)]
1✔
193
        self.close_trade_profit = []
1✔
194
        self._total_unrealized_profit = 1
1✔
195

196
        return self._get_observation()
1✔
197

198
    @abstractmethod
1✔
199
    def step(self, action: int):
1✔
200
        """
201
        Step depeneds on action types, this must be inherited.
202
        """
203
        return
×
204

205
    def _get_observation(self):
1✔
206
        """
207
        This may or may not be independent of action types, user can inherit
208
        this in their custom "MyRLEnv"
209
        """
210
        features_window = self.signal_features[(
1✔
211
            self._current_tick - self.window_size):self._current_tick]
212
        if self.add_state_info:
1✔
213
            features_and_state = DataFrame(np.zeros((len(features_window), 3)),
×
214
                                           columns=['current_profit_pct',
215
                                                    'position',
216
                                                    'trade_duration'],
217
                                           index=features_window.index)
218

219
            features_and_state['current_profit_pct'] = self.get_unrealized_profit()
×
220
            features_and_state['position'] = self._position.value
×
221
            features_and_state['trade_duration'] = self.get_trade_duration()
×
222
            features_and_state = pd.concat([features_window, features_and_state], axis=1)
×
223
            return features_and_state
×
224
        else:
225
            return features_window
1✔
226

227
    def get_trade_duration(self):
1✔
228
        """
229
        Get the trade duration if the agent is in a trade
230
        """
231
        if self._last_trade_tick is None:
1✔
232
            return 0
1✔
233
        else:
234
            return self._current_tick - self._last_trade_tick
1✔
235

236
    def get_unrealized_profit(self):
1✔
237
        """
238
        Get the unrealized profit if the agent is in a trade
239
        """
240
        if self._last_trade_tick is None:
1✔
241
            return 0.
1✔
242

243
        if self._position == Positions.Neutral:
1✔
244
            return 0.
×
245
        elif self._position == Positions.Short:
1✔
246
            current_price = self.add_entry_fee(self.prices.iloc[self._current_tick].open)
1✔
247
            last_trade_price = self.add_exit_fee(self.prices.iloc[self._last_trade_tick].open)
1✔
248
            return (last_trade_price - current_price) / last_trade_price
1✔
249
        elif self._position == Positions.Long:
1✔
250
            current_price = self.add_exit_fee(self.prices.iloc[self._current_tick].open)
1✔
251
            last_trade_price = self.add_entry_fee(self.prices.iloc[self._last_trade_tick].open)
1✔
252
            return (current_price - last_trade_price) / last_trade_price
1✔
253
        else:
254
            return 0.
×
255

256
    @abstractmethod
1✔
257
    def is_tradesignal(self, action: int) -> bool:
1✔
258
        """
259
        Determine if the signal is a trade signal. This is
260
        unique to the actions in the environment, and therefore must be
261
        inherited.
262
        """
263
        return True
×
264

265
    def _is_valid(self, action: int) -> bool:
1✔
266
        """
267
        Determine if the signal is valid.This is
268
        unique to the actions in the environment, and therefore must be
269
        inherited.
270
        """
271
        return True
×
272

273
    def add_entry_fee(self, price):
1✔
274
        return price * (1 + self.fee)
1✔
275

276
    def add_exit_fee(self, price):
1✔
277
        return price / (1 + self.fee)
1✔
278

279
    def _update_history(self, info):
1✔
280
        if not self.history:
1✔
281
            self.history = {key: [] for key in info.keys()}
1✔
282

283
        for key, value in info.items():
1✔
284
            self.history[key].append(value)
1✔
285

286
    @abstractmethod
1✔
287
    def calculate_reward(self, action: int) -> float:
1✔
288
        """
289
        An example reward function. This is the one function that users will likely
290
        wish to inject their own creativity into.
291
        :param action: int = The action made by the agent for the current candle.
292
        :return:
293
        float = the reward to give to the agent for current step (used for optimization
294
            of weights in NN)
295
        """
296

297
    def _update_unrealized_total_profit(self):
1✔
298
        """
299
        Update the unrealized total profit incase of episode end.
300
        """
301
        if self._position in (Positions.Long, Positions.Short):
1✔
302
            pnl = self.get_unrealized_profit()
1✔
303
            if self.compound_trades:
1✔
304
                # assumes unit stake and compounding
305
                unrl_profit = self._total_profit * (1 + pnl)
×
306
            else:
307
                # assumes unit stake and no compounding
308
                unrl_profit = self._total_profit + pnl
1✔
309
            self._total_unrealized_profit = unrl_profit
1✔
310

311
    def _update_total_profit(self):
1✔
312
        pnl = self.get_unrealized_profit()
1✔
313
        if self.compound_trades:
1✔
314
            # assumes unit stake and compounding
315
            self._total_profit = self._total_profit * (1 + pnl)
×
316
        else:
317
            # assumes unit stake and no compounding
318
            self._total_profit += pnl
1✔
319

320
    def current_price(self) -> float:
1✔
321
        return self.prices.iloc[self._current_tick].open
1✔
322

323
    def get_actions(self) -> Type[Enum]:
1✔
324
        """
325
        Used by SubprocVecEnv to get actions from
326
        initialized env for tensorboard callback
327
        """
328
        return self.actions
1✔
329

330
    # Keeping around incase we want to start building more complex environment
331
    # templates in the future.
332
    # def most_recent_return(self):
333
    #     """
334
    #     Calculate the tick to tick return if in a trade.
335
    #     Return is generated from rising prices in Long
336
    #     and falling prices in Short positions.
337
    #     The actions Sell/Buy or Hold during a Long position trigger the sell/buy-fee.
338
    #     """
339
    #     # Long positions
340
    #     if self._position == Positions.Long:
341
    #         current_price = self.prices.iloc[self._current_tick].open
342
    #         previous_price = self.prices.iloc[self._current_tick - 1].open
343

344
    #         if (self._position_history[self._current_tick - 1] == Positions.Short
345
    #                 or self._position_history[self._current_tick - 1] == Positions.Neutral):
346
    #             previous_price = self.add_entry_fee(previous_price)
347

348
    #         return np.log(current_price) - np.log(previous_price)
349

350
    #     # Short positions
351
    #     if self._position == Positions.Short:
352
    #         current_price = self.prices.iloc[self._current_tick].open
353
    #         previous_price = self.prices.iloc[self._current_tick - 1].open
354
    #         if (self._position_history[self._current_tick - 1] == Positions.Long
355
    #                 or self._position_history[self._current_tick - 1] == Positions.Neutral):
356
    #             previous_price = self.add_exit_fee(previous_price)
357

358
    #         return np.log(previous_price) - np.log(current_price)
359

360
    #     return 0
361

362
    # def update_portfolio_log_returns(self, action):
363
    #     self.portfolio_log_returns[self._current_tick] = self.most_recent_return(action)
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc