• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ContinualAI / avalanche / 9268694310

28 May 2024 11:45AM CUT coverage: 51.799% (-0.004%) from 51.803%
9268694310

Pull #1647

github

web-flow
Merge 2995a272c into 8f0e61f23
Pull Request #1647: Improved Efficiency of the DiskUsage Metric

18 of 30 new or added lines in 2 files covered. (60.0%)

2 existing lines in 2 files now uncovered.

15091 of 29134 relevant lines covered (51.8%)

0.52 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/avalanche/benchmarks/datasets/lvis_dataset/lvis_dataset.py
1
################################################################################
2
# Copyright (c) 2022 ContinualAI                                               #
3
# Copyrights licensed under the MIT License.                                   #
4
# See the accompanying LICENSE file for terms.                                 #
5
#                                                                              #
6
# Date: 18-02-2022                                                             #
7
# Author: Lorenzo Pellegrini                                                   #
8
#                                                                              #
9
# E-mail: contact@continualai.org                                              #
10
# Website: www.continualai.org                                                 #
11
################################################################################
12

13
""" LVIS PyTorch Object Detection Dataset """
14

15
from pathlib import Path
×
16
import dill
×
17
from typing import Optional, Union, List, Sequence, TypedDict
×
18

19
import torch
×
20
from PIL import Image
×
21
from torchvision.datasets.folder import default_loader
×
22
from torchvision.transforms import ToTensor
×
23

24
from avalanche.benchmarks.datasets import (
×
25
    DownloadableDataset,
26
    default_dataset_location,
27
)
28
from avalanche.benchmarks.datasets.lvis_dataset.lvis_data import lvis_archives
×
29
from avalanche.checkpointing import constructor_based_serialization
×
30

31
try:
×
32
    from lvis import LVIS
×
33
except ImportError:
×
34
    raise ModuleNotFoundError(
×
35
        "LVIS not found, if you want to use detection "
36
        "please install avalanche with the detection "
37
        "dependencies: "
38
        "pip install avalanche-lib[detection]"
39
    )
40

41

42
class LvisDataset(DownloadableDataset):
×
43
    """LVIS PyTorch Object Detection Dataset"""
44

45
    def __init__(
×
46
        self,
47
        root: Optional[Union[str, Path]] = None,
48
        *,
49
        train=True,
50
        transform=None,
51
        loader=default_loader,
52
        download=True,
53
        lvis_api: Optional[LVIS] = None,
54
        img_ids: Optional[List[int]] = None,
55
    ):
56
        """
57
        Creates an instance of the LVIS dataset.
58

59
        :param root: The directory where the dataset can be found or downloaded.
60
            Defaults to None, which means that the default location for
61
            "lvis" will be used.
62
        :param train: If True, the training set will be returned. If False,
63
            the test set will be returned.
64
        :param transform: The transformation to apply to (img, annotations)
65
            values.
66
        :param loader: The image loader to use.
67
        :param download: If True, the dataset will be downloaded if needed.
68
        :param lvis_api: An instance of the LVIS class (from the lvis-api) to
69
            use. Defaults to None, which means that annotations will be loaded
70
            from the annotation json found in the root directory.
71
        :param img_ids: A list representing a subset of images to use. Defaults
72
            to None, which means that the dataset will contain all images
73
            in the LVIS dataset.
74
        """
75

76
        if root is None:
×
77
            root = default_dataset_location("lvis")
×
78

79
        self.train = train  # training set or test set
×
80
        self.transform = transform
×
81
        self.loader = loader
×
82
        self.bbox_crop = True
×
83
        self.img_ids: List[int] = img_ids  # type: ignore
×
84

85
        self.targets: LVISDetectionTargets = None  # type: ignore
×
86
        self.lvis_api: LVIS = lvis_api  # type: ignore
×
87

88
        super(LvisDataset, self).__init__(root, download=download, verbose=True)
×
89

90
        self._load_dataset()
×
91

92
    def _download_dataset(self) -> None:
×
93
        data2download = lvis_archives
×
94

95
        for name, url, checksum in data2download:
×
96
            if self.verbose:
×
97
                print("Downloading " + name + "...")
×
98

99
            result_file = self._download_file(url, name, checksum)
×
100
            if self.verbose:
×
101
                print("Download completed. Extracting...")
×
102

103
            self._extract_archive(result_file)
×
104
            if self.verbose:
×
105
                print("Extraction completed!")
×
106

107
    def _load_metadata(self) -> bool:
×
108
        must_load_api = self.lvis_api is None
×
109
        must_load_img_ids = self.img_ids is None
×
110
        try:
×
111
            # Load metadata
112
            if must_load_api:
×
113
                if self.train:
×
114
                    ann_json_path = str(self.root / "lvis_v1_train.json")
×
115
                else:
116
                    ann_json_path = str(self.root / "lvis_v1_val.json")
×
117

118
                self.lvis_api = LVIS(ann_json_path)
×
119

120
            lvis_api = self.lvis_api
×
121
            if must_load_img_ids:
×
122
                self.img_ids = list(sorted(lvis_api.get_img_ids()))
×
123

124
            self.targets = LVISDetectionTargets(lvis_api, self.img_ids)
×
125

126
            # Try loading an image
127
            if len(self.img_ids) > 0:
×
128
                img_id = self.img_ids[0]
×
129
                img_dict: LVISImgEntry = self.lvis_api.load_imgs(ids=[img_id])[0]
×
130
                assert self._load_img(img_dict) is not None
×
131
        except BaseException:
×
132
            if must_load_api:
×
133
                self.lvis_api = None  # type: ignore
×
134
            if must_load_img_ids:
×
135
                self.img_ids = None  # type: ignore
×
136

137
            self.targets = None  # type: ignore
×
138
            raise
×
139

140
        return True
×
141

142
    def _download_error_message(self) -> str:
×
143
        return (
×
144
            "[LVIS] Error downloading the dataset. Consider "
145
            "downloading it manually at: https://www.lvisdataset.org/dataset"
146
            " and placing it in: " + str(self.root)
147
        )
148

149
    def __getitem__(self, index):
×
150
        """
151
        Loads an instance given its index.
152

153
        :param index: The index of the instance to retrieve.
154

155
        :return: a (sample, target) tuple where the target is a
156
            torchvision-style annotation for object detection
157
            https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html
158
        """
159
        img_id = self.img_ids[index]
×
160
        img_dict: LVISImgEntry = self.lvis_api.load_imgs(ids=[img_id])[0]
×
161
        annotation_dicts: LVISImgTargets = self.targets[index]
×
162

163
        # Transform from LVIS dictionary to torchvision-style target
164
        num_objs = annotation_dicts["bbox"].shape[0]
×
165

166
        boxes = []
×
167
        labels = []
×
168
        for i in range(num_objs):
×
169
            xmin = annotation_dicts["bbox"][i][0]
×
170
            ymin = annotation_dicts["bbox"][i][1]
×
171
            xmax = xmin + annotation_dicts["bbox"][i][2]
×
172
            ymax = ymin + annotation_dicts["bbox"][i][3]
×
173
            boxes.append([xmin, ymin, xmax, ymax])
×
174
            labels.append(annotation_dicts["category_id"][i])
×
175

176
        if len(boxes) > 0:
×
177
            boxes = torch.as_tensor(boxes, dtype=torch.float32)
×
178
        else:
179
            boxes = torch.empty((0, 4), dtype=torch.float32)
×
180
        labels = torch.as_tensor(labels, dtype=torch.int64)
×
181

182
        image_id = torch.tensor([img_id])
×
183
        areas = []
×
184
        for i in range(num_objs):
×
185
            areas.append(annotation_dicts["area"][i])
×
186
        areas = torch.as_tensor(areas, dtype=torch.float32)
×
187
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)
×
188

189
        target = dict()
×
190
        target["boxes"] = boxes
×
191
        target["labels"] = labels
×
192
        target["image_id"] = image_id
×
193
        target["area"] = areas
×
194
        target["iscrowd"] = iscrowd
×
195

196
        img = self._load_img(img_dict)
×
197

198
        if self.transform is not None:
×
199
            img, target = self.transform(img, target)
×
200

201
        return img, target
×
202

203
    def __len__(self):
×
204
        return len(self.img_ids)
×
205

206
    def _load_img(self, img_dict: "LVISImgEntry"):
×
207
        coco_url = img_dict["coco_url"]
×
208
        splitted_url = coco_url.split("/")
×
209
        img_path = splitted_url[-2] + "/" + splitted_url[-1]
×
210
        final_path = self.root / img_path  # <root>/train2017/<img_id>.jpg
×
211
        return self.loader(str(final_path))
×
212

213

214
@dill.register(LvisDataset)
×
215
def checkpoint_LvisDataset(pickler, obj: LvisDataset):
×
216
    constructor_based_serialization(
×
217
        pickler,
218
        obj,
219
        LvisDataset,
220
        deduplicate=True,
221
        kwargs=dict(
222
            root=obj.root,
223
            train=obj.train,
224
            transform=obj.transform,
225
            loader=obj.loader,
226
            lvis_api=obj.lvis_api,
227
            img_ids=obj.img_ids,
228
        ),
229
    )
230

231

232
class LVISImgEntry(TypedDict):
×
233
    id: int
×
234
    date_captured: str
×
235
    neg_category_ids: List[int]
×
236
    license: int
×
237
    height: int
×
238
    width: int
×
239
    flickr_url: str
×
240
    coco_url: str
×
241
    not_exhaustive_category_ids: List[int]
×
242

243

244
class LVISAnnotationEntry(TypedDict):
×
245
    id: int
×
246
    area: float
×
247
    segmentation: List[List[float]]
×
248
    image_id: int
×
249
    bbox: List[int]
×
250
    category_id: int
×
251

252

253
class LVISImgTargets(TypedDict):
×
254
    id: torch.Tensor
×
255
    area: torch.Tensor
×
256
    segmentation: List[List[List[float]]]
×
257
    image_id: torch.Tensor
×
258
    bbox: torch.Tensor
×
259
    category_id: torch.Tensor
×
260
    labels: torch.Tensor
×
261

262

263
class LVISDetectionTargets(Sequence[List[LVISImgTargets]]):
×
264
    def __init__(self, lvis_api: LVIS, img_ids: Optional[List[int]] = None):
×
265
        super(LVISDetectionTargets, self).__init__()
×
266
        self.lvis_api = lvis_api
×
267
        if img_ids is None:
×
268
            img_ids = list(sorted(lvis_api.get_img_ids()))
×
269

270
        self.img_ids = img_ids
×
271

272
    def __len__(self):
×
273
        return len(self.img_ids)
×
274

275
    def __getitem__(self, index):
×
276
        img_id = self.img_ids[index]
×
277
        annotation_ids = self.lvis_api.get_ann_ids(img_ids=[img_id])
×
278
        annotation_dicts: List[LVISAnnotationEntry] = self.lvis_api.load_anns(
×
279
            annotation_ids
280
        )
281

282
        n_annotations = len(annotation_dicts)
×
283

284
        category_tensor = torch.empty((n_annotations,), dtype=torch.long)
×
285
        target_dict: LVISImgTargets = {
×
286
            "bbox": torch.empty((n_annotations, 4), dtype=torch.float32),
287
            "category_id": category_tensor,
288
            "id": torch.empty((n_annotations,), dtype=torch.long),
289
            "area": torch.empty((n_annotations,), dtype=torch.float32),
290
            "image_id": torch.full((1,), img_id, dtype=torch.long),
291
            "segmentation": [],
292
            "labels": category_tensor,  # Alias of category_id
293
        }
294

295
        for ann_idx, annotation in enumerate(annotation_dicts):
×
296
            target_dict["bbox"][ann_idx] = torch.as_tensor(annotation["bbox"])
×
297
            target_dict["category_id"][ann_idx] = annotation["category_id"]
×
298
            target_dict["id"][ann_idx] = annotation["id"]
×
299
            target_dict["area"][ann_idx] = annotation["area"]
×
300
            target_dict["segmentation"].append(annotation["segmentation"])
×
301

302
        return target_dict
×
303

304

305
def _test_to_tensor(a, b):
×
306
    return ToTensor()(a), b
×
307

308

309
def _detection_collate_fn(batch):
×
310
    return tuple(zip(*batch))
×
311

312

313
def _plot_detection_sample(img: Image.Image, target):
×
314
    from matplotlib import patches
×
315
    import matplotlib.pyplot as plt
×
316

317
    plt.gca().imshow(img)
×
318
    for box in target["boxes"]:
×
319
        box = box.tolist()
×
320

321
        rect = patches.Rectangle(
×
322
            (box[0], box[1]),
323
            box[2] - box[0],
324
            box[3] - box[1],
325
            linewidth=1,
326
            edgecolor="r",
327
            facecolor="none",
328
        )
329
        plt.gca().add_patch(rect)
×
330

331

332
if __name__ == "__main__":
×
333
    # this little example script can be used to visualize the first image
334
    # loaded from the dataset.
335
    from torch.utils.data.dataloader import DataLoader
×
336
    import matplotlib.pyplot as plt
×
337
    from torchvision import transforms
×
338
    import torch
×
339

340
    train_data = LvisDataset(transform=_test_to_tensor)
×
341
    test_data = LvisDataset(transform=_test_to_tensor, train=False)
×
342
    print("train size: ", len(train_data))
×
343
    print("Test size: ", len(test_data))
×
344
    dataloader = DataLoader(train_data, batch_size=1, collate_fn=_detection_collate_fn)
×
345

346
    n_to_show = 5
×
347
    for instance_idx, batch_data in enumerate(dataloader):
×
348
        x, y = batch_data
×
349
        x = x[0]
×
350
        y = y[0]
×
351
        _plot_detection_sample(transforms.ToPILImage()(x), y)
×
352
        plt.show()
×
353
        print("X image shape", x.shape)
×
354
        print("N annotations:", len(y["boxes"]))
×
355
        if (instance_idx + 1) >= n_to_show:
×
356
            break
×
357

358
__all__ = [
×
359
    "LvisDataset",
360
    "LVISImgEntry",
361
    "LVISAnnotationEntry",
362
    "LVISImgTargets",
363
    "LVISDetectionTargets",
364
]
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc