• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

freqtrade / freqtrade / 9394559170

26 Apr 2024 06:36AM UTC coverage: 94.656% (-0.02%) from 94.674%
9394559170

push

github

xmatthias
Loader should be passed as kwarg for clarity

20280 of 21425 relevant lines covered (94.66%)

0.95 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

87.5
/freqtrade/freqai/torch/PyTorchTrainerInterface.py
1
from abc import ABC, abstractmethod
1✔
2
from pathlib import Path
1✔
3
from typing import Dict, List
1✔
4

5
import pandas as pd
1✔
6
import torch
1✔
7
from torch import nn
1✔
8

9

10
class PyTorchTrainerInterface(ABC):
1✔
11

12
    @abstractmethod
1✔
13
    def fit(self, data_dictionary: Dict[str, pd.DataFrame], splits: List[str]) -> None:
1✔
14
        """
15
        :param data_dictionary: the dictionary constructed by DataHandler to hold
16
        all the training and test data/labels.
17
        :param splits: splits to use in training, splits must contain "train",
18
        optional "test" could be added by setting freqai.data_split_parameters.test_size > 0
19
        in the config file.
20

21
         - Calculates the predicted output for the batch using the PyTorch model.
22
         - Calculates the loss between the predicted and actual output using a loss function.
23
         - Computes the gradients of the loss with respect to the model's parameters using
24
           backpropagation.
25
         - Updates the model's parameters using an optimizer.
26
        """
27

28
    @abstractmethod
1✔
29
    def save(self, path: Path) -> None:
1✔
30
        """
31
        - Saving any nn.Module state_dict
32
        - Saving model_meta_data, this dict should contain any additional data that the
33
          user needs to store. e.g class_names for classification models.
34
        """
35

36
    def load(self, path: Path) -> nn.Module:
1✔
37
        """
38
        :param path: path to zip file.
39
        :returns: pytorch model.
40
        """
41
        checkpoint = torch.load(path)
×
42
        return self.load_from_checkpoint(checkpoint)
×
43

44
    @abstractmethod
1✔
45
    def load_from_checkpoint(self, checkpoint: Dict) -> nn.Module:
1✔
46
        """
47
        when using continual_learning, DataDrawer will load the dictionary
48
        (containing state dicts and model_meta_data) by calling torch.load(path).
49
        you can access this dict from any class that inherits IFreqaiModel by calling
50
        get_init_model method.
51
        :checkpoint checkpoint: dict containing the model & optimizer state dicts,
52
        model_meta_data, etc..
53
        """
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc