• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ContinualAI / avalanche / 4993189103

pending completion
4993189103

Pull #1370

github

Unknown Committer
Unknown Commit Message
Pull Request #1370: Add base elements to support distributed comms. Add supports_distributed plugin flag.

258 of 822 new or added lines in 27 files covered. (31.39%)

80 existing lines in 5 files now uncovered.

15585 of 21651 relevant lines covered (71.98%)

2.88 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

99.71
/tests/test_high_level_generators.py
1
import os
4✔
2
import tempfile
4✔
3
import unittest
4✔
4
from os.path import expanduser
4✔
5

6
import torch
4✔
7
from numpy.testing import assert_almost_equal
4✔
8
from torchvision.datasets import MNIST
4✔
9
from torchvision.datasets.utils import download_url, extract_archive
4✔
10
from torchvision.transforms import ToTensor
4✔
11
from tests.unit_tests_utils import DummyImageDataset
4✔
12

13

14
from avalanche.benchmarks import (
4✔
15
    dataset_benchmark,
16
    filelist_benchmark,
17
    tensors_benchmark,
18
    paths_benchmark,
19
    data_incremental_benchmark,
20
    benchmark_with_validation_stream,
21
)
22
from avalanche.benchmarks.datasets import default_dataset_location
4✔
23
from avalanche.benchmarks.generators.benchmark_generators import (
4✔
24
    class_balanced_split_strategy,
25
)
26
from avalanche.benchmarks.scenarios.generic_benchmark_creation import (
4✔
27
    create_lazy_generic_benchmark,
28
    LazyStreamDefinition,
29
)
30
from avalanche.benchmarks.utils import (
4✔
31
    make_classification_dataset,
32
    make_tensor_classification_dataset,
33
)
34
from tests.test_avalanche_classification_dataset import get_mbatch
4✔
35
from tests.unit_tests_utils import common_setups, get_fast_benchmark
4✔
36

37

38
class HighLevelGeneratorTests(unittest.TestCase):
4✔
39
    def setUp(self):
4✔
40
        common_setups()
4✔
41

42
    def test_dataset_benchmark(self):
4✔
43
        train_MNIST = MNIST(
4✔
44
            root=default_dataset_location("mnist"), train=True, download=True
45
        )
46
        test_MNIST = MNIST(
4✔
47
            root=default_dataset_location("mnist"), train=False, download=True
48
        )
49

50
        train_cifar10 = DummyImageDataset(n_classes=10)
4✔
51
        test_cifar10 = DummyImageDataset(n_classes=10)
4✔
52

53
        generic_benchmark = dataset_benchmark(
4✔
54
            [train_MNIST, train_cifar10], [test_MNIST, test_cifar10]
55
        )
56

57
    def test_dataset_benchmark_avalanche_dataset(self):
4✔
58
        train_MNIST = make_classification_dataset(
4✔
59
            MNIST(
60
                root=default_dataset_location("mnist"),
61
                train=True,
62
                download=True,
63
            ),
64
            task_labels=0,
65
        )
66

67
        test_MNIST = make_classification_dataset(
4✔
68
            MNIST(
69
                root=default_dataset_location("mnist"),
70
                train=False,
71
                download=True,
72
            ),
73
            task_labels=0,
74
        )
75

76
        train_cifar10 = make_classification_dataset(
4✔
77
            DummyImageDataset(n_classes=10),
78
            task_labels=1,
79
        )
80

81
        test_cifar10 = make_classification_dataset(
4✔
82
            DummyImageDataset(n_classes=10),
83
            task_labels=1,
84
        )
85

86
        generic_benchmark = dataset_benchmark(
4✔
87
            [train_MNIST, train_cifar10], [test_MNIST, test_cifar10]
88
        )
89

90
        self.assertEqual(0, generic_benchmark.train_stream[0].task_label)
4✔
91
        self.assertEqual(1, generic_benchmark.train_stream[1].task_label)
4✔
92
        self.assertEqual(0, generic_benchmark.test_stream[0].task_label)
4✔
93
        self.assertEqual(1, generic_benchmark.test_stream[1].task_label)
4✔
94

95
    def test_filelist_benchmark(self):
4✔
96
        download_url(
4✔
97
            "https://storage.googleapis.com/mledu-datasets/"
98
            "cats_and_dogs_filtered.zip",
99
            expanduser("~") + "/.avalanche/data",
100
            "cats_and_dogs_filtered.zip",
101
        )
102
        archive_name = os.path.join(
4✔
103
            expanduser("~") + "/.avalanche/data", "cats_and_dogs_filtered.zip"
104
        )
105
        extract_archive(
4✔
106
            archive_name, to_path=expanduser("~") + "/.avalanche/data/"
107
        )
108

109
        dirpath = (
4✔
110
            expanduser("~") + "/.avalanche/data/cats_and_dogs_filtered/train"
111
        )
112

113
        with tempfile.TemporaryDirectory() as tmpdirname:
4✔
114
            list_paths = []
4✔
115
            for filelist, rel_dir, label in zip(
4✔
116
                ["train_filelist_00.txt", "train_filelist_01.txt"],
117
                ["cats", "dogs"],
118
                [0, 1],
119
            ):
120
                # First, obtain the list of files
121
                filenames_list = os.listdir(os.path.join(dirpath, rel_dir))
4✔
122
                filelist_path = os.path.join(tmpdirname, filelist)
4✔
123
                list_paths.append(filelist_path)
4✔
124
                with open(filelist_path, "w") as wf:
4✔
125
                    for name in filenames_list:
4✔
126
                        wf.write(
4✔
127
                            "{} {}\n".format(os.path.join(rel_dir, name), label)
128
                        )
129

130
            generic_benchmark = filelist_benchmark(
4✔
131
                dirpath,
132
                list_paths,
133
                [list_paths[0]],
134
                task_labels=[0, 0],
135
                complete_test_set_only=True,
136
                train_transform=ToTensor(),
137
                eval_transform=ToTensor(),
138
            )
139

140
        self.assertEqual(2, len(generic_benchmark.train_stream))
4✔
141
        self.assertEqual(1, len(generic_benchmark.test_stream))
4✔
142

143
    def test_paths_benchmark(self):
4✔
144
        download_url(
4✔
145
            "https://storage.googleapis.com/mledu-datasets/"
146
            "cats_and_dogs_filtered.zip",
147
            expanduser("~") + "/.avalanche/data",
148
            "cats_and_dogs_filtered.zip",
149
        )
150
        archive_name = os.path.join(
4✔
151
            expanduser("~") + "/.avalanche/data", "cats_and_dogs_filtered.zip"
152
        )
153
        extract_archive(
4✔
154
            archive_name, to_path=expanduser("~") + "/.avalanche/data/"
155
        )
156

157
        dirpath = (
4✔
158
            expanduser("~") + "/.avalanche/data/cats_and_dogs_filtered/train"
159
        )
160

161
        train_experiences = []
4✔
162
        for rel_dir, label in zip(["cats", "dogs"], [0, 1]):
4✔
163
            filenames_list = os.listdir(os.path.join(dirpath, rel_dir))
4✔
164

165
            experience_paths = []
4✔
166
            for name in filenames_list:
4✔
167
                instance_tuple = (os.path.join(dirpath, rel_dir, name), label)
4✔
168
                experience_paths.append(instance_tuple)
4✔
169
            train_experiences.append(experience_paths)
4✔
170

171
        generic_benchmark = paths_benchmark(
4✔
172
            train_experiences,
173
            [train_experiences[0]],  # Single test set
174
            task_labels=[0, 0],
175
            complete_test_set_only=True,
176
            train_transform=ToTensor(),
177
            eval_transform=ToTensor(),
178
        )
179

180
        self.assertEqual(2, len(generic_benchmark.train_stream))
4✔
181
        self.assertEqual(1, len(generic_benchmark.test_stream))
4✔
182

183
    def test_tensors_benchmark(self):
4✔
184
        pattern_shape = (3, 32, 32)
4✔
185

186
        # Definition of training experiences
187
        # Experience 1
188
        experience_1_x = torch.zeros(100, *pattern_shape)
4✔
189
        experience_1_y = torch.zeros(100, dtype=torch.long)
4✔
190

191
        # Experience 2
192
        experience_2_x = torch.zeros(80, *pattern_shape)
4✔
193
        experience_2_y = torch.ones(80, dtype=torch.long)
4✔
194

195
        # Test experience
196
        test_x = torch.zeros(50, *pattern_shape)
4✔
197
        test_y = torch.zeros(50, dtype=torch.long)
4✔
198

199
        generic_benchmark = tensors_benchmark(
4✔
200
            train_tensors=[
201
                (experience_1_x, experience_1_y),
202
                (experience_2_x, experience_2_y),
203
            ],
204
            test_tensors=[(test_x, test_y)],
205
            task_labels=[0, 0],  # Task label of each train exp
206
            complete_test_set_only=True,
207
        )
208

209
        self.assertEqual(2, len(generic_benchmark.train_stream))
4✔
210
        self.assertEqual(1, len(generic_benchmark.test_stream))
4✔
211

212
    def test_data_incremental_benchmark(self):
4✔
213
        pattern_shape = (3, 32, 32)
4✔
214

215
        # Definition of training experiences
216
        # Experience 1
217
        experience_1_x = torch.zeros(100, *pattern_shape)
4✔
218
        experience_1_y = torch.zeros(100, dtype=torch.long)
4✔
219

220
        # Experience 2
221
        experience_2_x = torch.zeros(80, *pattern_shape)
4✔
222
        experience_2_y = torch.ones(80, dtype=torch.long)
4✔
223

224
        # Test experience
225
        test_x = torch.zeros(50, *pattern_shape)
4✔
226
        test_y = torch.zeros(50, dtype=torch.long)
4✔
227

228
        initial_benchmark_instance = tensors_benchmark(
4✔
229
            train_tensors=[
230
                (experience_1_x, experience_1_y),
231
                (experience_2_x, experience_2_y),
232
            ],
233
            test_tensors=[(test_x, test_y)],
234
            task_labels=[0, 0],  # Task label of each train exp
235
            complete_test_set_only=True,
236
        )
237

238
        data_incremental_instance = data_incremental_benchmark(
4✔
239
            initial_benchmark_instance, 12, shuffle=False, drop_last=False
240
        )
241

242
        self.assertEqual(16, len(data_incremental_instance.train_stream))
4✔
243
        self.assertEqual(1, len(data_incremental_instance.test_stream))
4✔
244
        self.assertTrue(data_incremental_instance.complete_test_set_only)
4✔
245

246
        tensor_idx = 0
4✔
247
        ref_tensor_x = experience_1_x
4✔
248
        ref_tensor_y = experience_1_y
4✔
249
        for exp in data_incremental_instance.train_stream:
4✔
250
            if exp.current_experience == 8:
4✔
251
                # Last mini-exp from 1st exp
252
                self.assertEqual(4, len(exp.dataset))
4✔
253
            elif exp.current_experience == 15:
4✔
254
                # Last mini-exp from 2nd exp
255
                self.assertEqual(8, len(exp.dataset))
4✔
256
            else:
257
                # Other mini-exp
258
                self.assertEqual(12, len(exp.dataset))
4✔
259

260
            if tensor_idx >= 100:
4✔
261
                ref_tensor_x = experience_2_x
4✔
262
                ref_tensor_y = experience_2_y
4✔
263
                tensor_idx = 0
4✔
264

265
            for x, y, *_ in exp.dataset:
4✔
266
                self.assertTrue(torch.equal(ref_tensor_x[tensor_idx], x))
4✔
267
                self.assertTrue(
4✔
268
                    torch.equal(ref_tensor_y[tensor_idx], torch.tensor(y))
269
                )
270
                tensor_idx += 1
4✔
271

272
        exp = data_incremental_instance.test_stream[0]
4✔
273
        self.assertEqual(50, len(exp.dataset))
4✔
274

275
        tensor_idx = 0
4✔
276
        for x, y, *_ in exp.dataset:
4✔
277
            self.assertTrue(torch.equal(test_x[tensor_idx], x))
4✔
278
            self.assertTrue(torch.equal(test_y[tensor_idx], torch.tensor(y)))
4✔
279
            tensor_idx += 1
4✔
280

281
    def test_data_incremental_benchmark_from_lazy_benchmark(self):
4✔
282
        pattern_shape = (3, 32, 32)
4✔
283

284
        # Definition of training experiences
285
        # Experience 1
286
        experience_1_x = torch.zeros(100, *pattern_shape)
4✔
287
        experience_1_y = torch.zeros(100, dtype=torch.long)
4✔
288
        experience_1_dataset = make_tensor_classification_dataset(
4✔
289
            experience_1_x, experience_1_y
290
        )
291

292
        # Experience 2
293
        experience_2_x = torch.zeros(80, *pattern_shape)
4✔
294
        experience_2_y = torch.ones(80, dtype=torch.long)
4✔
295
        experience_2_dataset = make_tensor_classification_dataset(
4✔
296
            experience_2_x, experience_2_y
297
        )
298

299
        # Test experience
300
        test_x = torch.zeros(50, *pattern_shape)
4✔
301
        test_y = torch.zeros(50, dtype=torch.long)
4✔
302
        experience_test = make_tensor_classification_dataset(test_x, test_y)
4✔
303

304
        def train_gen():
4✔
305
            # Lazy generator of the training stream
306
            for dataset in [experience_1_dataset, experience_2_dataset]:
4✔
307
                yield dataset
4✔
308

309
        def test_gen():
4✔
310
            # Lazy generator of the test stream
311
            for dataset in [experience_test]:
4✔
312
                yield dataset
4✔
313

314
        initial_benchmark_instance = create_lazy_generic_benchmark(
4✔
315
            train_generator=LazyStreamDefinition(train_gen(), 2, [0, 0]),
316
            test_generator=LazyStreamDefinition(test_gen(), 1, [0]),
317
            complete_test_set_only=True,
318
        )
319

320
        data_incremental_instance = data_incremental_benchmark(
4✔
321
            initial_benchmark_instance, 12, shuffle=False, drop_last=False
322
        )
323

324
        self.assertEqual(16, len(data_incremental_instance.train_stream))
4✔
325
        self.assertEqual(1, len(data_incremental_instance.test_stream))
4✔
326
        self.assertTrue(data_incremental_instance.complete_test_set_only)
4✔
327

328
        tensor_idx = 0
4✔
329
        ref_tensor_x = experience_1_x
4✔
330
        ref_tensor_y = experience_1_y
4✔
331
        for exp in data_incremental_instance.train_stream:
4✔
332
            if exp.current_experience == 8:
4✔
333
                # Last mini-exp from 1st exp
334
                self.assertEqual(4, len(exp.dataset))
4✔
335
            elif exp.current_experience == 15:
4✔
336
                # Last mini-exp from 2nd exp
337
                self.assertEqual(8, len(exp.dataset))
4✔
338
            else:
339
                # Other mini-exp
340
                self.assertEqual(12, len(exp.dataset))
4✔
341

342
            if tensor_idx >= 100:
4✔
343
                ref_tensor_x = experience_2_x
4✔
344
                ref_tensor_y = experience_2_y
4✔
345
                tensor_idx = 0
4✔
346

347
            for x, y, *_ in exp.dataset:
4✔
348
                self.assertTrue(torch.equal(ref_tensor_x[tensor_idx], x))
4✔
349
                self.assertTrue(
4✔
350
                    torch.equal(ref_tensor_y[tensor_idx], torch.tensor(y))
351
                )
352
                tensor_idx += 1
4✔
353

354
        exp = data_incremental_instance.test_stream[0]
4✔
355
        self.assertEqual(50, len(exp.dataset))
4✔
356

357
        tensor_idx = 0
4✔
358
        for x, y, *_ in exp.dataset:
4✔
359
            self.assertTrue(torch.equal(test_x[tensor_idx], x))
4✔
360
            self.assertTrue(torch.equal(test_y[tensor_idx], torch.tensor(y)))
4✔
361
            tensor_idx += 1
4✔
362

363
    def test_benchmark_with_validation_stream_fixed_size(self):
4✔
364
        pattern_shape = (3, 32, 32)
4✔
365

366
        # Definition of training experiences
367
        # Experience 1
368
        experience_1_x = torch.zeros(100, *pattern_shape)
4✔
369
        experience_1_y = torch.zeros(100, dtype=torch.long)
4✔
370

371
        # Experience 2
372
        experience_2_x = torch.zeros(80, *pattern_shape)
4✔
373
        experience_2_y = torch.ones(80, dtype=torch.long)
4✔
374

375
        # Test experience
376
        test_x = torch.zeros(50, *pattern_shape)
4✔
377
        test_y = torch.zeros(50, dtype=torch.long)
4✔
378

379
        initial_benchmark_instance = tensors_benchmark(
4✔
380
            train_tensors=[
381
                (experience_1_x, experience_1_y),
382
                (experience_2_x, experience_2_y),
383
            ],
384
            test_tensors=[(test_x, test_y)],
385
            task_labels=[0, 0],  # Task label of each train exp
386
            complete_test_set_only=True,
387
        )
388

389
        valid_benchmark = benchmark_with_validation_stream(
4✔
390
            initial_benchmark_instance, 20, shuffle=False
391
        )
392

393
        self.assertEqual(2, len(valid_benchmark.train_stream))
4✔
394
        self.assertEqual(2, len(valid_benchmark.valid_stream))
4✔
395
        self.assertEqual(1, len(valid_benchmark.test_stream))
4✔
396
        self.assertTrue(valid_benchmark.complete_test_set_only)
4✔
397

398
        self.assertEqual(80, len(valid_benchmark.train_stream[0].dataset))
4✔
399
        self.assertEqual(60, len(valid_benchmark.train_stream[1].dataset))
4✔
400
        self.assertEqual(20, len(valid_benchmark.valid_stream[0].dataset))
4✔
401
        self.assertEqual(20, len(valid_benchmark.valid_stream[1].dataset))
4✔
402

403
        vd = valid_benchmark.train_stream[0].dataset
4✔
404
        mb = get_mbatch(vd, len(vd))
4✔
405
        self.assertTrue(torch.equal(experience_1_x[:80], mb[0]))
4✔
406
        self.assertTrue(torch.equal(experience_1_y[:80], mb[1]))
4✔
407

408
        vd = valid_benchmark.train_stream[1].dataset
4✔
409
        mb = get_mbatch(vd, len(vd))
4✔
410
        self.assertTrue(torch.equal(experience_2_x[:60], mb[0]))
4✔
411
        self.assertTrue(torch.equal(experience_2_y[:60], mb[1]))
4✔
412

413
        vd = valid_benchmark.valid_stream[0].dataset
4✔
414
        mb = get_mbatch(vd, len(vd))
4✔
415
        self.assertTrue(torch.equal(experience_1_x[80:], mb[0]))
4✔
416
        self.assertTrue(torch.equal(experience_1_y[80:], mb[1]))
4✔
417

418
        vd = valid_benchmark.valid_stream[1].dataset
4✔
419
        mb = get_mbatch(vd, len(vd))
4✔
420
        self.assertTrue(torch.equal(experience_2_x[60:], mb[0]))
4✔
421
        self.assertTrue(torch.equal(experience_2_y[60:], mb[1]))
4✔
422

423
        vd = valid_benchmark.test_stream[0].dataset
4✔
424
        mb = get_mbatch(vd, len(vd))
4✔
425
        self.assertTrue(torch.equal(test_x, mb[0]))
4✔
426
        self.assertTrue(torch.equal(test_y, mb[1]))
4✔
427

428
    def test_benchmark_with_validation_stream_rel_size(self):
4✔
429
        pattern_shape = (3, 32, 32)
4✔
430

431
        # Definition of training experiences
432
        # Experience 1
433
        experience_1_x = torch.zeros(100, *pattern_shape)
4✔
434
        experience_1_y = torch.zeros(100, dtype=torch.long)
4✔
435

436
        # Experience 2
437
        experience_2_x = torch.zeros(80, *pattern_shape)
4✔
438
        experience_2_y = torch.ones(80, dtype=torch.long)
4✔
439

440
        # Test experience
441
        test_x = torch.zeros(50, *pattern_shape)
4✔
442
        test_y = torch.zeros(50, dtype=torch.long)
4✔
443

444
        initial_benchmark_instance = tensors_benchmark(
4✔
445
            train_tensors=[
446
                (experience_1_x, experience_1_y),
447
                (experience_2_x, experience_2_y),
448
            ],
449
            test_tensors=[(test_x, test_y)],
450
            task_labels=[0, 0],  # Task label of each train exp
451
            complete_test_set_only=True,
452
        )
453

454
        valid_benchmark = benchmark_with_validation_stream(
4✔
455
            initial_benchmark_instance, 0.2, shuffle=False
456
        )
457
        true_rel_1_valid = int(100 * 0.2)
4✔
458
        true_rel_1_train = 100 - true_rel_1_valid
4✔
459
        true_rel_2_valid = int(80 * 0.2)
4✔
460
        true_rel_2_train = 80 - true_rel_2_valid
4✔
461

462
        self.assertEqual(2, len(valid_benchmark.train_stream))
4✔
463
        self.assertEqual(2, len(valid_benchmark.valid_stream))
4✔
464
        self.assertEqual(1, len(valid_benchmark.test_stream))
4✔
465
        self.assertTrue(valid_benchmark.complete_test_set_only)
4✔
466

467
        ts = valid_benchmark.train_stream
4✔
468
        self.assertEqual(true_rel_1_train, len(ts[0].dataset))
4✔
469
        self.assertEqual(true_rel_2_train, len(ts[1].dataset))
4✔
470

471
        stm = valid_benchmark.valid_stream
4✔
472
        self.assertEqual(true_rel_1_valid, len(stm[0].dataset))
4✔
473
        self.assertEqual(true_rel_2_valid, len(stm[1].dataset))
4✔
474

475
        dd = valid_benchmark.train_stream[0].dataset
4✔
476
        mb = get_mbatch(dd, len(dd))
4✔
477
        self.assertTrue(torch.equal(experience_1_x[:true_rel_1_train], mb[0]))
4✔
478

479
        dd = valid_benchmark.train_stream[1].dataset
4✔
480
        mb = get_mbatch(dd, len(dd))
4✔
481
        self.assertTrue(torch.equal(experience_2_x[:true_rel_2_train], mb[0]))
4✔
482
        self.assertTrue(torch.equal(experience_2_y[:true_rel_2_train], mb[1]))
4✔
483

484
        dd = valid_benchmark.train_stream[0].dataset
4✔
485
        mb = get_mbatch(dd, len(dd))
4✔
486
        self.assertTrue(torch.equal(experience_1_x[:true_rel_1_train], mb[0]))
4✔
487
        self.assertTrue(torch.equal(experience_1_y[:true_rel_1_train], mb[1]))
4✔
488

489
        dd = valid_benchmark.valid_stream[1].dataset
4✔
490
        mb = get_mbatch(dd, len(dd))
4✔
491
        self.assertTrue(torch.equal(experience_2_x[true_rel_2_train:], mb[0]))
4✔
492
        self.assertTrue(torch.equal(experience_2_y[true_rel_2_train:], mb[1]))
4✔
493

494
        dd = valid_benchmark.valid_stream[0].dataset
4✔
495
        mb = get_mbatch(dd, len(dd))
4✔
496
        self.assertTrue(torch.equal(experience_1_y[true_rel_1_train:], mb[1]))
4✔
497

498
        dd = valid_benchmark.test_stream[0].dataset
4✔
499
        mb = get_mbatch(dd, len(dd))
4✔
500
        self.assertTrue(torch.equal(test_x, mb[0]))
4✔
501
        self.assertTrue(torch.equal(test_y, mb[1]))
4✔
502

503
        # Regression test for #1371
504
        self.assertEquals(
4✔
505
            [0],
506
            valid_benchmark.train_stream[0].classes_in_this_experience
507
        )
508

509
    def test_lazy_benchmark_with_validation_stream_fixed_size(self):
4✔
510
        lazy_options = [None, True, False]
4✔
511
        for lazy_option in lazy_options:
4✔
512
            with self.subTest(lazy_option=lazy_option):
4✔
513
                pattern_shape = (3, 32, 32)
4✔
514

515
                # Definition of training experiences
516
                # Experience 1
517
                experience_1_x = torch.zeros(100, *pattern_shape)
4✔
518
                experience_1_y = torch.zeros(100, dtype=torch.long)
4✔
519
                experience_1_dataset = make_tensor_classification_dataset(
4✔
520
                    experience_1_x, experience_1_y
521
                )
522

523
                # Experience 2
524
                experience_2_x = torch.zeros(80, *pattern_shape)
4✔
525
                experience_2_y = torch.ones(80, dtype=torch.long)
4✔
526
                experience_2_dataset = make_tensor_classification_dataset(
4✔
527
                    experience_2_x, experience_2_y
528
                )
529

530
                # Test experience
531
                test_x = torch.zeros(50, *pattern_shape)
4✔
532
                test_y = torch.zeros(50, dtype=torch.long)
4✔
533
                experience_test = make_tensor_classification_dataset(
4✔
534
                    test_x, test_y
535
                )
536

537
                def train_gen():
4✔
538
                    # Lazy generator of the training stream
539
                    for dataset in [experience_1_dataset, experience_2_dataset]:
4✔
540
                        yield dataset
4✔
541

542
                def test_gen():
4✔
543
                    # Lazy generator of the test stream
544
                    for dataset in [experience_test]:
4✔
545
                        yield dataset
4✔
546

547
                initial_benchmark_instance = create_lazy_generic_benchmark(
4✔
548
                    train_generator=LazyStreamDefinition(
549
                        train_gen(), 2, [0, 0]
550
                    ),
551
                    test_generator=LazyStreamDefinition(test_gen(), 1, [0]),
552
                    complete_test_set_only=True,
553
                )
554

555
                valid_benchmark = benchmark_with_validation_stream(
4✔
556
                    initial_benchmark_instance,
557
                    20,
558
                    shuffle=False,
559
                    lazy_splitting=lazy_option,
560
                )
561

562
                if lazy_option is None or lazy_option:
4✔
563
                    expect_laziness = True
4✔
564
                else:
565
                    expect_laziness = False
4✔
566

567
                self.assertEqual(
4✔
568
                    expect_laziness,
569
                    valid_benchmark.stream_definitions["train"].is_lazy,
570
                )
571

572
                self.assertEqual(2, len(valid_benchmark.train_stream))
4✔
573
                self.assertEqual(2, len(valid_benchmark.valid_stream))
4✔
574
                self.assertEqual(1, len(valid_benchmark.test_stream))
4✔
575
                self.assertTrue(valid_benchmark.complete_test_set_only)
4✔
576

577
                maybe_exp = valid_benchmark.stream_definitions[
4✔
578
                    "train"
579
                ].exps_data.get_experience_if_loaded(0)
580
                self.assertEqual(expect_laziness, maybe_exp is None)
4✔
581

582
                self.assertEqual(
4✔
583
                    80, len(valid_benchmark.train_stream[0].dataset)
584
                )
585

586
                maybe_exp = valid_benchmark.stream_definitions[
4✔
587
                    "train"
588
                ].exps_data.get_experience_if_loaded(1)
589
                self.assertEqual(expect_laziness, maybe_exp is None)
4✔
590

591
                self.assertEqual(
4✔
592
                    60, len(valid_benchmark.train_stream[1].dataset)
593
                )
594

595
                maybe_exp = valid_benchmark.stream_definitions[
4✔
596
                    "valid"
597
                ].exps_data.get_experience_if_loaded(0)
598
                self.assertEqual(expect_laziness, maybe_exp is None)
4✔
599

600
                self.assertEqual(
4✔
601
                    20, len(valid_benchmark.valid_stream[0].dataset)
602
                )
603

604
                maybe_exp = valid_benchmark.stream_definitions[
4✔
605
                    "valid"
606
                ].exps_data.get_experience_if_loaded(1)
607
                self.assertEqual(expect_laziness, maybe_exp is None)
4✔
608

609
                self.assertEqual(
4✔
610
                    20, len(valid_benchmark.valid_stream[1].dataset)
611
                )
612

613
                self.assertIsNotNone(
4✔
614
                    valid_benchmark.stream_definitions[
615
                        "train"
616
                    ].exps_data.get_experience_if_loaded(0)
617
                )
618
                self.assertIsNotNone(
4✔
619
                    valid_benchmark.stream_definitions[
620
                        "valid"
621
                    ].exps_data.get_experience_if_loaded(0)
622
                )
623
                self.assertIsNotNone(
4✔
624
                    valid_benchmark.stream_definitions[
625
                        "train"
626
                    ].exps_data.get_experience_if_loaded(1)
627
                )
628
                self.assertIsNotNone(
4✔
629
                    valid_benchmark.stream_definitions[
630
                        "valid"
631
                    ].exps_data.get_experience_if_loaded(1)
632
                )
633

634
                dd = valid_benchmark.train_stream[0].dataset
4✔
635
                mb = get_mbatch(dd, len(dd))
4✔
636
                self.assertTrue(
4✔
637
                    torch.equal(
638
                        experience_1_x[:80],
639
                        mb[0],
640
                    )
641
                )
642

643
                dd = valid_benchmark.train_stream[1].dataset
4✔
644
                mb = get_mbatch(dd, len(dd))
4✔
645
                self.assertTrue(
4✔
646
                    torch.equal(
647
                        experience_2_x[:60],
648
                        mb[0],
649
                    )
650
                )
651

652
                dd = valid_benchmark.train_stream[0].dataset
4✔
653
                mb = get_mbatch(dd, len(dd))
4✔
654
                self.assertTrue(
4✔
655
                    torch.equal(
656
                        experience_1_y[:80],
657
                        mb[1],
658
                    )
659
                )
660

661
                dd = valid_benchmark.train_stream[1].dataset
4✔
662
                mb = get_mbatch(dd, len(dd))
4✔
663
                self.assertTrue(
4✔
664
                    torch.equal(
665
                        experience_2_y[:60],
666
                        mb[1],
667
                    )
668
                )
669

670
                dd = valid_benchmark.valid_stream[0].dataset
4✔
671
                mb = get_mbatch(dd, len(dd))
4✔
672
                self.assertTrue(torch.equal(experience_1_x[80:], mb[0]))
4✔
673

674
                dd = valid_benchmark.valid_stream[1].dataset
4✔
675
                mb = get_mbatch(dd, len(dd))
4✔
676
                self.assertTrue(torch.equal(experience_2_x[60:], mb[0]))
4✔
677

678
                dd = valid_benchmark.valid_stream[0].dataset
4✔
679
                mb = get_mbatch(dd, len(dd))
4✔
680
                self.assertTrue(torch.equal(experience_1_y[80:], mb[1]))
4✔
681

682
                dd = valid_benchmark.valid_stream[1].dataset
4✔
683
                mb = get_mbatch(dd, len(dd))
4✔
684
                self.assertTrue(torch.equal(experience_2_y[60:], mb[1]))
4✔
685

686
                dd = valid_benchmark.test_stream[0].dataset
4✔
687
                mb = get_mbatch(dd, len(dd))
4✔
688
                self.assertTrue(torch.equal(test_x, mb[0]))
4✔
689
                self.assertTrue(torch.equal(test_y, mb[1]))
4✔
690

691

692
class DataSplitStrategiesTests(unittest.TestCase):
4✔
693
    def test_dataset_benchmark(self):
4✔
694
        benchmark = get_fast_benchmark(n_samples_per_class=1000)
4✔
695
        exp = benchmark.train_stream[0]
4✔
696
        num_classes = len(exp.classes_in_this_experience)
4✔
697

698
        train_d, valid_d = class_balanced_split_strategy(0.5, exp)
4✔
699
        assert abs(len(train_d) - len(valid_d)) <= num_classes
4✔
700
        for cid in exp.classes_in_this_experience:
4✔
701
            train_cnt = (torch.as_tensor(train_d.targets) == cid).sum()
4✔
702
            valid_cnt = (torch.as_tensor(valid_d.targets) == cid).sum()
4✔
703
            assert abs(train_cnt - valid_cnt) <= 1
4✔
704

705
        ratio = 0.123
4✔
706
        len_data = len(exp.dataset)
4✔
707
        train_d, valid_d = class_balanced_split_strategy(ratio, exp)
4✔
708
        assert_almost_equal(len(valid_d) / len_data, ratio, decimal=2)
4✔
709
        for cid in exp.classes_in_this_experience:
4✔
710
            data_cnt = (torch.as_tensor(exp.dataset.targets) == cid).sum()
4✔
711
            valid_cnt = (torch.as_tensor(valid_d.targets) == cid).sum()
4✔
712
            assert_almost_equal(valid_cnt / data_cnt, ratio, decimal=2)
4✔
713

714

715
if __name__ == "__main__":
4✔
UNCOV
716
    unittest.main()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc