• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ContinualAI / avalanche / 8098020118

29 Feb 2024 02:57PM UTC coverage: 51.806% (-12.4%) from 64.17%
8098020118

push

github

web-flow
Update test-coverage-coveralls.yml

14756 of 28483 relevant lines covered (51.81%)

0.52 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

7.85
/tests/benchmarks/utils/test_avalanche_classification_dataset.py
1
import unittest
1✔
2

3
from os.path import expanduser
1✔
4

5
import avalanche
1✔
6
from avalanche.benchmarks.datasets import default_dataset_location
1✔
7
from avalanche.benchmarks.utils.data import AvalancheDataset
1✔
8
from avalanche.models import SimpleMLP
1✔
9
from torch.optim import SGD
1✔
10
from torch.nn import CrossEntropyLoss
1✔
11
from avalanche.training.supervised import Naive
1✔
12
from avalanche.benchmarks.scenarios.deprecated.generators import dataset_benchmark
1✔
13
import PIL
1✔
14
import torch
1✔
15
from PIL import ImageChops
1✔
16
from PIL.Image import Image
1✔
17
from torch import Tensor
1✔
18
from torch.utils.data import TensorDataset, Subset, ConcatDataset, DataLoader
1✔
19
from torchvision.datasets import MNIST
1✔
20
from torchvision.transforms import (
1✔
21
    ToTensor,
22
    RandomCrop,
23
    ToPILImage,
24
    Compose,
25
    Lambda,
26
)
27
from typing import List
1✔
28

29
from avalanche.benchmarks.scenarios.deprecated.generic_benchmark_creation import (
1✔
30
    create_generic_benchmark_from_tensor_lists,
31
)
32
from avalanche.benchmarks.utils import (
1✔
33
    _make_taskaware_classification_dataset,
34
    _taskaware_classification_subset,
35
    _concat_taskaware_classification_datasets,
36
    _make_taskaware_tensor_classification_dataset,
37
)
38
from avalanche.benchmarks.utils.utils import (
1✔
39
    concat_datasets,
40
)
41
from avalanche.training.utils import load_all_dataset
1✔
42
import random
1✔
43

44
import numpy as np
1✔
45

46
from avalanche.benchmarks.utils.flat_data import (
1✔
47
    _flatdata_depth,
48
)
49
from avalanche.benchmarks.utils.classification_dataset import (
1✔
50
    TaskAwareClassificationDataset,
51
    _concat_taskaware_classification_datasets_sequentially,
52
)
53
from tests.unit_tests_utils import load_image_data
1✔
54

55

56
def pil_images_equal(img_a, img_b):
1✔
57
    diff = ImageChops.difference(img_a, img_b)
×
58

59
    return not diff.getbbox()
×
60

61

62
def zero_if_label_2(img_tensor: Tensor, class_label):
1✔
63
    if int(class_label) == 2:
×
64
        torch.full(img_tensor.shape, 0.0, out=img_tensor)
×
65

66
    return img_tensor, class_label
×
67

68

69
def get_mbatch(data, batch_size=5):
1✔
70
    dl = DataLoader(
×
71
        data, shuffle=False, batch_size=batch_size, collate_fn=data.collate_fn
72
    )
73
    return next(iter(dl))
×
74

75

76
class AvalancheDatasetTests(unittest.TestCase):
1✔
77
    def test_avalanche_dataset_multi_param_transform(self):
1✔
78
        dataset_mnist = MNIST(root=default_dataset_location("mnist"), download=True)
×
79

80
        ref_instance2_idx = None
×
81
        for instance_idx, (_, instance_y) in enumerate(dataset_mnist):
×
82
            if instance_y == 2:
×
83
                ref_instance2_idx = instance_idx
×
84
                break
×
85
        self.assertIsNotNone(ref_instance2_idx)
×
86

87
        ref_instance_idx = None
×
88
        for instance_idx, (_, instance_y) in enumerate(dataset_mnist):
×
89
            if instance_y != 2:
×
90
                ref_instance_idx = instance_idx
×
91
                break
×
92
        self.assertIsNotNone(ref_instance_idx)
×
93

94
        with self.assertWarns(avalanche.benchmarks.utils.ComposeMaxParamsWarning):
×
95
            dataset_transform = avalanche.benchmarks.utils.MultiParamCompose(
×
96
                [ToTensor(), zero_if_label_2]
97
            )
98

99
        self.assertEqual(1, dataset_transform.min_params)
×
100
        self.assertEqual(2, dataset_transform.max_params)
×
101

102
        tgs = {"train": dataset_transform, "eval": dataset_transform}
×
103
        x, y = dataset_mnist[ref_instance_idx]
×
104
        dataset = _make_taskaware_classification_dataset(
×
105
            dataset_mnist, transform_groups=tgs
106
        )
107
        x2, y2, t2 = dataset[ref_instance_idx]
×
108

109
        self.assertIsInstance(x2, Tensor)
×
110
        self.assertIsInstance(y2, int)
×
111
        self.assertIsInstance(t2, int)
×
112
        self.assertEqual(0, t2)
×
113
        self.assertTrue(torch.equal(ToTensor()(x), x2))
×
114
        self.assertEqual(y, y2)
×
115

116
        # Check that the multi-param transform was correctly called
117
        x3, y3, _ = dataset[ref_instance2_idx]
×
118

119
        self.assertEqual(2, y3)
×
120
        self.assertIsInstance(x3, Tensor)
×
121
        self.assertEqual(0.0, torch.min(x3))
×
122
        self.assertEqual(0.0, torch.max(x3))
×
123

124
    def test_avalanche_dataset_tensor_task_labels(self):
1✔
125
        x = torch.rand(32, 10)
×
126
        y = torch.rand(32, 10)
×
127
        t = torch.ones(32)  # Single task
×
128
        dataset = _make_taskaware_classification_dataset(
×
129
            TensorDataset(x, y), targets=1, task_labels=t
130
        )
131

132
        x2, y2, t2 = get_mbatch(dataset, batch_size=32)
×
133

134
        self.assertIsInstance(x2, Tensor)
×
135
        self.assertIsInstance(y2, Tensor)
×
136
        self.assertIsInstance(t2, Tensor)
×
137
        self.assertTrue(torch.equal(x, x2))
×
138
        self.assertTrue(torch.equal(y, y2))
×
139
        self.assertTrue(torch.equal(t.to(int), t2))
×
140

141
        self.assertListEqual([1] * 32, list(dataset.targets_task_labels))
×
142

143
        # Regression test for #654
144
        self.assertEqual(1, len(dataset.task_set))
×
145

146
        subset_task1 = dataset.task_set[1]
×
147
        self.assertIsInstance(subset_task1, TaskAwareClassificationDataset)
×
148
        self.assertEqual(len(dataset), len(subset_task1))
×
149

150
        with self.assertRaises(KeyError):
×
151
            subset_task0 = dataset.task_set[0]
×
152

153
        with self.assertRaises(KeyError):
×
154
            subset_task0 = dataset.task_set[2]
×
155

156
        # Check single instance types
157
        x2, y2, t2 = dataset[0]
×
158

159
        self.assertIsInstance(x2, Tensor)
×
160
        self.assertIsInstance(y2, Tensor)
×
161
        self.assertIsInstance(t2, int)
×
162

163
    def test_avalanche_dataset_uniform_task_labels_simple_def(self):
1✔
164
        dataset_mnist = MNIST(
×
165
            root=expanduser("~") + "/.avalanche/data/mnist/", download=True
166
        )
167
        dataset = _make_taskaware_classification_dataset(
×
168
            dataset_mnist, transform=ToTensor(), task_labels=1
169
        )
170
        _, _, t2 = dataset[0]
×
171

172
        self.assertIsInstance(t2, int)
×
173
        self.assertEqual(1, t2)
×
174

175
        self.assertListEqual(
×
176
            [1] * len(dataset_mnist), list(dataset.targets_task_labels)
177
        )
178

179
        subset_task1 = dataset.task_set[1]
×
180
        self.assertIsInstance(subset_task1, TaskAwareClassificationDataset)
×
181
        self.assertEqual(len(dataset), len(subset_task1))
×
182

183
        with self.assertRaises(KeyError):
×
184
            subset_task0 = dataset.task_set[0]
×
185

186
    def test_avalanche_dataset_mixed_task_labels(self):
1✔
187
        dataset_mnist = MNIST(
×
188
            root=expanduser("~") + "/.avalanche/data/mnist/", download=True
189
        )
190
        x, y = dataset_mnist[0]
×
191

192
        random_task_labels = [random.randint(0, 10) for _ in range(len(dataset_mnist))]
×
193
        dataset = _make_taskaware_classification_dataset(
×
194
            dataset_mnist, transform=ToTensor(), task_labels=random_task_labels
195
        )
196
        x2, y2, t2 = dataset[0]
×
197

198
        self.assertIsInstance(x2, Tensor)
×
199
        self.assertIsInstance(y2, int)
×
200
        self.assertIsInstance(t2, int)
×
201
        self.assertEqual(random_task_labels[0], t2)
×
202
        self.assertTrue(torch.equal(ToTensor()(x), x2))
×
203
        self.assertEqual(y, y2)
×
204

205
        self.assertListEqual(random_task_labels, list(dataset.targets_task_labels))
×
206

207
        u_labels, counts = np.unique(random_task_labels, return_counts=True)
×
208
        for i, task_label in enumerate(u_labels.tolist()):
×
209
            subset_task = dataset.task_set[task_label]
×
210
            self.assertIsInstance(subset_task, TaskAwareClassificationDataset)
×
211
            self.assertEqual(int(counts[i]), len(subset_task))
×
212

213
            unique_task_labels = list(subset_task.targets_task_labels)
×
214
            self.assertListEqual([task_label] * int(counts[i]), unique_task_labels)
×
215

216
        with self.assertRaises(KeyError):
×
217
            subset_task11 = dataset.task_set[11]
×
218

219
    def test_avalanche_tensor_dataset_task_labels_train(self):
1✔
220
        tr_ds = [
×
221
            _make_taskaware_tensor_classification_dataset(
222
                torch.randn(10, 4),
223
                torch.randint(0, 3, (10,)),
224
                task_labels=torch.randint(0, 5, (10,)).tolist(),
225
            )
226
            for i in range(3)
227
        ]
228
        ts_ds = [
×
229
            _make_taskaware_tensor_classification_dataset(
230
                torch.randn(10, 4),
231
                torch.randint(0, 3, (10,)),
232
                task_labels=torch.randint(0, 5, (10,)).tolist(),
233
            )
234
            for i in range(3)
235
        ]
236
        benchmark = dataset_benchmark(train_datasets=tr_ds, test_datasets=ts_ds)
×
237
        model = SimpleMLP(input_size=4, num_classes=3)
×
238
        cl_strategy = Naive(
×
239
            model,
240
            SGD(model.parameters(), lr=0.001, momentum=0.9),
241
            CrossEntropyLoss(),
242
            train_mb_size=5,
243
            train_epochs=1,
244
            eval_mb_size=5,
245
            device="cpu",
246
            evaluator=None,
247
        )
248
        exp = []
×
249
        for i, experience in enumerate(benchmark.train_stream):
×
250
            exp.append(i)
×
251
            cl_strategy.train(experience)
×
252
        self.assertEqual(len(exp), 3)
×
253

254
    def test_avalanche_dataset_task_labels_inheritance(self):
1✔
255
        dataset_mnist = MNIST(
×
256
            root=expanduser("~") + "/.avalanche/data/mnist/", download=True
257
        )
258
        random_task_labels = [random.randint(0, 10) for _ in range(len(dataset_mnist))]
×
259
        dataset_orig = _make_taskaware_classification_dataset(
×
260
            dataset_mnist, transform=ToTensor(), task_labels=random_task_labels
261
        )
262

263
        dataset_child = _make_taskaware_classification_dataset(dataset_orig)
×
264
        x2, y2, t2 = dataset_orig[0]
×
265
        x3, y3, t3 = dataset_child[0]
×
266

267
        self.assertIsInstance(t2, int)
×
268
        self.assertEqual(random_task_labels[0], t2)
×
269

270
        self.assertIsInstance(t3, int)
×
271
        self.assertEqual(random_task_labels[0], t3)
×
272

273
        self.assertListEqual(random_task_labels, list(dataset_orig.targets_task_labels))
×
274

275
        self.assertListEqual(
×
276
            random_task_labels, list(dataset_child.targets_task_labels)
277
        )
278

279
    def test_avalanche_dataset_tensor_dataset_input(self):
1✔
280
        train_x = torch.rand(500, 3, 28, 28)
×
281
        train_y = torch.zeros(500)
×
282
        test_x = torch.rand(200, 3, 28, 28)
×
283
        test_y = torch.ones(200)
×
284

285
        train = TensorDataset(train_x, train_y)
×
286
        test = TensorDataset(test_x, test_y)
×
287
        train_dataset = _make_taskaware_classification_dataset(train)
×
288
        test_dataset = _make_taskaware_classification_dataset(test)
×
289

290
        self.assertEqual(500, len(train_dataset))
×
291
        self.assertEqual(200, len(test_dataset))
×
292

293
        x, y, t = train_dataset[0]
×
294
        self.assertIsInstance(x, Tensor)
×
295
        self.assertEqual(0, y)
×
296
        self.assertEqual(0, t)
×
297

298
        x2, y2, t2 = test_dataset[0]
×
299
        self.assertIsInstance(x2, Tensor)
×
300
        self.assertEqual(1, y2)
×
301
        self.assertEqual(0, t2)
×
302

303
    def test_avalanche_dataset_multiple_outputs_and_float_y(self):
1✔
304
        train_x = torch.rand(500, 3, 28, 28)
×
305
        train_y = torch.zeros(500)
×
306
        train_z = torch.ones(500)
×
307
        test_x = torch.rand(200, 3, 28, 28)
×
308
        test_y = torch.ones(200)
×
309
        test_z = torch.full((200,), 5)
×
310

311
        train = TensorDataset(train_x, train_y, train_z)
×
312
        test = TensorDataset(test_x, test_y, test_z)
×
313
        train_dataset = _make_taskaware_classification_dataset(train)
×
314
        test_dataset = _make_taskaware_classification_dataset(test)
×
315

316
        self.assertEqual(500, len(train_dataset))
×
317
        self.assertEqual(200, len(test_dataset))
×
318

319
        x, y, z, t = train_dataset[0]
×
320
        self.assertIsInstance(x, Tensor)
×
321
        self.assertEqual(0, y)
×
322
        self.assertEqual(1, z)
×
323
        self.assertEqual(0, t)
×
324

325
        x2, y2, z2, t2 = test_dataset[0]
×
326
        self.assertIsInstance(x2, Tensor)
×
327
        self.assertEqual(1, y2)
×
328
        self.assertEqual(5, z2)
×
329
        self.assertEqual(0, t2)
×
330

331
    def test_avalanche_concat_dataset_targets_val_to_idx(self):
1✔
332
        tensor_x = torch.rand(100, 3, 28, 28)
×
333
        tensor_x2 = torch.rand(11, 3, 28, 28)
×
334
        tensor_y = torch.randint(0, 10, (100,))
×
335
        tensor_y2 = torch.randint(0, 10, (11,))
×
336
        dataset1 = TensorDataset(tensor_x, tensor_y)
×
337
        dataset2 = TensorDataset(tensor_x2, tensor_y2)
×
338
        concat = dataset1 + dataset2
×
339
        av_dataset = _make_taskaware_classification_dataset(concat)
×
340
        self.assertIsInstance(av_dataset.targets.val_to_idx[0], list)
×
341
        self.assertEqual(10, len(av_dataset.targets.val_to_idx))
×
342

343
    def test_avalanche_dataset_from_pytorch_subset(self):
1✔
344
        tensor_x = torch.rand(500, 3, 28, 28)
×
345
        tensor_y = torch.randint(0, 100, (500,))
×
346

347
        whole_dataset = TensorDataset(tensor_x, tensor_y)
×
348

349
        train = Subset(whole_dataset, indices=list(range(400)))
×
350
        test = Subset(whole_dataset, indices=list(range(400, 500)))
×
351

352
        train_dataset = _make_taskaware_classification_dataset(train)
×
353
        test_dataset = _make_taskaware_classification_dataset(test)
×
354

355
        self.assertEqual(400, len(train_dataset))
×
356
        self.assertEqual(100, len(test_dataset))
×
357

358
        x, y, t = train_dataset[0]
×
359
        self.assertIsInstance(x, Tensor)
×
360
        self.assertTrue(torch.equal(tensor_x[0], x))
×
361
        self.assertTrue(torch.equal(tensor_y[0], y))
×
362
        self.assertEqual(0, t)
×
363

364
        self.assertTrue(
×
365
            torch.equal(torch.as_tensor(train_dataset.targets), tensor_y[:400])
366
        )
367

368
        x2, y2, t2 = test_dataset[0]
×
369
        self.assertIsInstance(x2, Tensor)
×
370
        self.assertTrue(torch.equal(tensor_x[400], x2))
×
371
        self.assertTrue(torch.equal(tensor_y[400], y2))
×
372
        self.assertEqual(0, t2)
×
373

374
        self.assertTrue(
×
375
            torch.equal(torch.as_tensor(test_dataset.targets), tensor_y[400:])
376
        )
377

378
    def test_avalanche_dataset_from_pytorch_concat_dataset(self):
1✔
379
        tensor_x = torch.rand(500, 3, 28, 28)
×
380
        tensor_x2 = torch.rand(300, 3, 28, 28)
×
381
        tensor_y = torch.randint(0, 100, (500,))
×
382
        tensor_y2 = torch.randint(0, 100, (300,))
×
383

384
        dataset1 = TensorDataset(tensor_x, tensor_y)
×
385
        dataset2 = TensorDataset(tensor_x2, tensor_y2)
×
386

387
        concat_dataset = ConcatDataset((dataset1, dataset2))
×
388

389
        av_dataset = _make_taskaware_classification_dataset(concat_dataset)
×
390

391
        self.assertEqual(500, len(dataset1))
×
392
        self.assertEqual(300, len(dataset2))
×
393

394
        x, y, t = av_dataset[0]
×
395
        x2, y2, t2 = av_dataset[500]
×
396
        self.assertIsInstance(x, Tensor)
×
397
        self.assertTrue(torch.equal(tensor_x[0], x))
×
398
        self.assertTrue(torch.equal(tensor_y[0], y))
×
399
        self.assertEqual(0, t)
×
400

401
        self.assertIsInstance(x2, Tensor)
×
402
        self.assertTrue(torch.equal(tensor_x2[0], x2))
×
403
        self.assertTrue(torch.equal(tensor_y2[0], y2))
×
404
        self.assertEqual(0, t2)
×
405

406
        self.assertTrue(
×
407
            torch.equal(
408
                torch.as_tensor(av_dataset.targets),
409
                torch.cat((tensor_y, tensor_y2)),
410
            )
411
        )
412

413
    def test_avalanche_dataset_from_chained_pytorch_concat_dataset(self):
1✔
414
        tensor_x = torch.rand(500, 3, 28, 28)
×
415
        tensor_x2 = torch.rand(300, 3, 28, 28)
×
416
        tensor_x3 = torch.rand(200, 3, 28, 28)
×
417
        tensor_y = torch.randint(0, 100, (500,))
×
418
        tensor_y2 = torch.randint(0, 100, (300,))
×
419
        tensor_y3 = torch.randint(0, 100, (200,))
×
420

421
        dataset1 = TensorDataset(tensor_x, tensor_y)
×
422
        dataset2 = TensorDataset(tensor_x2, tensor_y2)
×
423
        dataset3 = TensorDataset(tensor_x3, tensor_y3)
×
424

425
        concat_dataset = ConcatDataset((dataset1, dataset2))
×
426
        concat_dataset2 = ConcatDataset((concat_dataset, dataset3))
×
427

428
        av_dataset = _make_taskaware_classification_dataset(concat_dataset2)
×
429

430
        self.assertEqual(500, len(dataset1))
×
431
        self.assertEqual(300, len(dataset2))
×
432

433
        x, y, t = av_dataset[0]
×
434
        x2, y2, t2 = av_dataset[500]
×
435
        x3, y3, t3 = av_dataset[800]
×
436
        self.assertIsInstance(x, Tensor)
×
437
        self.assertTrue(torch.equal(tensor_x[0], x))
×
438
        self.assertTrue(torch.equal(tensor_y[0], y))
×
439
        self.assertEqual(0, t)
×
440

441
        self.assertIsInstance(x2, Tensor)
×
442
        self.assertTrue(torch.equal(tensor_x2[0], x2))
×
443
        self.assertTrue(torch.equal(tensor_y2[0], y2))
×
444
        self.assertEqual(0, t2)
×
445

446
        self.assertIsInstance(x3, Tensor)
×
447
        self.assertTrue(torch.equal(tensor_x3[0], x3))
×
448
        self.assertTrue(torch.equal(tensor_y3[0], y3))
×
449
        self.assertEqual(0, t3)
×
450

451
        self.assertTrue(
×
452
            torch.equal(
453
                torch.as_tensor(av_dataset.targets),
454
                torch.cat((tensor_y, tensor_y2, tensor_y3)),
455
            )
456
        )
457

458
    def test_avalanche_dataset_from_chained_pytorch_subsets(self):
1✔
459
        tensor_x = torch.rand(500, 3, 28, 28)
×
460
        tensor_y = torch.randint(0, 100, (500,))
×
461

462
        whole_dataset = TensorDataset(tensor_x, tensor_y)
×
463

464
        subset1 = Subset(whole_dataset, indices=list(range(400, 500)))
×
465
        subset2 = Subset(subset1, indices=[5, 7, 0])
×
466

467
        dataset = _make_taskaware_classification_dataset(subset2)
×
468

469
        self.assertEqual(3, len(dataset))
×
470

471
        x, y, t = dataset[0]
×
472
        self.assertIsInstance(x, Tensor)
×
473
        self.assertTrue(torch.equal(tensor_x[405], x))
×
474
        self.assertTrue(torch.equal(tensor_y[405], y))
×
475
        self.assertEqual(0, t)
×
476

477
        self.assertTrue(
×
478
            torch.equal(
479
                torch.as_tensor(dataset.targets),
480
                torch.as_tensor([tensor_y[405], tensor_y[407], tensor_y[400]]),
481
            )
482
        )
483

484
    def test_avalanche_dataset_from_chained_pytorch_concat_subset_dataset(self):
1✔
485
        tensor_x = torch.rand(200, 3, 28, 28)
×
486
        tensor_x2 = torch.rand(100, 3, 28, 28)
×
487
        tensor_y = torch.randint(0, 100, (200,))
×
488
        tensor_y2 = torch.randint(0, 100, (100,))
×
489

490
        dataset1 = TensorDataset(tensor_x, tensor_y)
×
491
        dataset2 = TensorDataset(tensor_x2, tensor_y2)
×
492

493
        indices = [random.randint(0, 299) for _ in range(1000)]
×
494

495
        concat_dataset = ConcatDataset((dataset1, dataset2))
×
496
        subset = Subset(concat_dataset, indices)
×
497

498
        av_dataset = _make_taskaware_classification_dataset(subset)
×
499

500
        self.assertEqual(200, len(dataset1))
×
501
        self.assertEqual(100, len(dataset2))
×
502
        self.assertEqual(1000, len(av_dataset))
×
503

504
        for idx in range(1000):
×
505
            orig_idx = indices[idx]
×
506
            if orig_idx < 200:
×
507
                expected_x, expected_y = dataset1[orig_idx]
×
508
            else:
509
                expected_x, expected_y = dataset2[orig_idx - 200]
×
510

511
            x, y, t = av_dataset[idx]
×
512
            self.assertIsInstance(x, Tensor)
×
513
            self.assertTrue(torch.equal(expected_x, x))
×
514
            self.assertTrue(torch.equal(expected_y, y))
×
515
            self.assertEqual(0, t)
×
516
            self.assertEqual(int(expected_y), int(av_dataset.targets[idx]))
×
517

518
    def test_avalanche_dataset_collate_fn(self):
1✔
519
        tensor_x = torch.rand(500, 3, 28, 28)
×
520
        tensor_y = torch.randint(0, 100, (500,))
×
521
        tensor_z = torch.randint(0, 100, (500,))
×
522

523
        def my_collate_fn(patterns):
×
524
            x_values = torch.stack([pat[0] for pat in patterns], 0)
×
525
            y_values = torch.tensor([pat[1] for pat in patterns]) + 1
×
526
            z_values = torch.tensor([-1 for _ in patterns])
×
527
            t_values = torch.tensor([pat[3] for pat in patterns])
×
528
            return x_values, y_values, z_values, t_values
×
529

530
        whole_dataset = TensorDataset(tensor_x, tensor_y, tensor_z)
×
531
        dataset = _make_taskaware_classification_dataset(
×
532
            whole_dataset, collate_fn=my_collate_fn
533
        )
534

535
        x, y, z, t = dataset[0]
×
536
        self.assertIsInstance(x, Tensor)
×
537
        self.assertTrue(torch.equal(tensor_x[0], x))
×
538
        self.assertTrue(torch.equal(tensor_y[0], y))
×
539
        self.assertEqual(0, t)
×
540

541
        x2, y2, z2, t2 = get_mbatch(dataset)
×
542
        self.assertIsInstance(x2, Tensor)
×
543
        self.assertTrue(torch.equal(tensor_x[0:5], x2))
×
544
        self.assertTrue(torch.equal(tensor_y[0:5] + 1, y2))
×
545
        self.assertTrue(torch.equal(torch.full((5,), -1, dtype=torch.long), z2))
×
546
        self.assertTrue(torch.equal(torch.zeros(5, dtype=torch.long), t2))
×
547

548
        inherited = _make_taskaware_classification_dataset(dataset)
×
549

550
        x3, y3, z3, t3 = get_mbatch(inherited)
×
551
        self.assertIsInstance(x3, Tensor)
×
552
        self.assertTrue(torch.equal(tensor_x[0:5], x3))
×
553
        self.assertTrue(torch.equal(tensor_y[0:5] + 1, y3))
×
554
        self.assertTrue(torch.equal(torch.full((5,), -1, dtype=torch.long), z3))
×
555
        self.assertTrue(torch.equal(torch.zeros(5, dtype=torch.long), t3))
×
556

557
    def test_avalanche_dataset_collate_fn_inheritance(self):
1✔
558
        tensor_x = torch.rand(200, 3, 28, 28)
×
559
        tensor_y = torch.randint(0, 100, (200,))
×
560
        tensor_z = torch.randint(0, 100, (200,))
×
561

562
        def my_collate_fn(patterns):
×
563
            x_values = torch.stack([pat[0] for pat in patterns], 0)
×
564
            y_values = torch.tensor([pat[1] for pat in patterns]) + 1
×
565
            z_values = torch.tensor([-1 for _ in patterns])
×
566
            t_values = torch.tensor([pat[3] for pat in patterns])
×
567
            return x_values, y_values, z_values, t_values
×
568

569
        def my_collate_fn2(patterns):
×
570
            x_values = torch.stack([pat[0] for pat in patterns], 0)
×
571
            y_values = torch.tensor([pat[1] for pat in patterns]) + 2
×
572
            z_values = torch.tensor([-2 for _ in patterns])
×
573
            t_values = torch.tensor([pat[3] for pat in patterns])
×
574
            return x_values, y_values, z_values, t_values
×
575

576
        whole_dataset = TensorDataset(tensor_x, tensor_y, tensor_z)
×
577
        dataset = _make_taskaware_classification_dataset(
×
578
            whole_dataset, collate_fn=my_collate_fn
579
        )
580
        inherited = _make_taskaware_classification_dataset(
×
581
            dataset, collate_fn=my_collate_fn2
582
        )  # Ok
583

584
        x, y, z, t = get_mbatch(inherited)
×
585
        self.assertIsInstance(x, Tensor)
×
586
        self.assertTrue(torch.equal(tensor_x[0:5], x))
×
587
        self.assertTrue(torch.equal(tensor_y[0:5] + 2, y))
×
588
        self.assertTrue(torch.equal(torch.full((5,), -2, dtype=torch.long), z))
×
589
        self.assertTrue(torch.equal(torch.zeros(5, dtype=torch.long), t))
×
590

591
        classification_dataset = _make_taskaware_classification_dataset(whole_dataset)
×
592

593
        ok_inherited_classification = _make_taskaware_classification_dataset(
×
594
            classification_dataset
595
        )
596

597
    def test_avalanche_concat_dataset_collate_fn_inheritance(self):
1✔
598
        tensor_x = torch.rand(200, 3, 28, 28)
×
599
        tensor_y = torch.randint(0, 100, (200,))
×
600
        tensor_z = torch.randint(0, 100, (200,))
×
601

602
        tensor_x2 = torch.rand(200, 3, 28, 28)
×
603
        tensor_y2 = torch.randint(0, 100, (200,))
×
604
        tensor_z2 = torch.randint(0, 100, (200,))
×
605

606
        def my_collate_fn(patterns):
×
607
            x_values = torch.stack([pat[0] for pat in patterns], 0)
×
608
            y_values = torch.tensor([pat[1] for pat in patterns]) + 1
×
609
            z_values = torch.tensor([-1 for _ in patterns])
×
610
            t_values = torch.tensor([pat[3] for pat in patterns])
×
611
            return x_values, y_values, z_values, t_values
×
612

613
        def my_collate_fn2(patterns):
×
614
            x_values = torch.stack([pat[0] for pat in patterns], 0)
×
615
            y_values = torch.tensor([pat[1] for pat in patterns]) + 2
×
616
            z_values = torch.tensor([-2 for _ in patterns])
×
617
            t_values = torch.tensor([pat[3] for pat in patterns])
×
618
            return x_values, y_values, z_values, t_values
×
619

620
        dataset1 = TensorDataset(tensor_x, tensor_y, tensor_z)
×
621
        dataset2 = _make_taskaware_tensor_classification_dataset(
×
622
            tensor_x2, tensor_y2, tensor_z2, collate_fn=my_collate_fn
623
        )
624
        concat = _concat_taskaware_classification_datasets(
×
625
            [dataset1, dataset2], collate_fn=my_collate_fn2
626
        )  # Ok
627

628
        x, y, z, t = get_mbatch(dataset2)
×
629
        self.assertIsInstance(x, Tensor)
×
630
        self.assertTrue(torch.equal(tensor_x2[0:5], x))
×
631
        self.assertTrue(torch.equal(tensor_y2[0:5] + 1, y))
×
632
        self.assertTrue(torch.equal(torch.full((5,), -1, dtype=torch.long), z))
×
633
        self.assertTrue(torch.equal(torch.zeros(5, dtype=torch.long), t))
×
634

635
        x2, y2, z2, t2 = get_mbatch(concat)
×
636
        self.assertIsInstance(x2, Tensor)
×
637
        self.assertTrue(torch.equal(tensor_x[0:5], x2))
×
638
        self.assertTrue(torch.equal(tensor_y[0:5] + 2, y2))
×
639
        self.assertTrue(torch.equal(torch.full((5,), -2, dtype=torch.long), z2))
×
640
        self.assertTrue(torch.equal(torch.zeros(5, dtype=torch.long), t2))
×
641

642
        dataset1_classification = _make_taskaware_tensor_classification_dataset(
×
643
            tensor_x, tensor_y, tensor_z
644
        )
645

646
        dataset2_segmentation = _make_taskaware_classification_dataset(dataset2)
×
647

648
        # with self.assertRaises(ValueError):
649
        #     bad_concat_types = dataset1_classification + dataset2_segmentation
650
        #
651
        # with self.assertRaises(ValueError):
652
        #     bad_concat_collate = AvalancheConcatDataset(
653
        #         [dataset1, dataset2_segmentation], collate_fn=my_collate_fn
654
        #     )
655

656
        ok_concat_classification = dataset1_classification + dataset2
×
657

658
        ok_concat_classification2 = dataset2 + dataset1_classification
×
659

660
    def test_avalanche_concat_dataset_recursion(self):
1✔
661
        def gen_random_tensors(n):
×
662
            return (
×
663
                torch.rand(n, 3, 28, 28),
664
                torch.randint(0, 100, (n,)),
665
                torch.randint(0, 100, (n,)),
666
            )
667

668
        tensor_x, tensor_y, tensor_z = gen_random_tensors(200)
×
669
        tensor_x2, tensor_y2, tensor_z2 = gen_random_tensors(200)
×
670
        tensor_x3, tensor_y3, tensor_z3 = gen_random_tensors(200)
×
671
        tensor_x4, tensor_y4, tensor_z4 = gen_random_tensors(200)
×
672
        tensor_x5, tensor_y5, tensor_z5 = gen_random_tensors(200)
×
673
        tensor_x6, tensor_y6, tensor_z6 = gen_random_tensors(200)
×
674
        tensor_x7, tensor_y7, tensor_z7 = gen_random_tensors(200)
×
675

676
        dataset1 = TensorDataset(tensor_x, tensor_y, tensor_z)
×
677
        dataset2 = _make_taskaware_classification_dataset(
×
678
            TensorDataset(tensor_x2, tensor_y2, tensor_z2),
679
            targets=tensor_y2,
680
            task_labels=1,
681
        )
682
        dataset3 = _make_taskaware_classification_dataset(
×
683
            TensorDataset(tensor_x3, tensor_y3, tensor_z3),
684
            targets=tensor_y3,
685
            task_labels=2,
686
        )
687

688
        dataset4 = _make_taskaware_classification_dataset(
×
689
            TensorDataset(tensor_x4, tensor_y4, tensor_z4),
690
            targets=tensor_y4,
691
            task_labels=3,
692
        )
693
        dataset5 = _make_taskaware_classification_dataset(
×
694
            TensorDataset(tensor_x5, tensor_y5, tensor_z5),
695
            targets=tensor_y5,
696
            task_labels=4,
697
        )
698
        dataset6 = _make_taskaware_classification_dataset(
×
699
            TensorDataset(tensor_x6, tensor_y6, tensor_z6), targets=tensor_y6
700
        )
701
        dataset7 = _make_taskaware_classification_dataset(
×
702
            TensorDataset(tensor_x7, tensor_y7, tensor_z7), targets=tensor_y7
703
        )
704

705
        # This will test recursion on both PyTorch ConcatDataset and
706
        # AvalancheConcatDataset
707
        concat = _concat_taskaware_classification_datasets([dataset1, dataset2])
×
708

709
        # Beware of the explicit task_labels=5 that *must* override the
710
        # task labels set in dataset4 and dataset5
711

712
        def transform_target_to_constant(ignored_target_value):
×
713
            return 101
×
714

715
        def transform_target_to_constant2(ignored_target_value):
×
716
            return 102
×
717

718
        concat2 = _concat_taskaware_classification_datasets(
×
719
            [dataset4, dataset5],
720
            task_labels=5,
721
            target_transform=transform_target_to_constant,
722
        )
723

724
        concat3 = _concat_taskaware_classification_datasets(
×
725
            [dataset6, dataset7], target_transform=transform_target_to_constant2
726
        ).freeze_transforms()
727
        concat_uut = concat_datasets([concat, dataset3, concat2, concat3])
×
728

729
        self.assertEqual(400, len(concat))
×
730
        self.assertEqual(400, len(concat2))
×
731
        self.assertEqual(400, len(concat3))
×
732
        self.assertEqual(1400, len(concat_uut))
×
733

734
        x, y, z, t = concat_uut[0]
×
735
        x2, y2, z2, t2 = concat_uut[200]
×
736
        x3, y3, z3, t3 = concat_uut[400]
×
737
        x4, y4, z4, t4 = concat_uut[600]
×
738
        x5, y5, z5, t5 = concat_uut[800]
×
739
        x6, y6, z6, t6 = concat_uut[1000]
×
740
        x7, y7, z7, t7 = concat_uut[1200]
×
741

742
        self.assertTrue(torch.equal(x, tensor_x[0]))
×
743
        self.assertTrue(torch.equal(y, tensor_y[0]))
×
744
        self.assertTrue(torch.equal(z, tensor_z[0]))
×
745
        self.assertEqual(0, t)
×
746

747
        self.assertTrue(torch.equal(x2, tensor_x2[0]))
×
748
        self.assertTrue(torch.equal(y2, tensor_y2[0]))
×
749
        self.assertTrue(torch.equal(z2, tensor_z2[0]))
×
750
        self.assertEqual(1, t2)
×
751

752
        self.assertTrue(torch.equal(x3, tensor_x3[0]))
×
753
        self.assertTrue(torch.equal(y3, tensor_y3[0]))
×
754
        self.assertTrue(torch.equal(z3, tensor_z3[0]))
×
755
        self.assertEqual(2, t3)
×
756

757
        self.assertTrue(torch.equal(x4, tensor_x4[0]))
×
758
        self.assertEqual(101, y4)
×
759
        self.assertTrue(torch.equal(z4, tensor_z4[0]))
×
760
        self.assertEqual(5, t4)
×
761

762
        self.assertTrue(torch.equal(x5, tensor_x5[0]))
×
763
        self.assertEqual(101, y5)
×
764
        self.assertTrue(torch.equal(z5, tensor_z5[0]))
×
765
        self.assertEqual(5, t5)
×
766

767
        self.assertTrue(torch.equal(x6, tensor_x6[0]))
×
768
        self.assertEqual(102, y6)
×
769
        self.assertTrue(torch.equal(z6, tensor_z6[0]))
×
770
        self.assertEqual(0, t6)
×
771

772
        self.assertTrue(torch.equal(x7, tensor_x7[0]))
×
773
        self.assertEqual(102, y7)
×
774
        self.assertTrue(torch.equal(z7, tensor_z7[0]))
×
775
        self.assertEqual(0, t7)
×
776

777
    def test_avalanche_pytorch_subset_recursion(self):
1✔
778
        dataset_mnist = MNIST(
×
779
            root=expanduser("~") + "/.avalanche/data/mnist/", download=True
780
        )
781
        x, y = dataset_mnist[3000]
×
782
        x2, y2 = dataset_mnist[1010]
×
783

784
        subset = Subset(dataset_mnist, indices=[3000, 8, 4, 1010, 12])
×
785

786
        dataset = _taskaware_classification_subset(subset, indices=[0, 3])
×
787

788
        self.assertEqual(5, len(subset))
×
789
        self.assertEqual(2, len(dataset))
×
790

791
        x3, y3, t3 = dataset[0]
×
792
        x4, y4, t4 = dataset[1]
×
793
        self.assertTrue(pil_images_equal(x, x3))
×
794
        self.assertEqual(y, y3)
×
795
        self.assertEqual(0, t3)
×
796
        self.assertTrue(pil_images_equal(x2, x4))
×
797
        self.assertEqual(y2, y4)
×
798
        self.assertEqual(0, t4)
×
799
        self.assertFalse(pil_images_equal(x, x4))
×
800
        self.assertFalse(pil_images_equal(x2, x3))
×
801

802
        def transform_target_to_constant(ignored_target_value):
×
803
            return 101
×
804

805
        subset = Subset(dataset_mnist, indices=[3000, 8, 4, 1010, 12])
×
806

807
        dataset = _taskaware_classification_subset(
×
808
            subset,
809
            indices=[0, 3],
810
            target_transform=transform_target_to_constant,
811
            task_labels=5,
812
        )
813

814
        self.assertEqual(5, len(subset))
×
815
        self.assertEqual(2, len(dataset))
×
816

817
        x5, y5, t5 = dataset[0]
×
818
        x6, y6, t6 = dataset[1]
×
819
        self.assertTrue(pil_images_equal(x, x5))
×
820
        self.assertEqual(101, y5)
×
821
        self.assertEqual(5, t5)
×
822
        self.assertTrue(pil_images_equal(x2, x6))
×
823
        self.assertEqual(101, y6)
×
824
        self.assertEqual(5, t6)
×
825
        self.assertFalse(pil_images_equal(x, x6))
×
826
        self.assertFalse(pil_images_equal(x2, x5))
×
827

828
    def test_avalanche_pytorch_subset_recursion_no_indices(self):
1✔
829
        dataset_mnist = MNIST(
×
830
            root=expanduser("~") + "/.avalanche/data/mnist/", download=True
831
        )
832
        x, y = dataset_mnist[3000]
×
833
        x2, y2 = dataset_mnist[8]
×
834

835
        subset = Subset(dataset_mnist, indices=[3000, 8, 4, 1010, 12])
×
836

837
        dataset = _taskaware_classification_subset(subset)
×
838

839
        self.assertEqual(5, len(subset))
×
840
        self.assertEqual(5, len(dataset))
×
841

842
        x3, y3, t3 = dataset[0]
×
843
        x4, y4, t4 = dataset[1]
×
844
        self.assertTrue(pil_images_equal(x, x3))
×
845
        self.assertEqual(y, y3)
×
846
        self.assertTrue(pil_images_equal(x2, x4))
×
847
        self.assertEqual(y2, y4)
×
848
        self.assertFalse(pil_images_equal(x, x4))
×
849
        self.assertFalse(pil_images_equal(x2, x3))
×
850

851
    def test_avalanche_avalanche_subset_recursion_no_indices_transform(self):
1✔
852
        dataset_mnist = MNIST(
×
853
            root=expanduser("~") + "/.avalanche/data/mnist/", download=True
854
        )
855
        x, y = dataset_mnist[3000]
×
856
        x2, y2 = dataset_mnist[8]
×
857

858
        def transform_target_to_constant(ignored_target_value):
×
859
            return 101
×
860

861
        def transform_target_plus_one(target_value):
×
862
            return target_value + 1
×
863

864
        subset = _taskaware_classification_subset(
×
865
            dataset_mnist,
866
            indices=[3000, 8, 4, 1010, 12],
867
            transform=ToTensor(),
868
            target_transform=transform_target_to_constant,
869
        )
870

871
        dataset = _taskaware_classification_subset(
×
872
            subset, target_transform=transform_target_plus_one
873
        )
874

875
        self.assertEqual(5, len(subset))
×
876
        self.assertEqual(5, len(dataset))
×
877

878
        x3, y3, t3 = dataset[0]
×
879
        x4, y4, t4 = dataset[1]
×
880
        self.assertIsInstance(x3, Tensor)
×
881
        self.assertIsInstance(x4, Tensor)
×
882
        self.assertTrue(torch.equal(ToTensor()(x), x3))
×
883
        self.assertEqual(102, y3)
×
884
        self.assertTrue(torch.equal(ToTensor()(x2), x4))
×
885
        self.assertEqual(102, y4)
×
886
        self.assertFalse(torch.equal(ToTensor()(x), x4))
×
887
        self.assertFalse(torch.equal(ToTensor()(x2), x3))
×
888

889
    def test_avalanche_avalanche_subset_recursion_transform(self):
1✔
890
        dataset_mnist = MNIST(root=default_dataset_location("mnist"), download=True)
×
891
        x, y = dataset_mnist[3000]
×
892
        x2, y2 = dataset_mnist[1010]
×
893

894
        def transform_target_to_constant(ignored_target_value):
×
895
            return 101
×
896

897
        def transform_target_plus_one(target_value):
×
898
            return target_value + 2
×
899

900
        subset = _taskaware_classification_subset(
×
901
            dataset_mnist,
902
            indices=[3000, 8, 4, 1010, 12],
903
            target_transform=transform_target_to_constant,
904
        )
905

906
        dataset = _taskaware_classification_subset(
×
907
            subset,
908
            indices=[0, 3, 1],
909
            target_transform=transform_target_plus_one,
910
        )
911

912
        self.assertEqual(5, len(subset))
×
913
        self.assertEqual(3, len(dataset))
×
914

915
        x3, y3, t3 = dataset[0]
×
916
        x4, y4, t4 = dataset[1]
×
917

918
        self.assertTrue(pil_images_equal(x, x3))
×
919
        self.assertEqual(103, y3)
×
920
        self.assertTrue(pil_images_equal(x2, x4))
×
921
        self.assertEqual(103, y4)
×
922
        self.assertFalse(pil_images_equal(x, x4))
×
923
        self.assertFalse(pil_images_equal(x2, x3))
×
924

925
    def test_avalanche_avalanche_subset_recursion_frozen_transform(self):
1✔
926
        dataset_mnist = MNIST(root=default_dataset_location("mnist"), download=True)
×
927
        x, y = dataset_mnist[3000]
×
928
        x2, y2 = dataset_mnist[1010]
×
929

930
        def transform_target_to_constant(ignored_target_value):
×
931
            return 101
×
932

933
        def transform_target_plus_two(target_value):
×
934
            return target_value + 2
×
935

936
        subset = _taskaware_classification_subset(
×
937
            dataset_mnist,
938
            indices=[3000, 8, 4, 1010, 12],
939
            target_transform=transform_target_to_constant,
940
        )
941
        subset = subset.freeze_transforms()
×
942

943
        dataset = _taskaware_classification_subset(
×
944
            subset,
945
            indices=[0, 3, 1],
946
            target_transform=transform_target_plus_two,
947
        )
948

949
        self.assertEqual(5, len(subset))
×
950
        self.assertEqual(3, len(dataset))
×
951

952
        x3, y3, t3 = dataset[0]
×
953
        x4, y4, t4 = dataset[1]
×
954

955
        self.assertTrue(pil_images_equal(x, x3))
×
956
        self.assertEqual(103, y3)
×
957
        self.assertTrue(pil_images_equal(x2, x4))
×
958
        self.assertEqual(103, y4)
×
959
        self.assertFalse(pil_images_equal(x, x4))
×
960
        self.assertFalse(pil_images_equal(x2, x3))
×
961

962
        dataset = _taskaware_classification_subset(
×
963
            subset,
964
            indices=[0, 3, 1],
965
            target_transform=transform_target_plus_two,
966
        )
967
        dataset = dataset.replace_current_transform_group(None)
×
968

969
        x5, y5, t5 = dataset[0]
×
970
        x6, y6, t6 = dataset[1]
×
971

972
        self.assertTrue(pil_images_equal(x, x5))
×
973
        self.assertEqual(101, y5)
×
974
        self.assertTrue(pil_images_equal(x2, x6))
×
975
        self.assertEqual(101, y6)
×
976
        self.assertFalse(pil_images_equal(x, x6))
×
977
        self.assertFalse(pil_images_equal(x2, x5))
×
978

979
    def test_avalanche_avalanche_subset_recursion_sub_class_mapping(self):
1✔
980
        dataset_mnist = MNIST(root=default_dataset_location("mnist"), download=True)
×
981
        x, y = dataset_mnist[3000]
×
982
        x2, y2 = dataset_mnist[1010]
×
983

984
        class_mapping = list(range(10))
×
985
        random.shuffle(class_mapping)
×
986

987
        subset = _taskaware_classification_subset(
×
988
            dataset_mnist,
989
            indices=[3000, 8, 4, 1010, 12],
990
            class_mapping=class_mapping,
991
        )
992

993
        dataset = _taskaware_classification_subset(subset, indices=[0, 3, 1])
×
994

995
        self.assertEqual(5, len(subset))
×
996
        self.assertEqual(3, len(dataset))
×
997

998
        x3, y3, t3 = dataset[0]
×
999
        x4, y4, t4 = dataset[1]
×
1000

1001
        self.assertTrue(pil_images_equal(x, x3))
×
1002
        expected_y3 = class_mapping[y]
×
1003
        self.assertEqual(expected_y3, y3)
×
1004
        self.assertTrue(pil_images_equal(x2, x4))
×
1005
        expected_y4 = class_mapping[y2]
×
1006
        self.assertEqual(expected_y4, y4)
×
1007
        self.assertFalse(pil_images_equal(x, x4))
×
1008
        self.assertFalse(pil_images_equal(x2, x3))
×
1009

1010
    def test_avalanche_avalanche_subset_recursion_up_class_mapping(self):
1✔
1011
        dataset_mnist = MNIST(root=default_dataset_location("mnist"), download=True)
×
1012
        x, y = dataset_mnist[3000]
×
1013
        x2, y2 = dataset_mnist[1010]
×
1014

1015
        class_mapping = list(range(10))
×
1016
        random.shuffle(class_mapping)
×
1017

1018
        subset = _taskaware_classification_subset(
×
1019
            dataset_mnist, indices=[3000, 8, 4, 1010, 12]
1020
        )
1021

1022
        dataset = _taskaware_classification_subset(
×
1023
            subset, indices=[0, 3, 1], class_mapping=class_mapping
1024
        )
1025

1026
        self.assertEqual(5, len(subset))
×
1027
        self.assertEqual(3, len(dataset))
×
1028

1029
        x3, y3, t3 = dataset[0]
×
1030
        x4, y4, t4 = dataset[1]
×
1031

1032
        self.assertTrue(pil_images_equal(x, x3))
×
1033
        expected_y3 = class_mapping[y]
×
1034
        self.assertEqual(expected_y3, y3)
×
1035
        self.assertTrue(pil_images_equal(x2, x4))
×
1036
        expected_y4 = class_mapping[y2]
×
1037
        self.assertEqual(expected_y4, y4)
×
1038
        self.assertFalse(pil_images_equal(x, x4))
×
1039
        self.assertFalse(pil_images_equal(x2, x3))
×
1040

1041
    def test_avalanche_avalanche_subset_recursion_mix_class_mapping(self):
1✔
1042
        dataset_mnist = MNIST(root=default_dataset_location("mnist"), download=True)
×
1043
        x, y = dataset_mnist[3000]
×
1044
        x2, y2 = dataset_mnist[1010]
×
1045

1046
        class_mapping = list(range(10))
×
1047
        class_mapping2 = list(range(10))
×
1048
        random.shuffle(class_mapping)
×
1049
        random.shuffle(class_mapping2)
×
1050

1051
        subset = _taskaware_classification_subset(
×
1052
            dataset_mnist,
1053
            indices=[3000, 8, 4, 1010, 12],
1054
            class_mapping=class_mapping,
1055
        )
1056

1057
        dataset = _taskaware_classification_subset(
×
1058
            subset, indices=[0, 3, 1], class_mapping=class_mapping2
1059
        )
1060

1061
        self.assertEqual(5, len(subset))
×
1062
        self.assertEqual(3, len(dataset))
×
1063

1064
        x3, y3, t3 = dataset[0]
×
1065
        x4, y4, t4 = dataset[1]
×
1066

1067
        self.assertTrue(pil_images_equal(x, x3))
×
1068
        expected_y3 = class_mapping2[class_mapping[y]]
×
1069
        self.assertEqual(expected_y3, y3)
×
1070
        self.assertTrue(pil_images_equal(x2, x4))
×
1071
        expected_y4 = class_mapping2[class_mapping[y2]]
×
1072
        self.assertEqual(expected_y4, y4)
×
1073
        self.assertFalse(pil_images_equal(x, x4))
×
1074
        self.assertFalse(pil_images_equal(x2, x3))
×
1075

1076
    def test_avalanche_avalanche_subset_concat_stack_overflow(self):
1✔
1077
        d_sz = 4
×
1078
        tensor_x = torch.rand(d_sz, 2)
×
1079
        tensor_y = torch.randint(0, 7, (d_sz,))
×
1080
        tensor_t = torch.randint(0, 7, (d_sz,))
×
1081
        dataset = _make_taskaware_classification_dataset(
×
1082
            TensorDataset(tensor_x, tensor_y),
1083
            targets=tensor_y,
1084
            task_labels=tensor_t,
1085
        )
1086
        dataset_hierarchy_depth = 500
×
1087

1088
        # prepare random permutations for each step
1089
        random_permutations: List[List[int]] = []
×
1090
        for _ in range(dataset_hierarchy_depth):
×
1091
            idx_permuted = list(range(d_sz))
×
1092
            random.shuffle(idx_permuted)
×
1093
            random_permutations.append(idx_permuted)
×
1094

1095
        # compute expected indices after all permutations
1096
        current_indices = range(d_sz)
×
1097
        true_indices: List[List[int]] = []
×
1098
        true_indices.append(list(current_indices))
×
1099
        for idx in range(dataset_hierarchy_depth):
×
1100
            current_indices = [current_indices[x] for x in random_permutations[idx]]
×
1101
            true_indices.append(current_indices)
×
1102
        true_indices = list(reversed(true_indices))
×
1103

1104
        # apply permutations and concatenations iteratively
1105
        curr_dataset = dataset
×
1106
        for idx in range(dataset_hierarchy_depth):
×
1107
            # print(idx)
1108
            # print(idx, "depth: ", _flatdata_depth(curr_dataset))
1109
            # _flatdata_print(curr_dataset)
1110
            intermediate_idx_test = (dataset_hierarchy_depth - 1) - idx
×
1111
            subset = _taskaware_classification_subset(
×
1112
                curr_dataset, indices=random_permutations[idx]
1113
            )
1114
            curr_dataset = subset.concat(curr_dataset)
×
1115

1116
            # Regression test for #616 (second bug)
1117
            # https://github.com/ContinualAI/avalanche/issues/616#issuecomment-848852287
1118
            all_targets = torch.tensor(curr_dataset.targets)
×
1119
            self.assertTrue(torch.equal(tensor_y, all_targets[-d_sz:]))
×
1120

1121
            curr_targets = torch.tensor(list(curr_dataset.targets))
×
1122
            for idx_internal in range(idx + 1):
×
1123
                # curr_dataset is the concat of idx+1 datasets.
1124
                # Check all of them are permuted correctly
1125
                leaf_range = range(idx_internal * d_sz, (idx_internal + 1) * d_sz)
×
1126
                permuted = true_indices[idx_internal + intermediate_idx_test]
×
1127
                self.assertTrue(
×
1128
                    torch.equal(tensor_y[permuted], curr_targets[leaf_range])
1129
                )
1130

1131
            self.assertTrue(torch.equal(tensor_y, curr_targets[-d_sz:]))
×
1132

1133
        self.assertEqual(d_sz * dataset_hierarchy_depth + d_sz, len(curr_dataset))
×
1134

1135
        def collect_permuted_data(dataset, indices):
×
1136
            x, y, t = [], [], []
×
1137
            for idx in indices:
×
1138
                x_, y_, t_ = dataset[idx]
×
1139
                x.append(x_)
×
1140
                y.append(y_)
×
1141
                t.append(t_)
×
1142
            return torch.stack(x, dim=0), torch.stack(y, dim=0), torch.tensor(t)
×
1143

1144
        for idx in range(dataset_hierarchy_depth):
×
1145
            leaf_range = range(idx * d_sz, (idx + 1) * d_sz)
×
1146
            permuted = true_indices[idx]
×
1147

1148
            x_leaf, y_leaf, t_leaf = collect_permuted_data(curr_dataset, leaf_range)
×
1149
            self.assertTrue(torch.equal(tensor_x[permuted], x_leaf))
×
1150
            self.assertTrue(torch.equal(tensor_y[permuted], y_leaf))
×
1151
            self.assertTrue(torch.equal(tensor_t[permuted], t_leaf))
×
1152

1153
            trg_leaf = torch.tensor(curr_dataset.targets)[leaf_range]
×
1154
            self.assertTrue(torch.equal(tensor_y[permuted], trg_leaf))
×
1155

1156
        slice_idxs = list(range(d_sz * dataset_hierarchy_depth, len(curr_dataset)))
×
1157
        x_slice, y_slice, t_slice = collect_permuted_data(curr_dataset, slice_idxs)
×
1158
        self.assertTrue(torch.equal(tensor_x, x_slice))
×
1159
        self.assertTrue(torch.equal(tensor_y, y_slice))
×
1160
        self.assertTrue(torch.equal(tensor_t, t_slice))
×
1161

1162
        trg_slice = torch.tensor(curr_dataset.targets)[d_sz * dataset_hierarchy_depth :]
×
1163
        self.assertTrue(torch.equal(tensor_y, trg_slice))
×
1164

1165
        # If you broke this test it means that dataset merging is not working
1166
        # anymore. you are probably doing something that disable merging
1167
        # (passing custom transforms?)
1168
        # Good luck...
1169
        assert _flatdata_depth(curr_dataset) <= 3
×
1170

1171
    def test_avalanche_concat_classification_datasets_sequentially(self):
1✔
1172
        # create list of training datasets
1173
        train = [
×
1174
            _make_taskaware_classification_dataset(
1175
                TensorDataset(torch.randn(20, 10), torch.randint(0, 2, (20,)))
1176
            ),
1177
            _make_taskaware_classification_dataset(
1178
                TensorDataset(torch.randn(20, 10), torch.randint(2, 4, (20,)))
1179
            ),
1180
            _make_taskaware_classification_dataset(
1181
                TensorDataset(torch.randn(20, 10), torch.randint(4, 6, (20,)))
1182
            ),
1183
            _make_taskaware_classification_dataset(
1184
                TensorDataset(torch.randn(20, 10), torch.randint(0, 2, (20,)))
1185
            ),
1186
        ]
1187

1188
        # create list of test datasets
1189
        test = [
×
1190
            _make_taskaware_classification_dataset(
1191
                TensorDataset(torch.randn(20, 10), torch.randint(0, 2, (20,)))
1192
            ),
1193
            _make_taskaware_classification_dataset(
1194
                TensorDataset(torch.randn(20, 10), torch.randint(2, 4, (20,)))
1195
            ),
1196
            _make_taskaware_classification_dataset(
1197
                TensorDataset(torch.randn(20, 10), torch.randint(4, 6, (20,)))
1198
            ),
1199
            _make_taskaware_classification_dataset(
1200
                TensorDataset(torch.randn(20, 10), torch.randint(0, 2, (20,)))
1201
            ),
1202
        ]
1203

1204
        # concatenate datasets
1205
        (
×
1206
            final_train,
1207
            _,
1208
            classes,
1209
        ) = _concat_taskaware_classification_datasets_sequentially(train, test)
1210

1211
        # merge all classes into a single list
1212
        classes_all = []
×
1213
        for class_list in classes:
×
1214
            classes_all.extend(class_list)
×
1215

1216
        # get the target set of classes
1217
        target_classes = list(set(map(int, final_train.targets)))
×
1218

1219
        # test for correctness
1220
        self.assertEqual(classes_all, target_classes)
×
1221

1222

1223
class TransformationSubsetTests(unittest.TestCase):
1✔
1224
    def test_avalanche_subset_transform(self):
1✔
1225
        dataset_mnist = MNIST(root=default_dataset_location("mnist"), download=True)
×
1226
        x, y = dataset_mnist[0]
×
1227
        dataset = _taskaware_classification_subset(dataset_mnist, transform=ToTensor())
×
1228
        x2, y2, t2 = dataset[0]
×
1229
        self.assertIsInstance(x2, Tensor)
×
1230
        self.assertIsInstance(y2, int)
×
1231
        self.assertIsInstance(t2, int)
×
1232
        self.assertTrue(torch.equal(ToTensor()(x), x2))
×
1233
        self.assertEqual(y, y2)
×
1234
        self.assertEqual(0, t2)
×
1235

1236
    def test_avalanche_subset_composition(self):
1✔
1237
        dataset_mnist = MNIST(
×
1238
            root=default_dataset_location("mnist"),
1239
            download=True,
1240
            transform=RandomCrop(16),
1241
        )
1242
        x, y = dataset_mnist[0]
×
1243
        self.assertIsInstance(x, Image)
×
1244
        self.assertEqual([x.width, x.height], [16, 16])
×
1245
        self.assertIsInstance(y, int)
×
1246

1247
        dataset = _taskaware_classification_subset(
×
1248
            dataset_mnist,
1249
            transform=ToTensor(),
1250
            target_transform=lambda target: -1,
1251
        )
1252

1253
        x2, y2, t2 = dataset[0]
×
1254
        self.assertIsInstance(x2, Tensor)
×
1255
        self.assertEqual(x2.shape, (1, 16, 16))
×
1256
        self.assertIsInstance(y2, int)
×
1257
        self.assertIsInstance(t2, int)
×
1258
        self.assertEqual(y2, -1)
×
1259
        self.assertEqual(0, t2)
×
1260

1261
    def test_avalanche_subset_indices(self):
1✔
1262
        dataset_mnist = MNIST(root=default_dataset_location("mnist"), download=True)
×
1263
        x, y = dataset_mnist[1000]
×
1264
        x2, y2 = dataset_mnist[1007]
×
1265

1266
        dataset = _taskaware_classification_subset(dataset_mnist, indices=[1000, 1007])
×
1267

1268
        x3, y3, t3 = dataset[0]
×
1269
        x4, y4, t4 = dataset[1]
×
1270
        self.assertTrue(pil_images_equal(x, x3))
×
1271
        self.assertEqual(y, y3)
×
1272
        self.assertTrue(pil_images_equal(x2, x4))
×
1273
        self.assertEqual(y2, y4)
×
1274
        self.assertFalse(pil_images_equal(x, x4))
×
1275
        self.assertFalse(pil_images_equal(x2, x3))
×
1276

1277
    def test_avalanche_subset_mapping(self):
1✔
1278
        dataset_mnist = MNIST(root=default_dataset_location("mnist"), download=True)
×
1279
        _, y = dataset_mnist[1000]
×
1280

1281
        mapping = list(range(10))
×
1282
        other_classes = list(mapping)
×
1283
        other_classes.remove(y)
×
1284

1285
        swap_y = random.choice(other_classes)
×
1286

1287
        mapping[y] = swap_y
×
1288
        mapping[swap_y] = y
×
1289

1290
        dataset = _taskaware_classification_subset(dataset_mnist, class_mapping=mapping)
×
1291

1292
        _, y2, _ = dataset[1000]
×
1293
        self.assertEqual(y2, swap_y)
×
1294

1295
    def test_avalanche_subset_uniform_task_labels(self):
1✔
1296
        dataset_mnist = MNIST(root=default_dataset_location("mnist"), download=True)
×
1297
        x, y = dataset_mnist[1000]
×
1298
        x2, y2 = dataset_mnist[1007]
×
1299

1300
        # First, test by passing len(task_labels) == len(dataset_mnist)
1301
        dataset = _taskaware_classification_subset(
×
1302
            dataset_mnist,
1303
            indices=[1000, 1007],
1304
            task_labels=[1] * len(dataset_mnist),
1305
        )
1306

1307
        x3, y3, t3 = dataset[0]
×
1308
        x4, y4, t4 = dataset[1]
×
1309
        self.assertEqual(y, y3)
×
1310
        self.assertEqual(1, t3)
×
1311
        self.assertEqual(y2, y4)
×
1312
        self.assertEqual(1, t4)
×
1313

1314
        # Secondly, test by passing len(task_labels) == len(indices)
1315
        dataset = _taskaware_classification_subset(
×
1316
            dataset_mnist, indices=[1000, 1007], task_labels=[1, 1]
1317
        )
1318

1319
        x3, y3, t3 = dataset[0]
×
1320
        x4, y4, t4 = dataset[1]
×
1321
        self.assertEqual(y, y3)
×
1322
        self.assertEqual(1, t3)
×
1323
        self.assertEqual(y2, y4)
×
1324
        self.assertEqual(1, t4)
×
1325

1326
    def test_avalanche_subset_mixed_task_labels(self):
1✔
1327
        dataset_mnist = MNIST(root=default_dataset_location("mnist"), download=True)
×
1328
        x, y = dataset_mnist[1000]
×
1329
        x2, y2 = dataset_mnist[1007]
×
1330

1331
        full_task_labels = [1] * len(dataset_mnist)
×
1332
        full_task_labels[1000] = 2
×
1333
        # First, test by passing len(task_labels) == len(dataset_mnist)
1334
        dataset = _taskaware_classification_subset(
×
1335
            dataset_mnist, indices=[1000, 1007], task_labels=full_task_labels
1336
        )
1337

1338
        x3, y3, t3 = dataset[0]
×
1339
        x4, y4, t4 = dataset[1]
×
1340
        self.assertEqual(y, y3)
×
1341
        self.assertEqual(2, t3)
×
1342
        self.assertEqual(y2, y4)
×
1343
        self.assertEqual(1, t4)
×
1344

1345
        # Secondly, test by passing len(task_labels) == len(indices)
1346
        dataset = _taskaware_classification_subset(
×
1347
            dataset_mnist, indices=[1000, 1007], task_labels=[3, 5]
1348
        )
1349

1350
        x3, y3, t3 = dataset[0]
×
1351
        x4, y4, t4 = dataset[1]
×
1352
        self.assertEqual(y, y3)
×
1353
        self.assertEqual(3, t3)
×
1354
        self.assertEqual(y2, y4)
×
1355
        self.assertEqual(5, t4)
×
1356

1357
    def test_avalanche_subset_task_labels_inheritance(self):
1✔
1358
        dataset_mnist = MNIST(root=default_dataset_location("mnist"), download=True)
×
1359
        random_task_labels = [random.randint(0, 10) for _ in range(len(dataset_mnist))]
×
1360
        dataset_orig = _make_taskaware_classification_dataset(
×
1361
            dataset_mnist, transform=ToTensor(), task_labels=random_task_labels
1362
        )
1363

1364
        dataset_child = _taskaware_classification_subset(
×
1365
            dataset_orig, indices=[1000, 1007]
1366
        )
1367
        _, _, t2 = dataset_orig[1000]
×
1368
        _, _, t5 = dataset_orig[1007]
×
1369
        _, _, t3 = dataset_child[0]
×
1370
        _, _, t6 = dataset_child[1]
×
1371

1372
        self.assertEqual(random_task_labels[1000], t2)
×
1373
        self.assertEqual(random_task_labels[1007], t5)
×
1374
        self.assertEqual(random_task_labels[1000], t3)
×
1375
        self.assertEqual(random_task_labels[1007], t6)
×
1376

1377
        self.assertListEqual(random_task_labels, list(dataset_orig.targets_task_labels))
×
1378

1379
        self.assertListEqual(
×
1380
            [random_task_labels[1000], random_task_labels[1007]],
1381
            list(dataset_child.targets_task_labels),
1382
        )
1383

1384
    def test_avalanche_subset_collate_fn_inheritance(self):
1✔
1385
        tensor_x = torch.rand(200, 3, 28, 28)
×
1386
        tensor_y = torch.randint(0, 100, (200,))
×
1387
        tensor_z = torch.randint(0, 100, (200,))
×
1388

1389
        def my_collate_fn(patterns):
×
1390
            x_values = torch.stack([pat[0] for pat in patterns], 0)
×
1391
            y_values = torch.tensor([pat[1] for pat in patterns]) + 1
×
1392
            z_values = torch.tensor([-1 for _ in patterns])
×
1393
            t_values = torch.tensor([pat[3] for pat in patterns])
×
1394
            return x_values, y_values, z_values, t_values
×
1395

1396
        def my_collate_fn2(patterns):
×
1397
            x_values = torch.stack([pat[0] for pat in patterns], 0)
×
1398
            y_values = torch.tensor([pat[1] for pat in patterns]) + 2
×
1399
            z_values = torch.tensor([-2 for _ in patterns])
×
1400
            t_values = torch.tensor([pat[3] for pat in patterns])
×
1401
            return x_values, y_values, z_values, t_values
×
1402

1403
        whole_dataset = TensorDataset(tensor_x, tensor_y, tensor_z)
×
1404
        dataset = _make_taskaware_classification_dataset(
×
1405
            whole_dataset, collate_fn=my_collate_fn
1406
        )
1407
        inherited = _taskaware_classification_subset(
×
1408
            dataset, indices=list(range(5, 150)), collate_fn=my_collate_fn2
1409
        )  # Ok
1410

1411
        x, y, z, t = get_mbatch(inherited, batch_size=5)
×
1412
        self.assertIsInstance(x, Tensor)
×
1413
        self.assertTrue(torch.equal(tensor_x[5:10], x))
×
1414
        self.assertTrue(torch.equal(tensor_y[5:10] + 2, y))
×
1415
        self.assertTrue(torch.equal(torch.full((5,), -2, dtype=torch.long), z))
×
1416
        self.assertTrue(torch.equal(torch.zeros(5, dtype=torch.long), t))
×
1417

1418
        classification_dataset = _make_taskaware_classification_dataset(whole_dataset)
×
1419

1420
        ok_inherited_classification = _taskaware_classification_subset(
×
1421
            classification_dataset, indices=list(range(5, 150))
1422
        )
1423

1424

1425
class TransformationTensorDatasetTests(unittest.TestCase):
1✔
1426
    def test_tensor_dataset_helper_tensor_y(self):
1✔
1427
        train_exps = [
×
1428
            [torch.rand(50, 32, 32), torch.randint(0, 100, (50,))] for _ in range(5)
1429
        ]
1430
        test_exps = [
×
1431
            [torch.rand(23, 32, 32), torch.randint(0, 100, (23,))] for _ in range(5)
1432
        ]
1433

1434
        cl_benchmark = create_generic_benchmark_from_tensor_lists(
×
1435
            train_tensors=train_exps,
1436
            test_tensors=test_exps,
1437
            task_labels=[0] * 5,
1438
        )
1439

1440
        self.assertEqual(5, len(cl_benchmark.train_stream))
×
1441
        self.assertEqual(5, len(cl_benchmark.test_stream))
×
1442
        self.assertEqual(5, cl_benchmark.n_experiences)
×
1443

1444
        for exp_id in range(cl_benchmark.n_experiences):
×
1445
            benchmark_train_x, benchmark_train_y, _ = load_all_dataset(
×
1446
                cl_benchmark.train_stream[exp_id].dataset
1447
            )
1448
            benchmark_test_x, benchmark_test_y, _ = load_all_dataset(
×
1449
                cl_benchmark.test_stream[exp_id].dataset
1450
            )
1451

1452
            self.assertTrue(
×
1453
                torch.all(torch.eq(train_exps[exp_id][0], benchmark_train_x))
1454
            )
1455
            self.assertTrue(
×
1456
                torch.all(torch.eq(train_exps[exp_id][1], benchmark_train_y))
1457
            )
1458
            self.assertSequenceEqual(
×
1459
                train_exps[exp_id][1].tolist(),
1460
                cl_benchmark.train_stream[exp_id].dataset.targets,
1461
            )
1462
            self.assertEqual(0, cl_benchmark.train_stream[exp_id].task_label)
×
1463

1464
            self.assertTrue(torch.all(torch.eq(test_exps[exp_id][0], benchmark_test_x)))
×
1465
            self.assertTrue(torch.all(torch.eq(test_exps[exp_id][1], benchmark_test_y)))
×
1466
            self.assertSequenceEqual(
×
1467
                test_exps[exp_id][1].tolist(),
1468
                cl_benchmark.test_stream[exp_id].dataset.targets,
1469
            )
1470
            self.assertEqual(0, cl_benchmark.test_stream[exp_id].task_label)
×
1471

1472
    def test_tensor_dataset_helper_list_y(self):
1✔
1473
        train_exps = [
×
1474
            (torch.rand(50, 32, 32), torch.randint(0, 100, (50,))) for _ in range(5)
1475
        ]
1476
        test_exps = [
×
1477
            (torch.rand(23, 32, 32), torch.randint(0, 100, (23,))) for _ in range(5)
1478
        ]
1479

1480
        cl_benchmark = create_generic_benchmark_from_tensor_lists(
×
1481
            train_tensors=train_exps,
1482
            test_tensors=test_exps,
1483
            task_labels=[0] * 5,
1484
        )
1485

1486
        self.assertEqual(5, len(cl_benchmark.train_stream))
×
1487
        self.assertEqual(5, len(cl_benchmark.test_stream))
×
1488
        self.assertEqual(5, cl_benchmark.n_experiences)
×
1489

1490
        for exp_id in range(cl_benchmark.n_experiences):
×
1491
            benchmark_train_x, benchmark_train_y, _ = load_all_dataset(
×
1492
                cl_benchmark.train_stream[exp_id].dataset
1493
            )
1494
            benchmark_test_x, benchmark_test_y, _ = load_all_dataset(
×
1495
                cl_benchmark.test_stream[exp_id].dataset
1496
            )
1497

1498
            self.assertTrue(
×
1499
                torch.all(torch.eq(train_exps[exp_id][0], benchmark_train_x))
1500
            )
1501
            self.assertSequenceEqual(train_exps[exp_id][1], benchmark_train_y.tolist())
×
1502
            self.assertSequenceEqual(
×
1503
                train_exps[exp_id][1],
1504
                cl_benchmark.train_stream[exp_id].dataset.targets,
1505
            )
1506
            self.assertEqual(0, cl_benchmark.train_stream[exp_id].task_label)
×
1507

1508
            self.assertTrue(torch.all(torch.eq(test_exps[exp_id][0], benchmark_test_x)))
×
1509
            self.assertSequenceEqual(test_exps[exp_id][1], benchmark_test_y.tolist())
×
1510
            self.assertSequenceEqual(
×
1511
                test_exps[exp_id][1],
1512
                cl_benchmark.test_stream[exp_id].dataset.targets,
1513
            )
1514
            self.assertEqual(0, cl_benchmark.test_stream[exp_id].task_label)
×
1515

1516

1517
class AvalancheDatasetTransformOpsTests(unittest.TestCase):
1✔
1518
    def test_avalanche_inherit_groups(self):
1✔
1519
        original_dataset = MNIST(root=default_dataset_location("mnist"), download=True)
×
1520

1521
        def plus_one_target(target):
×
1522
            return target + 1
×
1523

1524
        transform_groups = dict(train=(ToTensor(), None), eval=(None, plus_one_target))
×
1525
        x, y = original_dataset[0]
×
1526
        dataset = _make_taskaware_classification_dataset(
×
1527
            original_dataset, transform_groups=transform_groups
1528
        )
1529

1530
        x2, y2, _ = dataset[0]
×
1531
        self.assertIsInstance(x2, Tensor)
×
1532
        self.assertIsInstance(y2, int)
×
1533
        self.assertTrue(torch.equal(ToTensor()(x), x2))
×
1534
        self.assertEqual(y, y2)
×
1535

1536
        dataset_eval = dataset.eval()
×
1537
        x3, y3, _ = dataset_eval[0]
×
1538
        self.assertIsInstance(x3, PIL.Image.Image)
×
1539
        self.assertIsInstance(y3, int)
×
1540
        self.assertEqual(y + 1, y3)
×
1541

1542
        # Regression test for #565
1543
        dataset_inherit = _make_taskaware_classification_dataset(dataset_eval)
×
1544
        x4, y4, _ = dataset_inherit[0]
×
1545
        self.assertIsInstance(x4, PIL.Image.Image)
×
1546
        self.assertIsInstance(y4, int)
×
1547
        self.assertEqual(y + 1, y4)
×
1548

1549
        # Regression test for #566
1550
        dataset_sub_train = _taskaware_classification_subset(dataset)
×
1551
        dataset_sub_eval = dataset_sub_train.eval()
×
1552
        dataset_sub = _taskaware_classification_subset(dataset_sub_eval, indices=[0])
×
1553

1554
        x5, y5, _ = dataset_sub[0]
×
1555
        self.assertIsInstance(x5, PIL.Image.Image)
×
1556
        self.assertIsInstance(y5, int)
×
1557
        self.assertEqual(y + 1, y5)
×
1558
        # End regression tests
1559

1560
        concat_dataset = dataset_sub_eval.concat(dataset_sub)
×
1561
        x6, y6, _ = concat_dataset[0]
×
1562
        self.assertIsInstance(x6, PIL.Image.Image)
×
1563
        self.assertIsInstance(y6, int)
×
1564
        self.assertEqual(y + 1, y6)
×
1565

1566
        # DEPRECATED BEHAVIOR
1567
        # concat_dataset_no_inherit_initial =
1568
        #   AvalancheConcatClassificationDataset(
1569
        #     [dataset_sub_eval, dataset]
1570
        # )
1571
        # x7, y7, _ = concat_dataset_no_inherit_initial[0]
1572
        # self.assertIsInstance(x7, Tensor)
1573
        # self.assertIsInstance(y7, int)
1574
        # self.assertEqual(y, y7)
1575

1576
    def test_avalanche_inherit_groups_freeze_transforms(self):
1✔
1577
        original_dataset = MNIST(root=default_dataset_location("mnist"), download=True)
×
1578

1579
        transform_groups = dict(train=(RandomCrop(16), None), eval=(None, None))
×
1580
        dataset = _make_taskaware_classification_dataset(
×
1581
            original_dataset, transform_groups=transform_groups
1582
        )
1583

1584
        # test for #1353
1585
        dataset_inherit = _make_taskaware_classification_dataset(dataset)
×
1586
        x, *_ = dataset_inherit[0]
×
1587

1588
        dataset_frozen = dataset_inherit.freeze_transforms()
×
1589
        x2, *_ = dataset_frozen[0]
×
1590

1591
        dataset_frozen_reset = dataset_frozen.replace_current_transform_group(None)
×
1592
        x3, *_ = dataset_frozen_reset[0]
×
1593

1594
        dataset_reset = dataset_inherit.replace_current_transform_group(None)
×
1595
        x4, *_ = dataset_reset[0]
×
1596

1597
        self.assertEqual(x.size, (16, 16))
×
1598
        self.assertEqual(x2.size, (16, 16))
×
1599
        self.assertEqual(x3.size, (16, 16))
×
1600
        self.assertEqual(x4.size, (28, 28))
×
1601

1602
    def test_freeze_transforms(self):
1✔
1603
        original_dataset = MNIST(root=default_dataset_location("mnist"), download=True)
×
1604
        x, y = original_dataset[0]
×
1605
        dataset = _make_taskaware_classification_dataset(
×
1606
            original_dataset, transform=ToTensor()
1607
        )
1608
        dataset_frozen = dataset.freeze_transforms()
×
1609

1610
        x2, y2, _ = dataset_frozen[0]
×
1611
        self.assertIsInstance(x2, Tensor)
×
1612
        self.assertIsInstance(y2, int)
×
1613
        self.assertTrue(torch.equal(ToTensor()(x), x2))
×
1614
        self.assertEqual(y, y2)
×
1615

1616
    def test_freeze_transforms_subset(self):
1✔
1617
        original_dataset = MNIST(root=default_dataset_location("mnist"), download=True)
×
1618
        x, y = original_dataset[0]
×
1619
        dataset: AvalancheDataset = _make_taskaware_classification_dataset(
×
1620
            original_dataset, transform=ToTensor()
1621
        )
1622
        dataset_subset = dataset.subset((1, 2, 3))
×
1623

1624
        dataset_frozen = dataset_subset.freeze_transforms()
×
1625
        x, *_ = dataset_frozen[0]
×
1626
        self.assertIsInstance(x, Tensor)
×
1627

1628
        dataset_frozen = dataset_frozen.replace_current_transform_group(None)
×
1629

1630
        x, *_ = dataset_frozen[0]
×
1631
        self.assertIsInstance(x, Tensor)
×
1632

1633
        dataset_frozen_derivative = dataset_frozen.replace_current_transform_group(
×
1634
            ToPILImage()
1635
        )
1636

1637
        x, *_ = dataset_frozen[0]
×
1638
        x2, *_ = dataset_frozen_derivative[0]
×
1639
        self.assertIsInstance(x, Tensor)
×
1640
        self.assertIsInstance(x2, Image)
×
1641

1642
        dataset_frozen = dataset_frozen.replace_current_transform_group(ToPILImage())
×
1643

1644
        x, *_ = dataset_frozen[0]
×
1645
        x2, *_ = dataset_frozen_derivative[0]
×
1646
        self.assertIsInstance(x, Image)
×
1647
        self.assertIsInstance(x2, Image)
×
1648

1649
    def test_freeze_transforms_chain(self):
1✔
1650
        original_dataset = MNIST(
×
1651
            root=default_dataset_location("mnist"),
1652
            download=True,
1653
            transform=ToTensor(),
1654
        )
1655
        x, *_ = original_dataset[0]
×
1656
        self.assertIsInstance(x, Tensor)
×
1657

1658
        dataset_transform = _make_taskaware_classification_dataset(
×
1659
            original_dataset, transform=ToPILImage()
1660
        )  # TRANSFORMS: ToTensor -> ToPILImage
1661
        self.assertIsInstance(dataset_transform[0][0], Image)
×
1662

1663
        dataset_frozen = dataset_transform.freeze_transforms()
×
1664
        self.assertIsInstance(dataset_frozen[0][0], Image)
×
1665
        self.assertIsInstance(dataset_transform[0][0], Image)
×
1666

1667
        dataset_transform = dataset_transform.replace_current_transform_group(None)
×
1668
        self.assertIsInstance(dataset_transform[0][0], Tensor)
×
1669
        self.assertIsInstance(dataset_frozen[0][0], Image)
×
1670

1671
        dataset_frozen = dataset_frozen.replace_current_transform_group(ToTensor())
×
1672
        self.assertIsInstance(dataset_transform[0][0], Tensor)
×
1673
        self.assertIsInstance(dataset_frozen[0][0], Tensor)
×
1674

1675
        dataset_frozen2 = dataset_frozen.freeze_transforms()
×
1676
        x2, *_ = dataset_frozen2[0]
×
1677
        self.assertIsInstance(x2, Tensor)
×
1678

1679
        dataset_frozen = dataset_frozen.replace_current_transform_group(None)
×
1680
        x2, *_ = dataset_frozen2[0]
×
1681
        self.assertIsInstance(x2, Tensor)
×
1682
        x2, *_ = dataset_frozen[0]
×
1683
        self.assertIsInstance(x2, Image)
×
1684

1685
    def test_replace_transforms(self):
1✔
1686
        original_dataset = MNIST(root=default_dataset_location("mnist"), download=True)
×
1687
        x, y = original_dataset[0]
×
1688
        dataset = _make_taskaware_classification_dataset(
×
1689
            original_dataset, transform=ToTensor()
1690
        )
1691
        x2, *_ = dataset[0]
×
1692
        dataset_reset = dataset.replace_current_transform_group(None)
×
1693
        x3, *_ = dataset_reset[0]
×
1694

1695
        self.assertIsInstance(x, Image)
×
1696
        self.assertIsInstance(x2, Tensor)
×
1697
        self.assertIsInstance(x3, Image)
×
1698

1699
        dataset_reset = dataset_reset.replace_current_transform_group(ToTensor())
×
1700
        x4, *_ = dataset_reset[0]
×
1701
        self.assertIsInstance(x4, Tensor)
×
1702

1703
        dataset_reset.replace_current_transform_group(None)
×
1704

1705
        x5, *_ = dataset_reset[0]
×
1706
        self.assertIsInstance(x5, Tensor)
×
1707

1708
        dataset_other = _make_taskaware_classification_dataset(dataset_reset)
×
1709
        dataset_other = dataset_other.replace_current_transform_group(
×
1710
            (None, lambda lll: lll + 1)
1711
        )
1712

1713
        _, y6, _ = dataset_other[0]
×
1714
        self.assertEqual(y + 1, y6)
×
1715

1716
    def test_transforms_replace_freeze_mix(self):
1✔
1717
        original_dataset = MNIST(root=default_dataset_location("mnist"), download=True)
×
1718
        x, _ = original_dataset[0]
×
1719
        dataset = _make_taskaware_classification_dataset(
×
1720
            original_dataset, transform=ToTensor()
1721
        )
1722
        x2, *_ = dataset[0]
×
1723
        dataset_reset = dataset.replace_current_transform_group((None, None))
×
1724
        x3, *_ = dataset_reset[0]
×
1725

1726
        self.assertIsInstance(x, Image)
×
1727
        self.assertIsInstance(x2, Tensor)
×
1728
        self.assertIsInstance(x3, Image)
×
1729

1730
        dataset_frozen = dataset.freeze_transforms()
×
1731

1732
        x4, *_ = dataset_frozen[0]
×
1733
        self.assertIsInstance(x4, Tensor)
×
1734

1735
        dataset_frozen_reset = dataset_frozen.replace_current_transform_group(
×
1736
            (None, None)
1737
        )
1738

1739
        x5, *_ = dataset_frozen_reset[0]
×
1740
        self.assertIsInstance(x5, Tensor)
×
1741

1742
    def test_transforms_groups_base_usage(self):
1✔
1743
        original_dataset = MNIST(root=default_dataset_location("mnist"), download=True)
×
1744
        dataset = _make_taskaware_classification_dataset(
×
1745
            original_dataset,
1746
            transform_groups=dict(
1747
                train=(ToTensor(), None),
1748
                eval=(None, Lambda(lambda t: float(t))),
1749
            ),
1750
        )
1751

1752
        x, y, _ = dataset[0]
×
1753
        self.assertIsInstance(x, Tensor)
×
1754
        self.assertIsInstance(y, int)
×
1755

1756
        dataset_test = dataset.eval()
×
1757

1758
        x2, y2, _ = dataset_test[0]
×
1759
        x3, y3, _ = dataset[0]
×
1760
        self.assertIsInstance(x2, Image)
×
1761
        self.assertIsInstance(y2, float)
×
1762
        self.assertIsInstance(x3, Tensor)
×
1763
        self.assertIsInstance(y3, int)
×
1764

1765
        dataset_train = dataset.train()
×
1766
        dataset = dataset.replace_current_transform_group(None)
×
1767

1768
        x4, y4, _ = dataset_train[0]
×
1769
        x5, y5, _ = dataset[0]
×
1770
        self.assertIsInstance(x4, Tensor)
×
1771
        self.assertIsInstance(y4, int)
×
1772
        self.assertIsInstance(x5, Image)
×
1773
        self.assertIsInstance(y5, int)
×
1774

1775
    def test_transforms_groups_constructor_error(self):
1✔
1776
        original_dataset = load_image_data()
×
1777

1778
        with self.assertRaises(Exception):
×
1779
            # Test is not a tuple has only one element
1780
            dataset = _make_taskaware_classification_dataset(
×
1781
                original_dataset,
1782
                transform_groups=dict(
1783
                    train=(ToTensor(), None),
1784
                    eval=[None, Lambda(lambda t: float(t))],
1785
                ),
1786
            )
1787

1788
        with self.assertRaises(Exception):
×
1789
            # Train is None
1790
            dataset = _make_taskaware_classification_dataset(
×
1791
                original_dataset,
1792
                transform_groups=dict(
1793
                    train=None, eval=(None, Lambda(lambda t: float(t)))
1794
                ),
1795
            )
1796

1797
        with self.assertRaises(Exception):
×
1798
            # transform_groups is not a dictionary
1799
            dataset = _make_taskaware_classification_dataset(
×
1800
                original_dataset, transform_groups="Hello world!"
1801
            )
1802

1803
    def test_transforms_groups_alternative_default_group(self):
1✔
1804
        original_dataset = MNIST(root=default_dataset_location("mnist"), download=True)
×
1805
        dataset = _make_taskaware_classification_dataset(
×
1806
            original_dataset,
1807
            transform_groups=dict(train=(ToTensor(), None), eval=(None, None)),
1808
            initial_transform_group="eval",
1809
        )
1810

1811
        x, *_ = dataset[0]
×
1812
        self.assertIsInstance(x, Image)
×
1813

1814
        dataset_test = dataset.eval()
×
1815

1816
        x2, *_ = dataset_test[0]
×
1817
        x3, *_ = dataset[0]
×
1818
        self.assertIsInstance(x2, Image)
×
1819
        self.assertIsInstance(x3, Image)
×
1820

1821
    def test_transforms_groups_partial_constructor(self):
1✔
1822
        original_dataset = MNIST(root=default_dataset_location("mnist"), download=True)
×
1823
        dataset = _make_taskaware_classification_dataset(
×
1824
            original_dataset, transform_groups=dict(train=(ToTensor(), None))
1825
        )
1826

1827
        x, *_ = dataset[0]
×
1828
        self.assertIsInstance(x, Tensor)
×
1829

1830
        dataset = dataset.eval()
×
1831
        x2, *_ = dataset[0]
×
1832
        self.assertIsInstance(x2, Tensor)
×
1833

1834
    def test_transforms_groups_multiple_groups(self):
1✔
1835
        original_dataset = MNIST(root=default_dataset_location("mnist"), download=True)
×
1836
        dataset = _make_taskaware_classification_dataset(
×
1837
            original_dataset,
1838
            transform_groups=dict(
1839
                train=(ToTensor(), None),
1840
                eval=(None, None),
1841
                other=(
1842
                    Compose([ToTensor(), Lambda(lambda tensor: tensor.numpy())]),
1843
                    None,
1844
                ),
1845
            ),
1846
        )
1847

1848
        x, *_ = dataset[0]
×
1849
        self.assertIsInstance(x, Tensor)
×
1850

1851
        dataset = dataset.eval()
×
1852
        x2, *_ = dataset[0]
×
1853
        self.assertIsInstance(x2, Image)
×
1854

1855
        dataset = dataset.with_transforms("other")
×
1856
        x3, *_ = dataset[0]
×
1857
        self.assertIsInstance(x3, np.ndarray)
×
1858

1859
    def test_transformation_concat_dataset(self):
1✔
1860
        original_dataset = MNIST(root=default_dataset_location("mnist"), download=True)
×
1861
        original_dataset2 = MNIST(root=default_dataset_location("mnist"), download=True)
×
1862

1863
        dataset = concat_datasets([original_dataset, original_dataset2])
×
1864
        self.assertEqual(len(original_dataset) + len(original_dataset2), len(dataset))
×
1865

1866
    def test_transformation_concat_dataset_groups(self):
1✔
1867
        original_dataset = _make_taskaware_classification_dataset(
×
1868
            MNIST(root=default_dataset_location("mnist"), download=True),
1869
            transform_groups=dict(eval=(None, None), train=(ToTensor(), None)),
1870
        )
1871
        original_dataset2 = _make_taskaware_classification_dataset(
×
1872
            MNIST(root=default_dataset_location("mnist"), download=True),
1873
            transform_groups=dict(train=(None, None), eval=(ToTensor(), None)),
1874
        )
1875

1876
        dataset = original_dataset.concat(original_dataset2)
×
1877

1878
        self.assertEqual(len(original_dataset) + len(original_dataset2), len(dataset))
×
1879

1880
        x, *_ = dataset[0]
×
1881
        x2, *_ = dataset[len(original_dataset)]
×
1882
        self.assertIsInstance(x, Tensor)
×
1883
        self.assertIsInstance(x2, Image)
×
1884

1885
        dataset = dataset.eval()
×
1886

1887
        x3, *_ = dataset[0]
×
1888
        x4, *_ = dataset[len(original_dataset)]
×
1889
        self.assertIsInstance(x3, Image)
×
1890
        self.assertIsInstance(x4, Tensor)
×
1891

1892

1893
if __name__ == "__main__":
1✔
1894
    unittest.main()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc