• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

freqtrade / freqtrade / 6181253459

08 Sep 2023 06:04AM UTC coverage: 94.614% (+0.06%) from 94.556%
6181253459

push

github-actions

web-flow
Merge pull request #9159 from stash86/fix-adjust

remove old codes when we only can do partial entries

2 of 2 new or added lines in 1 file covered. (100.0%)

19114 of 20202 relevant lines covered (94.61%)

0.95 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

86.86
/freqtrade/freqai/RL/BaseEnvironment.py
1
import logging
1✔
2
import random
1✔
3
from abc import abstractmethod
1✔
4
from enum import Enum
1✔
5
from typing import List, Optional, Type, Union
1✔
6

7
import gymnasium as gym
1✔
8
import numpy as np
1✔
9
import pandas as pd
1✔
10
from gymnasium import spaces
1✔
11
from gymnasium.utils import seeding
1✔
12
from pandas import DataFrame
1✔
13

14
from freqtrade.exceptions import OperationalException
1✔
15

16

17
logger = logging.getLogger(__name__)
1✔
18

19

20
class BaseActions(Enum):
1✔
21
    """
22
    Default action space, mostly used for type handling.
23
    """
24
    Neutral = 0
1✔
25
    Long_enter = 1
1✔
26
    Long_exit = 2
1✔
27
    Short_enter = 3
1✔
28
    Short_exit = 4
1✔
29

30

31
class Positions(Enum):
1✔
32
    Short = 0
1✔
33
    Long = 1
1✔
34
    Neutral = 0.5
1✔
35

36
    def opposite(self):
1✔
37
        return Positions.Short if self == Positions.Long else Positions.Long
×
38

39

40
class BaseEnvironment(gym.Env):
1✔
41
    """
42
    Base class for environments. This class is agnostic to action count.
43
    Inherited classes customize this to include varying action counts/types,
44
    See RL/Base5ActionRLEnv.py and RL/Base4ActionRLEnv.py
45
    """
46

47
    def __init__(self, df: DataFrame = DataFrame(), prices: DataFrame = DataFrame(),
1✔
48
                 reward_kwargs: dict = {}, window_size=10, starting_point=True,
49
                 id: str = 'baseenv-1', seed: int = 1, config: dict = {}, live: bool = False,
50
                 fee: float = 0.0015, can_short: bool = False, pair: str = "",
51
                 df_raw: DataFrame = DataFrame()):
52
        """
53
        Initializes the training/eval environment.
54
        :param df: dataframe of features
55
        :param prices: dataframe of prices to be used in the training environment
56
        :param window_size: size of window (temporal) to pass to the agent
57
        :param reward_kwargs: extra config settings assigned by user in `rl_config`
58
        :param starting_point: start at edge of window or not
59
        :param id: string id of the environment (used in backend for multiprocessed env)
60
        :param seed: Sets the seed of the environment higher in the gym.Env object
61
        :param config: Typical user configuration file
62
        :param live: Whether or not this environment is active in dry/live/backtesting
63
        :param fee: The fee to use for environmental interactions.
64
        :param can_short: Whether or not the environment can short
65
        """
66
        self.config: dict = config
1✔
67
        self.rl_config: dict = config['freqai']['rl_config']
1✔
68
        self.add_state_info: bool = self.rl_config.get('add_state_info', False)
1✔
69
        self.id: str = id
1✔
70
        self.max_drawdown: float = 1 - self.rl_config.get('max_training_drawdown_pct', 0.8)
1✔
71
        self.compound_trades: bool = config['stake_amount'] == 'unlimited'
1✔
72
        self.pair: str = pair
1✔
73
        self.raw_features: DataFrame = df_raw
1✔
74
        if self.config.get('fee', None) is not None:
1✔
75
            self.fee = self.config['fee']
×
76
        else:
77
            self.fee = fee
1✔
78

79
        # set here to default 5Ac, but all children envs can override this
80
        self.actions: Type[Enum] = BaseActions
1✔
81
        self.tensorboard_metrics: dict = {}
1✔
82
        self.can_short: bool = can_short
1✔
83
        self.live: bool = live
1✔
84
        if not self.live and self.add_state_info:
1✔
85
            raise OperationalException("`add_state_info` is not available in backtesting. Change "
×
86
                                       "parameter to false in your rl_config. See `add_state_info` "
87
                                       "docs for more info.")
88
        self.seed(seed)
1✔
89
        self.reset_env(df, prices, window_size, reward_kwargs, starting_point)
1✔
90

91
    def reset_env(self, df: DataFrame, prices: DataFrame, window_size: int,
1✔
92
                  reward_kwargs: dict, starting_point=True):
93
        """
94
        Resets the environment when the agent fails (in our case, if the drawdown
95
        exceeds the user set max_training_drawdown_pct)
96
        :param df: dataframe of features
97
        :param prices: dataframe of prices to be used in the training environment
98
        :param window_size: size of window (temporal) to pass to the agent
99
        :param reward_kwargs: extra config settings assigned by user in `rl_config`
100
        :param starting_point: start at edge of window or not
101
        """
102
        self.signal_features: DataFrame = df
1✔
103
        self.prices: DataFrame = prices
1✔
104
        self.window_size: int = window_size
1✔
105
        self.starting_point: bool = starting_point
1✔
106
        self.rr: float = reward_kwargs["rr"]
1✔
107
        self.profit_aim: float = reward_kwargs["profit_aim"]
1✔
108

109
        # # spaces
110
        if self.add_state_info:
1✔
111
            self.total_features = self.signal_features.shape[1] + 3
×
112
        else:
113
            self.total_features = self.signal_features.shape[1]
1✔
114
        self.shape = (window_size, self.total_features)
1✔
115
        self.set_action_space()
1✔
116
        self.observation_space = spaces.Box(
1✔
117
            low=-1, high=1, shape=self.shape, dtype=np.float32)
118

119
        # episode
120
        self._start_tick: int = self.window_size
1✔
121
        self._end_tick: int = len(self.prices) - 1
1✔
122
        self._done: bool = False
1✔
123
        self._current_tick: int = self._start_tick
1✔
124
        self._last_trade_tick: Optional[int] = None
1✔
125
        self._position = Positions.Neutral
1✔
126
        self._position_history: list = [None]
1✔
127
        self.total_reward: float = 0
1✔
128
        self._total_profit: float = 1
1✔
129
        self._total_unrealized_profit: float = 1
1✔
130
        self.history: dict = {}
1✔
131
        self.trade_history: list = []
1✔
132

133
    def get_attr(self, attr: str):
1✔
134
        """
135
        Returns the attribute of the environment
136
        :param attr: attribute to return
137
        :return: attribute
138
        """
139
        return getattr(self, attr)
×
140

141
    @abstractmethod
1✔
142
    def set_action_space(self):
1✔
143
        """
144
        Unique to the environment action count. Must be inherited.
145
        """
146

147
    def action_masks(self) -> List[bool]:
1✔
148
        return [self._is_valid(action.value) for action in self.actions]
×
149

150
    def seed(self, seed: int = 1):
1✔
151
        self.np_random, seed = seeding.np_random(seed)
1✔
152
        return [seed]
1✔
153

154
    def tensorboard_log(self, metric: str, value: Optional[Union[int, float]] = None,
1✔
155
                        inc: Optional[bool] = None, category: str = "custom"):
156
        """
157
        Function builds the tensorboard_metrics dictionary
158
        to be parsed by the TensorboardCallback. This
159
        function is designed for tracking incremented objects,
160
        events, actions inside the training environment.
161
        For example, a user can call this to track the
162
        frequency of occurence of an `is_valid` call in
163
        their `calculate_reward()`:
164

165
        def calculate_reward(self, action: int) -> float:
166
            if not self._is_valid(action):
167
                self.tensorboard_log("invalid")
168
                return -2
169

170
        :param metric: metric to be tracked and incremented
171
        :param value: `metric` value
172
        :param inc: (deprecated) sets whether the `value` is incremented or not
173
        :param category: `metric` category
174
        """
175
        increment = True if value is None else False
1✔
176
        value = 1 if increment else value
1✔
177

178
        if category not in self.tensorboard_metrics:
1✔
179
            self.tensorboard_metrics[category] = {}
1✔
180

181
        if not increment or metric not in self.tensorboard_metrics[category]:
1✔
182
            self.tensorboard_metrics[category][metric] = value
1✔
183
        else:
184
            self.tensorboard_metrics[category][metric] += value
1✔
185

186
    def reset_tensorboard_log(self):
1✔
187
        self.tensorboard_metrics = {}
1✔
188

189
    def reset(self, seed=None):
1✔
190
        """
191
        Reset is called at the beginning of every episode
192
        """
193
        self.reset_tensorboard_log()
1✔
194

195
        self._done = False
1✔
196

197
        if self.starting_point is True:
1✔
198
            if self.rl_config.get('randomize_starting_position', False):
1✔
199
                length_of_data = int(self._end_tick / 4)
×
200
                start_tick = random.randint(self.window_size + 1, length_of_data)
×
201
                self._start_tick = start_tick
×
202
            self._position_history = (self._start_tick * [None]) + [self._position]
1✔
203
        else:
204
            self._position_history = (self.window_size * [None]) + [self._position]
×
205

206
        self._current_tick = self._start_tick
1✔
207
        self._last_trade_tick = None
1✔
208
        self._position = Positions.Neutral
1✔
209

210
        self.total_reward = 0.
1✔
211
        self._total_profit = 1.  # unit
1✔
212
        self.history = {}
1✔
213
        self.trade_history = []
1✔
214
        self.portfolio_log_returns = np.zeros(len(self.prices))
1✔
215

216
        self._profits = [(self._start_tick, 1)]
1✔
217
        self.close_trade_profit = []
1✔
218
        self._total_unrealized_profit = 1
1✔
219

220
        return self._get_observation(), self.history
1✔
221

222
    @abstractmethod
1✔
223
    def step(self, action: int):
1✔
224
        """
225
        Step depeneds on action types, this must be inherited.
226
        """
227
        return
×
228

229
    def _get_observation(self):
1✔
230
        """
231
        This may or may not be independent of action types, user can inherit
232
        this in their custom "MyRLEnv"
233
        """
234
        features_window = self.signal_features[(
1✔
235
            self._current_tick - self.window_size):self._current_tick]
236
        if self.add_state_info:
1✔
237
            features_and_state = DataFrame(np.zeros((len(features_window), 3)),
×
238
                                           columns=['current_profit_pct',
239
                                                    'position',
240
                                                    'trade_duration'],
241
                                           index=features_window.index)
242

243
            features_and_state['current_profit_pct'] = self.get_unrealized_profit()
×
244
            features_and_state['position'] = self._position.value
×
245
            features_and_state['trade_duration'] = self.get_trade_duration()
×
246
            features_and_state = pd.concat([features_window, features_and_state], axis=1)
×
247
            return features_and_state
×
248
        else:
249
            return features_window
1✔
250

251
    def get_trade_duration(self):
1✔
252
        """
253
        Get the trade duration if the agent is in a trade
254
        """
255
        if self._last_trade_tick is None:
1✔
256
            return 0
1✔
257
        else:
258
            return self._current_tick - self._last_trade_tick
1✔
259

260
    def get_unrealized_profit(self):
1✔
261
        """
262
        Get the unrealized profit if the agent is in a trade
263
        """
264
        if self._last_trade_tick is None:
1✔
265
            return 0.
1✔
266

267
        if self._position == Positions.Neutral:
1✔
268
            return 0.
×
269
        elif self._position == Positions.Short:
1✔
270
            current_price = self.add_entry_fee(self.prices.iloc[self._current_tick].open)
1✔
271
            last_trade_price = self.add_exit_fee(self.prices.iloc[self._last_trade_tick].open)
1✔
272
            return (last_trade_price - current_price) / last_trade_price
1✔
273
        elif self._position == Positions.Long:
1✔
274
            current_price = self.add_exit_fee(self.prices.iloc[self._current_tick].open)
1✔
275
            last_trade_price = self.add_entry_fee(self.prices.iloc[self._last_trade_tick].open)
1✔
276
            return (current_price - last_trade_price) / last_trade_price
1✔
277
        else:
278
            return 0.
×
279

280
    @abstractmethod
1✔
281
    def is_tradesignal(self, action: int) -> bool:
1✔
282
        """
283
        Determine if the signal is a trade signal. This is
284
        unique to the actions in the environment, and therefore must be
285
        inherited.
286
        """
287
        return True
×
288

289
    def _is_valid(self, action: int) -> bool:
1✔
290
        """
291
        Determine if the signal is valid.This is
292
        unique to the actions in the environment, and therefore must be
293
        inherited.
294
        """
295
        return True
×
296

297
    def add_entry_fee(self, price):
1✔
298
        return price * (1 + self.fee)
1✔
299

300
    def add_exit_fee(self, price):
1✔
301
        return price / (1 + self.fee)
1✔
302

303
    def _update_history(self, info):
1✔
304
        if not self.history:
1✔
305
            self.history = {key: [] for key in info.keys()}
1✔
306

307
        for key, value in info.items():
1✔
308
            self.history[key].append(value)
1✔
309

310
    @abstractmethod
1✔
311
    def calculate_reward(self, action: int) -> float:
1✔
312
        """
313
        An example reward function. This is the one function that users will likely
314
        wish to inject their own creativity into.
315

316
        Warning!
317
        This is function is a showcase of functionality designed to show as many possible
318
        environment control features as possible. It is also designed to run quickly
319
        on small computers. This is a benchmark, it is *not* for live production.
320

321
        :param action: int = The action made by the agent for the current candle.
322
        :return:
323
        float = the reward to give to the agent for current step (used for optimization
324
            of weights in NN)
325
        """
326

327
    def _update_unrealized_total_profit(self):
1✔
328
        """
329
        Update the unrealized total profit incase of episode end.
330
        """
331
        if self._position in (Positions.Long, Positions.Short):
1✔
332
            pnl = self.get_unrealized_profit()
1✔
333
            if self.compound_trades:
1✔
334
                # assumes unit stake and compounding
335
                unrl_profit = self._total_profit * (1 + pnl)
×
336
            else:
337
                # assumes unit stake and no compounding
338
                unrl_profit = self._total_profit + pnl
1✔
339
            self._total_unrealized_profit = unrl_profit
1✔
340

341
    def _update_total_profit(self):
1✔
342
        pnl = self.get_unrealized_profit()
1✔
343
        if self.compound_trades:
1✔
344
            # assumes unit stake and compounding
345
            self._total_profit = self._total_profit * (1 + pnl)
×
346
        else:
347
            # assumes unit stake and no compounding
348
            self._total_profit += pnl
1✔
349

350
    def current_price(self) -> float:
1✔
351
        return self.prices.iloc[self._current_tick].open
1✔
352

353
    def get_actions(self) -> Type[Enum]:
1✔
354
        """
355
        Used by SubprocVecEnv to get actions from
356
        initialized env for tensorboard callback
357
        """
358
        return self.actions
1✔
359

360
    # Keeping around incase we want to start building more complex environment
361
    # templates in the future.
362
    # def most_recent_return(self):
363
    #     """
364
    #     Calculate the tick to tick return if in a trade.
365
    #     Return is generated from rising prices in Long
366
    #     and falling prices in Short positions.
367
    #     The actions Sell/Buy or Hold during a Long position trigger the sell/buy-fee.
368
    #     """
369
    #     # Long positions
370
    #     if self._position == Positions.Long:
371
    #         current_price = self.prices.iloc[self._current_tick].open
372
    #         previous_price = self.prices.iloc[self._current_tick - 1].open
373

374
    #         if (self._position_history[self._current_tick - 1] == Positions.Short
375
    #                 or self._position_history[self._current_tick - 1] == Positions.Neutral):
376
    #             previous_price = self.add_entry_fee(previous_price)
377

378
    #         return np.log(current_price) - np.log(previous_price)
379

380
    #     # Short positions
381
    #     if self._position == Positions.Short:
382
    #         current_price = self.prices.iloc[self._current_tick].open
383
    #         previous_price = self.prices.iloc[self._current_tick - 1].open
384
    #         if (self._position_history[self._current_tick - 1] == Positions.Long
385
    #                 or self._position_history[self._current_tick - 1] == Positions.Neutral):
386
    #             previous_price = self.add_exit_fee(previous_price)
387

388
    #         return np.log(previous_price) - np.log(current_price)
389

390
    #     return 0
391

392
    # def update_portfolio_log_returns(self, action):
393
    #     self.portfolio_log_returns[self._current_tick] = self.most_recent_return(action)
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc