• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ContinualAI / avalanche / 5399886876

pending completion
5399886876

Pull #1398

github

web-flow
Merge 2c8aba8e6 into a61ae5cab
Pull Request #1398: switch to black formatting

1023 of 1372 new or added lines in 177 files covered. (74.56%)

144 existing lines in 66 files now uncovered.

16366 of 22540 relevant lines covered (72.61%)

2.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

25.0
/avalanche/benchmarks/datasets/cub200/cub200.py
1
################################################################################
2
# Copyright (c) 2021 ContinualAI.                                              #
3
# Copyrights licensed under the MIT License.                                   #
4
# See the accompanying LICENSE file for terms.                                 #
5
#                                                                              #
6
# Date: 12-04-2021                                                             #
7
# Author: Lorenzo Pellegrini, Vincenzo Lomonaco                                #
8
# E-mail: contact@continualai.org                                              #
9
# Website: continualai.org                                                     #
10
################################################################################
11

12
"""
4✔
13
CUB200 Pytorch Dataset: Caltech-UCSD Birds-200-2011 (CUB-200-2011) is an
14
extended version of the CUB-200 dataset, with roughly double the number of
15
images per class and new part location annotations. For detailed information
16
about the dataset, please check the official website:
17
http://www.vision.caltech.edu/visipedia/CUB-200-2011.html.
18
"""
19

20
import csv
4✔
21
from pathlib import Path
4✔
22
from typing import Dict, List, Optional, Tuple, Union
4✔
23

24
import gdown
4✔
25
import os
4✔
26
from collections import OrderedDict
4✔
27
from torchvision.datasets.folder import default_loader
4✔
28

29
from avalanche.benchmarks.datasets import (
4✔
30
    default_dataset_location,
31
    DownloadableDataset,
32
)
33
from avalanche.benchmarks.utils import PathsDataset
4✔
34

35

36
class CUB200(PathsDataset, DownloadableDataset):
4✔
37
    """Basic CUB200 PathsDataset to be used as a standard PyTorch Dataset.
4✔
38
    A classic continual learning benchmark built on top of this dataset
39
    can be found in 'benchmarks.classic', while for more custom benchmark
40
    design please use the 'benchmarks.generators'."""
41

42
    images_folder = "CUB_200_2011/images"
4✔
43
    official_url = (
4✔
44
        "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/" "CUB_200_2011.tgz"
45
    )
46
    gdrive_url = (
4✔
47
        "https://drive.google.com/u/0/uc?id=" "1hbzc_P1FuxMkcabkgn9ZKinBwW683j45"
48
    )
49
    filename = "CUB_200_2011.tgz"
4✔
50
    tgz_md5 = "97eceeb196236b17998738112f37df78"
4✔
51

52
    def __init__(
4✔
53
        self,
54
        root: Optional[Union[str, Path]] = None,
55
        *,
56
        train=True,
57
        transform=None,
58
        target_transform=None,
59
        loader=default_loader,
60
        download=True
61
    ):
62
        """
63

64
        :param root: root dir where the dataset can be found or downloaded.
65
            Defaults to None, which means that the default location for
66
            'CUB_200_2011' will be used.
67
        :param train: train or test subset of the original dataset. Default
68
            to True.
69
        :param transform: eventual input data transformations to apply.
70
            Default to None.
71
        :param target_transform: eventual target data transformations to apply.
72
            Default to None.
73
        :param loader: method to load the data from disk. Default to
74
            torchvision default_loader.
75
        :param download: default set to True. If the data is already
76
            downloaded it will skip the download.
77
        """
78

79
        if root is None:
×
80
            root = default_dataset_location("CUB_200_2011")
×
81

82
        self.train: bool = train
×
83

84
        # Needed for disambiguating the type,
85
        # which is not the same in the base classes
86
        self.root: Path = Path(root)
×
87
        self._images: List[Tuple[str, int]]
×
88

NEW
89
        DownloadableDataset.__init__(self, root, download=download, verbose=True)
×
UNCOV
90
        self._load_dataset()
×
91

92
        PathsDataset.__init__(
×
93
            self,
94
            os.path.join(root, CUB200.images_folder),
95
            self._images,
96
            transform=transform,
97
            target_transform=target_transform,
98
            loader=loader,
99
        )
100

101
    def _download_dataset(self) -> None:
4✔
102
        try:
×
103
            self._download_and_extract_archive(
×
104
                CUB200.official_url, CUB200.filename, checksum=CUB200.tgz_md5
105
            )
106
        except Exception:
×
107
            if self.verbose:
×
108
                print(
×
109
                    "[CUB200] Direct download may no longer be possible, "
110
                    "will try GDrive."
111
                )
112

113
        filepath = self.root / self.filename
×
114
        gdown.download(self.gdrive_url, str(filepath), quiet=False)
×
115
        gdown.cached_download(self.gdrive_url, str(filepath), md5=self.tgz_md5)
×
116

117
        self._extract_archive(filepath)
×
118

119
    def _download_error_message(self) -> str:
4✔
120
        return (
×
121
            "[CUB200] Error downloading the dataset. Consider downloading "
122
            "it manually at: " + CUB200.official_url + " and placing it "
123
            "in: " + str(self.root)
124
        )
125

126
    def _load_metadata(self):
4✔
127
        """Main method to load the CUB200 metadata"""
128

129
        cub_dir = self.root / "CUB_200_2011"
×
130
        images_list: Dict[int, List] = OrderedDict()
×
131

132
        with open(str(cub_dir / "train_test_split.txt")) as csv_file:
×
133
            csv_reader = csv.reader(csv_file, delimiter=" ")
×
134
            for row in csv_reader:
×
135
                img_id = int(row[0])
×
136
                is_train_instance = int(row[1]) == 1
×
137
                if is_train_instance == self.train:
×
138
                    images_list[img_id] = []
×
139

140
        with open(str(cub_dir / "images.txt")) as csv_file:
×
141
            csv_reader = csv.reader(csv_file, delimiter=" ")
×
142
            for row in csv_reader:
×
143
                img_id = int(row[0])
×
144
                if img_id in images_list:
×
145
                    images_list[img_id].append(row[1])
×
146

147
        with open(str(cub_dir / "image_class_labels.txt")) as csv_file:
×
148
            csv_reader = csv.reader(csv_file, delimiter=" ")
×
149
            for row in csv_reader:
×
150
                img_id = int(row[0])
×
151
                if img_id in images_list:
×
152
                    # CUB starts counting classes from 1 ...
153
                    images_list[img_id].append(int(row[1]) - 1)
×
154

155
        with open(str(cub_dir / "bounding_boxes.txt")) as csv_file:
×
156
            csv_reader = csv.reader(csv_file, delimiter=" ")
×
157
            for row in csv_reader:
×
158
                img_id = int(row[0])
×
159
                if img_id in images_list:
×
160
                    box_cub = [int(float(x)) for x in row[1:]]
×
161
                    box_avl = [box_cub[1], box_cub[0], box_cub[3], box_cub[2]]
×
162
                    # PathsDataset accepts (top, left, height, width)
163
                    images_list[img_id].append(box_avl)
×
164

165
        images_tuples = []
×
166
        for _, img_tuple in images_list.items():
×
167
            images_tuples.append(tuple(img_tuple))
×
168
        self._images = images_tuples  # type: ignore
×
169

170
        # Integrity check
171
        for row_check in self._images:
×
172
            filepath = self.root / CUB200.images_folder / row_check[0]
×
173
            if not filepath.is_file():
×
174
                if self.verbose:
×
175
                    print("[CUB200] Error checking integrity of:", filepath)
×
176
                return False
×
177

178
        return True
×
179

180

181
if __name__ == "__main__":
4✔
182
    """Simple test that will start if you run this script directly"""
183

184
    import matplotlib.pyplot as plt
×
185

186
    dataset = CUB200(train=False, download=True)
×
187
    print("test data len:", len(dataset))
×
188
    img, _ = dataset[14]
×
189
    plt.imshow(img)
×
190
    plt.show()
×
191

192
    dataset = CUB200(train=True)
×
193
    print("train data len:", len(dataset))
×
194
    img, _ = dataset[700]
×
195
    plt.imshow(img)
×
196
    plt.show()
×
197

198

199
__all__ = ["CUB200"]
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc