• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

freqtrade / freqtrade / 9394559170

26 Apr 2024 06:36AM UTC coverage: 94.656% (-0.02%) from 94.674%
9394559170

push

github

xmatthias
Loader should be passed as kwarg for clarity

20280 of 21425 relevant lines covered (94.66%)

0.95 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

86.46
/freqtrade/freqai/torch/PyTorchModelTrainer.py
1
import logging
1✔
2
from pathlib import Path
1✔
3
from typing import Any, Dict, List, Optional
1✔
4

5
import pandas as pd
1✔
6
import torch
1✔
7
from torch import nn
1✔
8
from torch.optim import Optimizer
1✔
9
from torch.utils.data import DataLoader, TensorDataset
1✔
10

11
from freqtrade.freqai.torch.PyTorchDataConvertor import PyTorchDataConvertor
1✔
12
from freqtrade.freqai.torch.PyTorchTrainerInterface import PyTorchTrainerInterface
1✔
13

14
from .datasets import WindowDataset
1✔
15

16

17
logger = logging.getLogger(__name__)
1✔
18

19

20
class PyTorchModelTrainer(PyTorchTrainerInterface):
1✔
21
    def __init__(
1✔
22
            self,
23
            model: nn.Module,
24
            optimizer: Optimizer,
25
            criterion: nn.Module,
26
            device: str,
27
            data_convertor: PyTorchDataConvertor,
28
            model_meta_data: Dict[str, Any] = {},
29
            window_size: int = 1,
30
            tb_logger: Any = None,
31
            **kwargs
32
    ):
33
        """
34
        :param model: The PyTorch model to be trained.
35
        :param optimizer: The optimizer to use for training.
36
        :param criterion: The loss function to use for training.
37
        :param device: The device to use for training (e.g. 'cpu', 'cuda').
38
        :param init_model: A dictionary containing the initial model/optimizer
39
            state_dict and model_meta_data saved by self.save() method.
40
        :param model_meta_data: Additional metadata about the model (optional).
41
        :param data_convertor: converter from pd.DataFrame to torch.tensor.
42
        :param n_steps: used to calculate n_epochs. The number of training iterations to run.
43
            iteration here refers to the number of times optimizer.step() is called.
44
            ignored if n_epochs is set.
45
        :param n_epochs: The maximum number batches to use for evaluation.
46
        :param batch_size: The size of the batches to use during training.
47
        """
48
        self.model = model
1✔
49
        self.optimizer = optimizer
1✔
50
        self.criterion = criterion
1✔
51
        self.model_meta_data = model_meta_data
1✔
52
        self.device = device
1✔
53
        self.n_epochs: Optional[int] = kwargs.get("n_epochs", 10)
1✔
54
        self.n_steps: Optional[int] = kwargs.get("n_steps", None)
1✔
55
        if self.n_steps is None and not self.n_epochs:
1✔
56
            raise Exception("Either `n_steps` or `n_epochs` should be set.")
×
57

58
        self.batch_size: int = kwargs.get("batch_size", 64)
1✔
59
        self.data_convertor = data_convertor
1✔
60
        self.window_size: int = window_size
1✔
61
        self.tb_logger = tb_logger
1✔
62
        self.test_batch_counter = 0
1✔
63

64
    def fit(self, data_dictionary: Dict[str, pd.DataFrame], splits: List[str]):
1✔
65
        """
66
        :param data_dictionary: the dictionary constructed by DataHandler to hold
67
        all the training and test data/labels.
68
        :param splits: splits to use in training, splits must contain "train",
69
        optional "test" could be added by setting freqai.data_split_parameters.test_size > 0
70
        in the config file.
71

72
         - Calculates the predicted output for the batch using the PyTorch model.
73
         - Calculates the loss between the predicted and actual output using a loss function.
74
         - Computes the gradients of the loss with respect to the model's parameters using
75
           backpropagation.
76
         - Updates the model's parameters using an optimizer.
77
        """
78
        self.model.train()
1✔
79

80
        data_loaders_dictionary = self.create_data_loaders_dictionary(data_dictionary, splits)
1✔
81
        n_obs = len(data_dictionary["train_features"])
1✔
82
        n_epochs = self.n_epochs or self.calc_n_epochs(n_obs=n_obs)
1✔
83
        batch_counter = 0
1✔
84
        for _ in range(n_epochs):
1✔
85
            for _, batch_data in enumerate(data_loaders_dictionary["train"]):
1✔
86
                xb, yb = batch_data
1✔
87
                xb = xb.to(self.device)
1✔
88
                yb = yb.to(self.device)
1✔
89
                yb_pred = self.model(xb)
1✔
90
                loss = self.criterion(yb_pred, yb)
1✔
91

92
                self.optimizer.zero_grad(set_to_none=True)
1✔
93
                loss.backward()
1✔
94
                self.optimizer.step()
1✔
95
                self.tb_logger.log_scalar("train_loss", loss.item(), batch_counter)
1✔
96
                batch_counter += 1
1✔
97

98
            # evaluation
99
            if "test" in splits:
1✔
100
                self.estimate_loss(data_loaders_dictionary, "test")
1✔
101

102
    @torch.no_grad()
1✔
103
    def estimate_loss(
1✔
104
            self,
105
            data_loader_dictionary: Dict[str, DataLoader],
106
            split: str,
107
    ) -> None:
108
        self.model.eval()
1✔
109
        for _, batch_data in enumerate(data_loader_dictionary[split]):
1✔
110
            xb, yb = batch_data
1✔
111
            xb = xb.to(self.device)
1✔
112
            yb = yb.to(self.device)
1✔
113

114
            yb_pred = self.model(xb)
1✔
115
            loss = self.criterion(yb_pred, yb)
1✔
116
            self.tb_logger.log_scalar(f"{split}_loss", loss.item(), self.test_batch_counter)
1✔
117
            self.test_batch_counter += 1
1✔
118

119
        self.model.train()
1✔
120

121
    def create_data_loaders_dictionary(
1✔
122
            self,
123
            data_dictionary: Dict[str, pd.DataFrame],
124
            splits: List[str]
125
    ) -> Dict[str, DataLoader]:
126
        """
127
        Converts the input data to PyTorch tensors using a data loader.
128
        """
129
        data_loader_dictionary = {}
1✔
130
        for split in splits:
1✔
131
            x = self.data_convertor.convert_x(data_dictionary[f"{split}_features"], self.device)
1✔
132
            y = self.data_convertor.convert_y(data_dictionary[f"{split}_labels"], self.device)
1✔
133
            dataset = TensorDataset(x, y)
1✔
134
            data_loader = DataLoader(
1✔
135
                dataset,
136
                batch_size=self.batch_size,
137
                shuffle=True,
138
                drop_last=True,
139
                num_workers=0,
140
            )
141
            data_loader_dictionary[split] = data_loader
1✔
142

143
        return data_loader_dictionary
1✔
144

145
    def calc_n_epochs(self, n_obs: int) -> int:
1✔
146
        """
147
        Calculates the number of epochs required to reach the maximum number
148
        of iterations specified in the model training parameters.
149

150
        the motivation here is that `n_steps` is easier to optimize and keep stable,
151
        across different n_obs - the number of data points.
152
        """
153
        assert isinstance(self.n_steps, int), "Either `n_steps` or `n_epochs` should be set."
×
154
        n_batches = n_obs // self.batch_size
×
155
        n_epochs = max(self.n_steps // n_batches, 1)
×
156
        if n_epochs <= 10:
×
157
            logger.warning(
×
158
                f"Setting low n_epochs: {n_epochs}. "
159
                f"Please consider increasing `n_steps` hyper-parameter."
160
            )
161

162
        return n_epochs
×
163

164
    def save(self, path: Path):
1✔
165
        """
166
        - Saving any nn.Module state_dict
167
        - Saving model_meta_data, this dict should contain any additional data that the
168
          user needs to store. e.g. class_names for classification models.
169
        """
170

171
        torch.save({
1✔
172
            "model_state_dict": self.model.state_dict(),
173
            "optimizer_state_dict": self.optimizer.state_dict(),
174
            "model_meta_data": self.model_meta_data,
175
            "pytrainer": self
176
        }, path)
177

178
    def load(self, path: Path):
1✔
179
        checkpoint = torch.load(path)
×
180
        return self.load_from_checkpoint(checkpoint)
×
181

182
    def load_from_checkpoint(self, checkpoint: Dict):
1✔
183
        """
184
        when using continual_learning, DataDrawer will load the dictionary
185
        (containing state dicts and model_meta_data) by calling torch.load(path).
186
        you can access this dict from any class that inherits IFreqaiModel by calling
187
        get_init_model method.
188
        """
189
        self.model.load_state_dict(checkpoint["model_state_dict"])
×
190
        self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
×
191
        self.model_meta_data = checkpoint["model_meta_data"]
×
192
        return self
×
193

194

195
class PyTorchTransformerTrainer(PyTorchModelTrainer):
1✔
196
    """
197
    Creating a trainer for the Transformer model.
198
    """
199

200
    def create_data_loaders_dictionary(
1✔
201
            self,
202
            data_dictionary: Dict[str, pd.DataFrame],
203
            splits: List[str]
204
    ) -> Dict[str, DataLoader]:
205
        """
206
        Converts the input data to PyTorch tensors using a data loader.
207
        """
208
        data_loader_dictionary = {}
1✔
209
        for split in splits:
1✔
210
            x = self.data_convertor.convert_x(data_dictionary[f"{split}_features"], self.device)
1✔
211
            y = self.data_convertor.convert_y(data_dictionary[f"{split}_labels"], self.device)
1✔
212
            dataset = WindowDataset(x, y, self.window_size)
1✔
213
            data_loader = DataLoader(
1✔
214
                dataset,
215
                batch_size=self.batch_size,
216
                shuffle=False,
217
                drop_last=True,
218
                num_workers=0,
219
            )
220
            data_loader_dictionary[split] = data_loader
1✔
221

222
        return data_loader_dictionary
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc