• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

WenjieDu / PyPOTS / 4845620836

pending completion
4845620836

Pull #77

github

GitHub
Merge fd7b05c7e into 39b2bbebd
Pull Request #77: Fix dependency error in daily testing

3161 of 3722 relevant lines covered (84.93%)

0.85 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

79.8
/pypots/data/base.py
1
"""
1✔
2
The base class for PyPOTS datasets.
3
"""
4

5
# Created by Wenjie Du <wenjay.du@gmail.com>
6
# License: GPL-v3
7

8
from abc import abstractmethod
1✔
9
from typing import Union, Optional, Tuple, Iterable
1✔
10

11
import h5py
1✔
12
import numpy as np
1✔
13
import torch
1✔
14
from torch.utils.data import Dataset
1✔
15

16
# Currently we only support h5 files
17
SUPPORTED_DATASET_FILE_TYPE = ["h5py"]
1✔
18

19

20
class BaseDataset(Dataset):
1✔
21
    """Base dataset class in PyPOTS.
1✔
22

23
    data : dict or str,
24
        The dataset for model input, should be a dictionary including keys as 'X' and 'y',
25
        or a path string locating a data file.
26
        If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
27
        which is time-series data for input, can contain missing values, and y should be array-like of shape
28
        [n_samples], which is classification labels of X.
29
        If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
30
        key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
31

32
    return_labels : bool, default = True,
33
        Whether to return labels in function __getitem__() if they exist in the given data. If `True`, for example,
34
        during training of classification models, the Dataset class will return labels in __getitem__() for model input.
35
        Otherwise, labels won't be included in the data returned by __getitem__(). This parameter exists because we
36
        need the defined Dataset class for all training/validating/testing stages. For those big datasets stored in h5
37
        files, they already have both X and y saved. But we don't read labels from the file for validating and testing
38
        with function _fetch_data_from_file(), which works for all three stages. Therefore, we need this parameter for
39
        distinction.
40

41
    file_type : str, default = "h5py"
42
        The type of the given file if train_set and val_set are path strings.
43
    """
44

45
    def __init__(
1✔
46
        self,
47
        data: Union[dict, str],
48
        return_labels: bool = True,
49
        file_type: str = "h5py",
50
    ):
51
        super().__init__()
1✔
52
        # types and shapes had been checked after X and y input into the model
53
        # So they are safe to use here. No need to check again.
54

55
        self.data = data
1✔
56
        self.return_labels = return_labels
1✔
57
        if isinstance(self.data, str):  # data from file
1✔
58
            # check if the given file type is supported
59
            assert (
1✔
60
                file_type in SUPPORTED_DATASET_FILE_TYPE
61
            ), f"file_type should be one of {SUPPORTED_DATASET_FILE_TYPE}, but got {file_type}"
62

63
            self.file_type = file_type
1✔
64

65
            # open the file handle
66
            self.file_handle = self._open_file_handle()
1✔
67
            # check if X exists in the file
68
            assert (
1✔
69
                "X" in self.file_handle.keys()
70
            ), "The given dataset file doesn't contains X. Please double check."
71

72
        else:  # data from array
73
            X = data["X"]
1✔
74
            y = None if "y" not in data.keys() else data["y"]
1✔
75
            self.X, self.y = self.check_input(X, y)
1✔
76

77
        self.sample_num = self._get_sample_num()
1✔
78

79
        # set up function fetch_data()
80
        if isinstance(self.data, str):
1✔
81
            self.fetch_data = self._fetch_data_from_file
1✔
82
        else:
83
            self.fetch_data = self._fetch_data_from_array
1✔
84

85
    def _get_sample_num(self) -> int:
1✔
86
        """Determine the number of samples in the dataset and return the number.
87

88
        Returns
89
        -------
90
        sample_num : int
91
            The number of the samples in the given dataset.
92
        """
93
        if isinstance(self.data, str):
1✔
94
            if self.file_handle is None:
1✔
95
                self.file_handle = self._open_file_handle()
×
96
            sample_num = len(self.file_handle["X"])
1✔
97
        else:
98
            sample_num = len(self.X)
1✔
99

100
        return sample_num
1✔
101

102
    def __len__(self) -> int:
1✔
103
        return self.sample_num
1✔
104

105
    @staticmethod
1✔
106
    def check_input(
1✔
107
        X: Union[np.ndarray, torch.Tensor, list],
108
        y: Optional[Union[np.ndarray, torch.Tensor, list]] = None,
109
        out_dtype: str = "tensor",
110
    ) -> Tuple[
111
        Union[np.ndarray, torch.Tensor, list],
112
        Optional[Union[np.ndarray, torch.Tensor, list]],
113
    ]:
114
        """Check value type and shape of input X and y
115

116
        Parameters
117
        ----------
118
        X : array-like,
119
            Time-series data that must have a shape like [n_samples, expected_n_steps, expected_n_features].
120

121
        y : array-like, default=None
122
            Labels of time-series samples (X) that must have a shape like [n_samples] or [n_samples, n_classes].
123

124
        out_dtype : str, in ['tensor', 'ndarray'], default='tensor'
125
            Data type of the output, should be np.ndarray or torch.Tensor
126

127
        Returns
128
        -------
129
        X : array-like
130

131
        y : array-like
132
        """
133
        assert out_dtype in [
1✔
134
            "tensor",
135
            "ndarray",
136
        ], f'out_dtype should be "tensor" or "ndarray", but got {out_dtype}'
137

138
        is_list = isinstance(X, list)
1✔
139
        is_array = isinstance(X, np.ndarray)
1✔
140
        is_tensor = isinstance(X, torch.Tensor)
1✔
141
        assert is_tensor or is_array or is_list, TypeError(
1✔
142
            "X should be an instance of list/np.ndarray/torch.Tensor, "
143
            f"but got {type(X)}"
144
        )
145

146
        # convert the data type if in need
147
        if out_dtype == "tensor":
1✔
148
            if is_list:
1✔
149
                X = torch.tensor(X)
×
150
            elif is_array:
1✔
151
                X = torch.from_numpy(X)
1✔
152
            else:  # is tensor
153
                pass
×
154
        else:  # out_dtype is ndarray
155
            # convert to np.ndarray first for shape check
156
            if is_list:
×
157
                X = np.asarray(X)
×
158
            elif is_tensor:
×
159
                X = X.numpy()
×
160
            else:  # is ndarray
161
                pass
×
162

163
        # check the shape of X here
164
        X_shape = X.shape
1✔
165
        assert len(X_shape) == 3, (
1✔
166
            f"input should have 3 dimensions [n_samples, seq_len, n_features],"
167
            f"but got shape={X_shape}"
168
        )
169

170
        if y is not None:
1✔
171
            assert len(X) == len(y), (
1✔
172
                f"lengths of X and y must match, " f"but got f{len(X)} and {len(y)}"
173
            )
174
            if isinstance(y, torch.Tensor):
1✔
175
                y = y if out_dtype == "tensor" else y.numpy()
×
176
            elif isinstance(y, list):
1✔
177
                y = torch.tensor(y) if out_dtype == "tensor" else np.asarray(y)
×
178
            elif isinstance(y, np.ndarray):
1✔
179
                y = torch.from_numpy(y) if out_dtype == "tensor" else y
1✔
180
            else:
181
                raise TypeError(
×
182
                    "y should be an instance of list/np.ndarray/torch.Tensor, "
183
                    f"but got {type(y)}"
184
                )
185

186
        return X, y
1✔
187

188
    @abstractmethod
1✔
189
    def _fetch_data_from_array(self, idx: int) -> Iterable:
1✔
190
        """Fetch data from self.X if it is given.
191

192
        Parameters
193
        ----------
194
        idx : int,
195
            The index of the sample to be return.
196

197
        Returns
198
        -------
199
        sample : list,
200
            The collated data sample, a list including all necessary sample info.
201
        """
202

203
        X = self.X[idx]
1✔
204
        missing_mask = ~torch.isnan(X)
1✔
205
        X = torch.nan_to_num(X)
1✔
206
        sample = [
1✔
207
            torch.tensor(idx),
208
            X.to(torch.float32),
209
            missing_mask.to(torch.float32),
210
        ]
211

212
        if self.y is not None and self.return_labels:
1✔
213
            sample.append(self.y[idx].to(torch.long))
×
214

215
        return sample
1✔
216

217
    def _open_file_handle(self) -> h5py.File:
1✔
218
        """Open the file handle for reading data from the file.
219

220
        Notes
221
        -----
222
        This function can also help confirm if the given file and file type match.
223

224
        Returns
225
        -------
226
        file_handle : file
227

228
        """
229
        data_file_path = self.data
1✔
230
        try:
1✔
231
            file_handler = h5py.File(
1✔
232
                data_file_path,
233
                "r",
234
            )  # set swmr=True if the h5 file need to be written into new content during reading
235
        except ImportError:
×
236
            raise ImportError(
×
237
                "h5py is missing and cannot be imported. Please install it first."
238
            )
239
        except OSError as e:
×
240
            raise TypeError(
×
241
                f"{e} This probably is caused by file type error. "
242
                f"Please confirm that the given file {data_file_path} is an h5 file."
243
            )
244
        except Exception as e:
×
245
            raise RuntimeError(e)
×
246
        return file_handler
1✔
247

248
    @abstractmethod
1✔
249
    def _fetch_data_from_file(self, idx: int) -> Iterable:
1✔
250
        """Fetch data with the lazy-loading strategy, i.e. only loading data from the file while requesting for samples.
251
        Here the opened file handle doesn't load the entire dataset into RAM but only load the currently accessed slice.
252

253
        Notes
254
        -----
255
        Multi workers reading from h5 file is tricky, and I was confronted with a problem similar to
256
        https://discuss.pytorch.org/t/dataloader-when-num-worker-0-there-is-bug/25643/7 in 2020, please
257
        refer to it for more details about the problem.
258
        The implementation here is referred to
259
        https://discuss.pytorch.org/t/dataloader-when-num-worker-0-there-is-bug/25643/10
260
        And according to https://discuss.pytorch.org/t/dataloader-when-num-worker-0-there-is-bug/25643/37,
261
        pytorch v1.7.1 and h5py v3.2.0 work well, so probably updating to the latest version can avoid the
262
        issue I met. After all, this implementation may need to be updated in the near future.
263

264
        Parameters
265
        ----------
266
        idx : int,
267
            The index of the sample to be return.
268

269
        Returns
270
        -------
271
        sample : list,
272
            The collated data sample, a list including all necessary sample info.
273
        """
274

275
        if self.file_handle is None:
1✔
276
            self.file_handle = self._open_file_handle()
×
277

278
        X = torch.from_numpy(self.file_handle["X"][idx])
1✔
279
        missing_mask = ~torch.isnan(X)
1✔
280
        X = torch.nan_to_num(X)
1✔
281
        sample = [
1✔
282
            torch.tensor(idx),
283
            X.to(torch.float32),
284
            missing_mask.to(torch.float32),
285
        ]
286

287
        # if the dataset has labels and is for training, then fetch it from the file
288
        if "y" in self.file_handle.keys() and self.return_labels:
1✔
289
            sample.append(self.file_handle["y"][idx].to(torch.long))
×
290

291
        return sample
1✔
292

293
    def __getitem__(self, idx: int) -> Iterable:
1✔
294
        """Fetch data according to index.
295

296
        Parameters
297
        ----------
298
        idx : int,
299
            The index to fetch the specified sample.
300

301
        Returns
302
        -------
303
        sample : list,
304
            The collated data sample, a list including all necessary sample info.
305
        """
306

307
        sample = self.fetch_data(idx)
1✔
308
        return sample
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc