• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ContinualAI / avalanche / 5399886876

pending completion
5399886876

Pull #1398

github

web-flow
Merge 2c8aba8e6 into a61ae5cab
Pull Request #1398: switch to black formatting

1023 of 1372 new or added lines in 177 files covered. (74.56%)

144 existing lines in 66 files now uncovered.

16366 of 22540 relevant lines covered (72.61%)

2.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

31.25
/avalanche/benchmarks/datasets/downloadable_dataset.py
1
################################################################################
2
# Copyright (c) 2021 ContinualAI.                                              #
3
# Copyrights licensed under the MIT License.                                   #
4
# See the accompanying LICENSE file for terms.                                 #
5
#                                                                              #
6
# Date: 25-05-2021                                                             #
7
# Author: Lorenzo Pellegrini                                                   #
8
# E-mail: contact@continualai.org                                              #
9
# Website: www.continualai.org                                                 #
10
################################################################################
11

12
from abc import abstractmethod, ABC
4✔
13
from pathlib import Path
4✔
14
from typing import TypeVar, Union, Optional
4✔
15

16
import shutil
4✔
17

18
import os
4✔
19
from torch.utils.data.dataset import Dataset
4✔
20
from torchvision.datasets.utils import (
4✔
21
    download_and_extract_archive,
22
    extract_archive,
23
    download_url,
24
    check_integrity,
25
)
26

27
from avalanche.benchmarks.datasets.dataset_utils import default_dataset_location
4✔
28

29
T_co = TypeVar("T_co", covariant=True)
4✔
30

31

32
class DownloadableDataset(Dataset[T_co], ABC):
4✔
33
    """Base class for a downloadable dataset.
4✔
34

35
    It is recommended to extend this class if a dataset can be downloaded from
36
    the internet. This implementation codes the recommended behaviour for
37
    downloading and verifying the dataset.
38

39
    The dataset child class must implement the `_download_dataset`,
40
    `_load_metadata` and `_download_error_message` methods
41

42
    The child class, in its constructor, must call the already implemented
43
    `_load_dataset` method (otherwise nothing will happen).
44

45
    A further simplification can be obtained by extending
46
    :class:`SimpleDownloadableDataset` instead of this class.
47
    :class:`SimpleDownloadableDataset` is recommended if a single archive is to
48
    be downloaded and extracted to the root folder "as is".
49

50
    The standardized procedure coded by `_load_dataset` is as follows:
51

52
    - First, `_load_metadata` is called to check if the dataset can be correctly
53
      loaded at the `root` path. This method must check if the data found
54
      at the `root` path is correct and that metadata can be correctly loaded.
55
      If this method succeeds (by returning True) the process is completed.
56
    - If `_load_metadata` fails (by returning False or by raising an error),
57
      then a download will be attempted if the download flag was set to True.
58
      The download must be implemented in `_download_dataset`. The
59
      procedure can be drastically simplified by using the `_download_file`,
60
      `_extract_archive` and `_download_and_extract_archive` helpers.
61
    - If the download succeeds (doesn't raise an error), then `_load_metadata`
62
      will be called again.
63

64
    If an error occurs, the `_download_error_message` will be called to obtain
65
    a message (a string) to show to the user. That message should contain
66
    instructions on how to download and prepare the dataset manually.
67
    """
68

69
    def __init__(
4✔
70
        self,
71
        root: Union[str, Path],
72
        download: bool = True,
73
        verbose: bool = False,
74
    ):
75
        """Creates an instance of a downloadable dataset.
76

77
        Consider looking at the class documentation for the precise details on
78
        how to extend this class.
79

80
        Beware that calling this constructor only fills the `root` field. The
81
        download and metadata loading procedures are triggered only by
82
        calling `_load_dataset`.
83

84
        :param root: The root path where the dataset will be downloaded.
85
            Consider passing a path obtained by calling
86
            `default_dataset_location` with the name of the dataset.
87
        :param download: If True, the dataset will be downloaded if needed.
88
            If False and the dataset can't be loaded from the provided root
89
            path, an error will be raised when calling the `_load_dataset`
90
            method. Defaults to True.
91
        :param verbose: If True, some info about the download process will be
92
            printed. Defaults to False.
93
        """
94

95
        super(DownloadableDataset, self).__init__()
×
96
        self.root: Path = Path(root)
×
97
        """
98
        The path to the dataset.
99
        """
100

101
        self.download: bool = download
×
102
        """
103
        If True, the dataset will be downloaded (only if needed).
104
        """
105

106
        self.verbose: bool = verbose
×
107
        """
×
108
        If True, some info about the download process will be printed.
109
        """
110

111
    def _load_dataset(self) -> None:
4✔
112
        """
113
        The standardized dataset download and load procedure.
114

115
        For more details on the coded procedure see the class documentation.
116

117
        This method shouldn't be overridden.
118

119
        This method will raise and error if the dataset couldn't be loaded
120
        or downloaded.
121

122
        :return: None
123
        """
124
        metadata_loaded = False
×
125
        metadata_load_error = None
×
126
        try:
×
127
            metadata_loaded = self._load_metadata()
×
128
        except Exception as e:
×
129
            metadata_load_error = e
×
130

131
        if metadata_loaded:
×
132
            if self.verbose:
×
133
                print("Files already downloaded and verified")
×
134
            return
×
135

136
        if not self.download:
×
137
            msg = (
×
138
                "Error loading dataset metadata (dataset download was "
139
                'not attempted as "download" is set to False)'
140
            )
141
            if metadata_load_error is None:
×
142
                raise RuntimeError(msg)
×
143
            else:
144
                print(msg)
×
145
                raise metadata_load_error
×
146

147
        try:
×
148
            self._download_dataset()
×
149
        except Exception as e:
×
150
            err_msg = self._download_error_message()
×
151
            print(err_msg, flush=True)
×
152
            raise e
×
153

154
        if not self._load_metadata():
×
155
            err_msg = self._download_error_message()
×
156
            print(err_msg)
×
157
            raise RuntimeError(
×
158
                "Error loading dataset metadata (... but the download "
159
                "procedure completed successfully)"
160
            )
161

162
    @abstractmethod
4✔
163
    def _download_dataset(self) -> None:
4✔
164
        """
165
        The download procedure.
166

167
        This procedure is called only if `_load_metadata` fails.
168

169
        This method must raise an error if the dataset can't be downloaded.
170

171
        Hints: don't re-invent the wheel! There are ready-to-use helper methods
172
        like `_download_and_extract_archive`, `_download_file` and
173
        `_extract_archive` that can be used.
174

175
        :return: None
176
        """
177
        pass
×
178

179
    @abstractmethod
4✔
180
    def _load_metadata(self) -> bool:
4✔
181
        """
182
        The dataset metadata loading procedure.
183

184
        This procedure is called at least once to load the dataset metadata.
185

186
        This procedure should return False if the dataset is corrupted or if it
187
        can't be loaded.
188

189
        :return: True if the dataset is not corrupted and could be successfully
190
        loaded.
191
        """
192
        pass
×
193

194
    @abstractmethod
4✔
195
    def _download_error_message(self) -> str:
4✔
196
        """
197
        Returns the error message hinting the user on how to download the
198
        dataset manually.
199

200
        :return: A string representing the message to show to the user.
201
        """
202
        pass
×
203

204
    def _cleanup_dataset_root(self):
4✔
205
        """
206
        Utility method that can be used to remove the dataset root directory.
207

208
        Can be useful if a cleanup is needed when downloading and extracting the
209
        dataset.
210

211
        This method will also re-create the root directory.
212

213
        :return: None
214
        """
215
        shutil.rmtree(self.root)
×
216
        self.root.mkdir(parents=True, exist_ok=True)
×
217

218
    def _download_file(self, url: str, file_name: str, checksum: Optional[str]) -> Path:
4✔
219
        """
220
        Utility method that can be used to download and verify a file.
221

222
        :param url: The download url.
223
        :param file_name: The name of the file to save. The file will be saved
224
            in the `root` with this name. Always fill this parameter.
225
            Don't pass a path! Pass a file name only!
226
        :param checksum: The MD5 hash to use when verifying the downloaded
227
            file. Can be None, in which case the check will be skipped.
228
            It is recommended to always fill this parameter.
229
        :return: The path to the downloaded file.
230
        """
231
        self.root.mkdir(parents=True, exist_ok=True)
×
232
        download_url(url, str(self.root), filename=file_name, md5=checksum)
×
233
        return self.root / file_name
×
234

235
    def _extract_archive(
4✔
236
        self,
237
        path: Union[str, Path],
238
        sub_directory: Optional[str] = None,
239
        remove_archive: bool = False,
240
    ) -> Path:
241
        """
242
        Utility method that can be used to extract an archive.
243

244
        :param path: The complete path to the archive (for instance obtained
245
            by calling `_download_file`).
246
        :param sub_directory: The name of the sub directory where to extract the
247
            archive. Can be None, which means that the archive will be extracted
248
            in the root. Beware that some archives already have a root directory
249
            inside of them, in which case it's probably better to use None here.
250
            Defaults to None.
251
        :param remove_archive: If True, the archive will be deleted after a
252
            successful extraction. Defaults to False.
253
        :return:
254
        """
255

256
        if sub_directory is None:
×
257
            extract_root = self.root
×
258
        else:
259
            extract_root = self.root / sub_directory
×
260

261
        extract_archive(
×
262
            str(path), to_path=str(extract_root), remove_finished=remove_archive
263
        )
264

265
        return extract_root
×
266

267
    def _download_and_extract_archive(
4✔
268
        self,
269
        url: str,
270
        file_name: str,
271
        checksum: Optional[str],
272
        sub_directory: Optional[str] = None,
273
        remove_archive: bool = False,
274
    ) -> Path:
275
        """
276
        Utility that downloads and extracts an archive.
277

278
        :param url: The download url.
279
        :param file_name: The name of the archive. The file will be saved
280
            in the `root` with this name. Always fill this parameter.
281
            Don't pass a path! Pass a file name only!
282
        :param checksum: The MD5 hash to use when verifying the downloaded
283
            archive. Can be None, in which case the check will be skipped.
284
            It is recommended to always fill this parameter.
285
        :param sub_directory: The name of the sub directory where to extract the
286
            archive. Can be None, which means that the archive will be extracted
287
            in the root. Beware that some archives already have a root directory
288
            inside of them, in which case it's probably better to use None here.
289
            Defaults to None.
290
        :param remove_archive: If True, the archive will be deleted after a
291
            successful extraction. Defaults to False.
292
        :return: The path to the extracted archive. If `sub_directory` is None,
293
            then this will be the `root` path.
294
        """
295
        if sub_directory is None:
×
296
            extract_root = self.root
×
297
        else:
298
            extract_root = self.root / sub_directory
×
299

300
        self.root.mkdir(parents=True, exist_ok=True)
×
301
        try:
×
302
            download_and_extract_archive(
×
303
                url,
304
                str(self.root),
305
                extract_root=str(extract_root),
306
                filename=file_name,
307
                md5=checksum,
308
                remove_finished=remove_archive,
309
            )
310
        except BaseException:
×
NEW
311
            print(
×
312
                "Error while downloading the dataset archive. "
313
                "The partially downloaded archive will be removed."
314
            )
315
            attempt_fpath = self.root / file_name
×
316
            attempt_fpath.unlink(missing_ok=True)
×
317
            raise
×
318

319
        return extract_root
×
320

321
    def _check_file(self, path: Union[str, Path], checksum: str) -> bool:
4✔
322
        """
323
        Utility method to check a file.
324

325
        :param path: The path to the file.
326
        :param checksum: The MD5 hash to use.
327
        :return: True if the MD5 hash of the file matched the given one.
328
        """
329
        return check_integrity(str(path), md5=checksum)
×
330

331

332
class SimpleDownloadableDataset(DownloadableDataset[T_co], ABC):
4✔
333
    """
4✔
334
    Base class for a downloadable dataset consisting of a single archive file.
335

336
    It is recommended to extend this class if a dataset can be downloaded from
337
    the internet as a single archive. For multi-file implementation or if
338
    a more fine-grained control is required, consider extending
339
    :class:`DownloadableDataset` instead.
340

341
    This is a simplified version of :class:`DownloadableDataset` where the
342
    following assumptions must hold:
343
    - The dataset is made of a single archive.
344
    - The archive must be extracted to the root folder "as is" (which means
345
        that no subdirectories must be created).
346

347
    The child class is only required to extend the `_load_metadata` method,
348
    which must check the dataset integrity and load the dataset metadata.
349

350
    Apart from that, the same assumptions of :class:`DownloadableDataset` hold.
351
    Remember to call the `_load_dataset` method in the child class constructor.
352
    """
353

354
    def __init__(
4✔
355
        self,
356
        root_or_dataset_name: Union[str, Path],
357
        url: str,
358
        checksum: Optional[str],
359
        download: bool = False,
360
        verbose: bool = False,
361
    ):
362
        """
363
        Creates an instance of a simple downloadable dataset.
364

365
        Consider looking at the class documentation for the precise details on
366
        how to extend this class.
367

368
        Beware that calling this constructor only fills the `root` field. The
369
        download and metadata loading procedures are triggered only by
370
        calling `_load_dataset`.
371

372
        :param root_or_dataset_name: The root path where the dataset will be
373
            downloaded. If a directory name is passed, then the root obtained by
374
            calling `default_dataset_location` will be used (recommended).
375
            To check if this parameter is a path, the constructor will check if
376
            it contains the '\' or '/' characters or if it is a Path instance.
377
        :param url: The url of the archive.
378
        :param checksum: The MD5 hash to use when verifying the downloaded
379
            archive. Can be None, in which case the check will be skipped.
380
            It is recommended to always fill this parameter.
381
        :param download: If True, the dataset will be downloaded if needed.
382
            If False and the dataset can't be loaded from the provided root
383
            path, an error will be raised when calling the `_load_dataset`
384
            method. Defaults to False.
385
        :param verbose: If True, some info about the download process will be
386
            printed. Defaults to False.
387
        """
388

389
        self.url = url
×
390
        self.checksum = checksum
×
391

392
        if (
×
393
            isinstance(root_or_dataset_name, Path)
394
            or "/" in root_or_dataset_name
395
            or "\\" in root_or_dataset_name
396
        ):
397
            root = Path(root_or_dataset_name)
×
398
        else:
399
            root = default_dataset_location(root_or_dataset_name)
×
400

401
        super(SimpleDownloadableDataset, self).__init__(
×
402
            root, download=download, verbose=verbose
403
        )
404

405
    def _download_dataset(self) -> None:
4✔
406
        filename = os.path.basename(self.url)
×
407
        self._download_and_extract_archive(
×
408
            self.url,
409
            filename,
410
            self.checksum,
411
            sub_directory=None,
412
            remove_archive=False,
413
        )
414

415
    def _download_error_message(self) -> str:
4✔
416
        return (
×
417
            "Error downloading the dataset. Consider downloading "
418
            "it manually at: " + self.url + " and placing it "
419
            "in: " + str(self.root)
420
        )
421

422

423
__all__ = ["DownloadableDataset", "SimpleDownloadableDataset"]
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc