• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ContinualAI / avalanche / 5725326611

pending completion
5725326611

push

github

web-flow
Merge pull request #1439 from lrzpellegrini/ffcv_support_pt2

FFCV support

500 of 806 new or added lines in 14 files covered. (62.03%)

1 existing line in 1 file now uncovered.

17477 of 23989 relevant lines covered (72.85%)

2.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

88.13
/tests/unit_tests_utils.py
1
import copy
4✔
2
import itertools
4✔
3
from os.path import expanduser
4✔
4

5
import os
4✔
6
import random
4✔
7
import torch
4✔
8
from PIL.Image import Image
4✔
9
from sklearn.datasets import make_blobs, make_classification
4✔
10
from sklearn.model_selection import train_test_split
4✔
11
import numpy as np
4✔
12
from torch.utils.data import TensorDataset, Dataset
4✔
13
from torch.utils.data.dataloader import DataLoader
4✔
14

15
from torchvision.datasets import MNIST
4✔
16
from torchvision.transforms import Compose, ToTensor
4✔
17

18
from avalanche.benchmarks import nc_benchmark
4✔
19
from avalanche.benchmarks.utils.detection_dataset import (
4✔
20
    make_detection_dataset,
21
)
22

23

24
# Environment variable used to skip some expensive tests that are very unlikely
25
# to break unless you touch their code directly (e.g. datasets).
26
FAST_TEST = False
4✔
27
if "FAST_TEST" in os.environ:
4✔
28
    FAST_TEST = os.environ["FAST_TEST"].lower() == "true"
4✔
29

30
# Environment variable used to update the metric pickles providing the ground
31
# truth for metric tests. If you change the metrics (names, x values, y
32
# values, ...) you may need to update them.
33
UPDATE_METRICS = False
4✔
34
if "UPDATE_METRICS" in os.environ:
4✔
35
    UPDATE_METRICS = os.environ["UPDATE_METRICS"].lower() == "true"
×
36

37
# print(f"UPDATE_METRICS: {UPDATE_METRICS}")
38

39

40
def is_github_action():
4✔
41
    """Check whether we are running in a Github action.
42

43
    We want to avoid some expensive operations (such as downloading data)
44
    inside the CI pipeline.
45
    """
46
    return "GITHUB_ACTION" in os.environ
×
47

48

49
def common_setups():
4✔
50
    # adapt_dataset_urls()
51
    pass
4✔
52

53

54
def load_benchmark(use_task_labels=False, fast_test=True):
4✔
55
    """
56
    Returns a NC Benchmark from a fake dataset of 10 classes, 5 experiences,
57
    2 classes per experience.
58
    """
59
    if fast_test:
4✔
60
        my_nc_benchmark = get_fast_benchmark(use_task_labels)
4✔
61
    else:
62
        mnist_train = MNIST(
×
63
            root=expanduser("~") + "/.avalanche/data/mnist/",
64
            train=True,
65
            download=True,
66
            transform=Compose([ToTensor()]),
67
        )
68

69
        mnist_test = MNIST(
×
70
            root=expanduser("~") + "/.avalanche/data/mnist/",
71
            train=False,
72
            download=True,
73
            transform=Compose([ToTensor()]),
74
        )
75
        my_nc_benchmark = nc_benchmark(
×
76
            mnist_train, mnist_test, 5, task_labels=use_task_labels, seed=1234
77
        )
78

79
    return my_nc_benchmark
4✔
80

81

82
def load_image_data():
4✔
83
    mnist_train = MNIST(
4✔
84
        root=expanduser("~") + "/.avalanche/data/mnist/",
85
        train=True,
86
        download=True,
87
        transform=Compose([ToTensor()]),
88
    )
89
    mnist_test = MNIST(
4✔
90
        root=expanduser("~") + "/.avalanche/data/mnist/",
91
        train=False,
92
        download=True,
93
        transform=Compose([ToTensor()]),
94
    )
95
    return mnist_train, mnist_test
4✔
96

97

98
image_data = None
4✔
99

100

101
def load_image_benchmark():
4✔
102
    """Returns a PyTorch image dataset of 10 classes."""
103
    global image_data
104

105
    if image_data is None:
4✔
106
        image_data = MNIST(
4✔
107
            root=expanduser("~") + "/.avalanche/data/mnist/",
108
            train=True,
109
            download=True,
110
        )
111
    return image_data
4✔
112

113

114
def load_tensor_benchmark():
4✔
115
    """Returns a PyTorch image dataset of 10 classes."""
116
    x = torch.rand(32, 10)
4✔
117
    y = torch.rand(32, 10)
4✔
118
    return TensorDataset(x, y)
4✔
119

120

121
def get_fast_benchmark(
4✔
122
    use_task_labels=False,
123
    shuffle=True,
124
    n_samples_per_class=100,
125
    n_classes=10,
126
    n_features=6,
127
    seed=None,
128
    train_transform=None,
129
    eval_transform=None,
130
):
131
    dataset = make_classification(
4✔
132
        n_samples=n_classes * n_samples_per_class,
133
        n_classes=n_classes,
134
        n_features=n_features,
135
        n_informative=6,
136
        n_redundant=0,
137
        random_state=seed,
138
    )
139

140
    X = torch.from_numpy(dataset[0]).float()
4✔
141
    y = torch.from_numpy(dataset[1]).long()
4✔
142

143
    train_X, test_X, train_y, test_y = train_test_split(
4✔
144
        X, y, train_size=0.6, shuffle=True, stratify=y, random_state=seed
145
    )
146

147
    train_dataset = TensorDataset(train_X, train_y)
4✔
148
    test_dataset = TensorDataset(test_X, test_y)
4✔
149
    my_nc_benchmark = nc_benchmark(
4✔
150
        train_dataset,
151
        test_dataset,
152
        5,
153
        task_labels=use_task_labels,
154
        shuffle=shuffle,
155
        train_transform=train_transform,
156
        eval_transform=eval_transform,
157
        seed=seed,
158
    )
159
    return my_nc_benchmark
4✔
160

161

162
class DummyImageDataset(Dataset):
4✔
163
    def __init__(self, n_elements=10000, n_classes=100):
4✔
164
        assert n_elements >= n_classes
4✔
165

166
        super().__init__()
4✔
167
        self.targets = list(range(n_classes))
4✔
168
        self.targets += [
4✔
169
            random.randint(0, n_classes - 1) for _ in range(n_elements - n_classes)
170
        ]
171

172
    def __getitem__(self, index):
4✔
173
        return (
×
174
            Image(),
175
            self.targets[index],
176
        )
177

178
    def __len__(self):
4✔
179
        return len(self.targets)
4✔
180

181

182
def load_experience_train_eval(experience, batch_size=32, num_workers=0):
4✔
183
    for x, y, t in DataLoader(
×
184
        experience.dataset.train(),
185
        batch_size=batch_size,
186
        num_workers=num_workers,
187
    ):
188
        break
×
189

190
    for x, y, t in DataLoader(
×
191
        experience.dataset.eval(),
192
        batch_size=batch_size,
193
        num_workers=num_workers,
194
    ):
195
        break
×
196

197

198
def get_device():
4✔
199
    if "USE_GPU" in os.environ:
4✔
200
        use_gpu = os.environ["USE_GPU"].lower() in ["true"]
4✔
201
    else:
202
        use_gpu = False
×
203
    print("Test on GPU:", use_gpu)
4✔
204
    if use_gpu:
4✔
205
        device = "cuda"
×
206
    else:
207
        device = "cpu"
4✔
208
    return device
4✔
209

210

211
def set_deterministic_run(seed=0):
4✔
212
    random.seed(seed)
4✔
213
    np.random.seed(seed)
4✔
214
    torch.manual_seed(seed)
4✔
215
    if torch.cuda.is_available():
4✔
216
        torch.cuda.manual_seed(seed)
×
217
        torch.backends.cudnn.enabled = True
×
218
        torch.backends.cudnn.benchmark = False
×
219
        torch.backends.cudnn.deterministic = True
×
220

221

222
class _DummyDetectionDataset:
4✔
223
    """
4✔
224
    A dataset that makes a defensive copy of the
225
    targets before returning them.
226

227
    Alas, many detection transformations, including the
228
    ones in the torchvision repository, modify bounding boxes
229
    (and other elements) in place.
230
    Luckly, images seem to be never modified in place.
231
    """
232

233
    def __init__(self, images, targets):
4✔
234
        self.images = images
4✔
235
        self.targets = targets
4✔
236

237
    def __len__(self):
4✔
238
        return len(self.images)
4✔
239

240
    def __getitem__(self, index):
4✔
241
        return self.images[index], copy.deepcopy(self.targets[index])
4✔
242

243

244
def get_fast_detection_datasets(
4✔
245
    n_images=30,
246
    max_elements_per_image=10,
247
    n_samples_per_class=20,
248
    n_classes=10,
249
    seed=None,
250
    image_size=64,
251
    n_test_images=5,
252
):
253
    if seed is not None:
4✔
NEW
254
        np.random.seed(seed)
×
NEW
255
        random.seed(seed)
×
256

257
    assert n_images * max_elements_per_image >= n_samples_per_class * n_classes
4✔
258
    assert n_test_images < n_images
4✔
259
    assert n_test_images > 0
4✔
260

261
    base_n_per_images = (n_samples_per_class * n_classes) // n_images
4✔
262
    additional_elements = (n_samples_per_class * n_classes) % n_images
4✔
263
    to_allocate = np.full(n_images, base_n_per_images)
4✔
264
    to_allocate[:additional_elements] += 1
4✔
265
    np.random.shuffle(to_allocate)
4✔
266
    classes_elements = np.repeat(np.arange(n_classes), n_samples_per_class)
4✔
267
    np.random.shuffle(classes_elements)
4✔
268

269
    import matplotlib.colors as mcolors
4✔
270

271
    forms = ["ellipse", "rectangle", "line", "arc"]
4✔
272
    colors = list(mcolors.TABLEAU_COLORS.values())
4✔
273
    combs = list(itertools.product(forms, colors))
4✔
274
    random.shuffle(combs)
4✔
275

276
    generated_images = []
4✔
277
    generated_targets = []
4✔
278
    for img_idx in range(n_images):
4✔
279
        n_to_allocate = to_allocate[img_idx]
4✔
280
        base_alloc_idx = to_allocate[:img_idx].sum()
4✔
281
        classes_to_instantiate = classes_elements[
4✔
282
            base_alloc_idx : base_alloc_idx + n_to_allocate
283
        ]
284

285
        _, _, clusters = make_blobs(
4✔
286
            n_to_allocate,
287
            n_features=2,
288
            centers=n_to_allocate,
289
            center_box=(0, image_size - 1),
290
            random_state=seed,
291
            return_centers=True,
292
        )
293

294
        from PIL import Image as ImageApi
4✔
295
        from PIL import ImageDraw
4✔
296

297
        im = ImageApi.new("RGB", (image_size, image_size))
4✔
298
        draw = ImageDraw.Draw(im)
4✔
299

300
        target = {
4✔
301
            "boxes": torch.zeros((n_to_allocate, 4), dtype=torch.float32),
302
            "labels": torch.zeros((n_to_allocate,), dtype=torch.long),
303
            "image_id": torch.full((1,), img_idx, dtype=torch.long),
304
            "area": torch.zeros((n_to_allocate,), dtype=torch.float32),
305
            "iscrowd": torch.zeros((n_to_allocate,), dtype=torch.long),
306
        }
307

308
        obj_sizes = np.random.uniform(
4✔
309
            low=image_size * 0.1 * 0.95,
310
            high=image_size * 0.1 * 1.05,
311
            size=(n_to_allocate,),
312
        )
313
        for center_idx, center in enumerate(clusters):
4✔
314
            obj_size = float(obj_sizes[center_idx])
4✔
315
            class_to_gen = classes_to_instantiate[center_idx]
4✔
316

317
            class_form, class_color = combs[class_to_gen]
4✔
318

319
            left = center[0] - obj_size
4✔
320
            top = center[1] - obj_size
4✔
321
            right = center[0] + obj_size
4✔
322
            bottom = center[1] + obj_size
4✔
323
            ltrb = (left, top, right, bottom)
4✔
324
            if class_form == "ellipse":
4✔
325
                draw.ellipse(ltrb, fill=class_color)
4✔
326
            elif class_form == "rectangle":
4✔
327
                draw.rectangle(ltrb, fill=class_color)
4✔
328
            elif class_form == "line":
4✔
329
                draw.line(ltrb, fill=class_color, width=max(1, int(obj_size * 0.25)))
4✔
330
            elif class_form == "arc":
4✔
331
                draw.arc(ltrb, fill=class_color, start=45, end=200)
4✔
332
            else:
NEW
333
                raise RuntimeError("Unsupported form")
×
334

335
            target["boxes"][center_idx] = torch.as_tensor(ltrb)
4✔
336
            target["labels"][center_idx] = class_to_gen
4✔
337
            target["area"][center_idx] = obj_size**2
4✔
338

339
        generated_images.append(np.array(im))
4✔
340
        generated_targets.append(target)
4✔
341
        im.close()
4✔
342

343
    test_indices = set(
4✔
344
        np.random.choice(n_images, n_test_images, replace=False).tolist()
345
    )
346
    train_images = [x for i, x in enumerate(generated_images) if i not in test_indices]
4✔
347
    test_images = [x for i, x in enumerate(generated_images) if i in test_indices]
4✔
348

349
    train_targets = [
4✔
350
        x for i, x in enumerate(generated_targets) if i not in test_indices
351
    ]
352
    test_targets = [x for i, x in enumerate(generated_targets) if i in test_indices]
4✔
353

354
    return make_detection_dataset(
4✔
355
        _DummyDetectionDataset(train_images, train_targets),
356
        targets=train_targets,
357
        task_labels=0,
358
    ), make_detection_dataset(
359
        _DummyDetectionDataset(test_images, test_targets),
360
        targets=test_targets,
361
        task_labels=0,
362
    )
363

364

365
__all__ = [
4✔
366
    "common_setups",
367
    "load_benchmark",
368
    "get_fast_benchmark",
369
    "load_experience_train_eval",
370
    "get_device",
371
    "set_deterministic_run",
372
    "get_fast_detection_datasets",
373
]
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc