• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ContinualAI / avalanche / 5268393053

pending completion
5268393053

Pull #1397

github

web-flow
Merge 60d244754 into e91562200
Pull Request #1397: Specialize benchmark creation helpers

417 of 538 new or added lines in 30 files covered. (77.51%)

43 existing lines in 5 files now uncovered.

16586 of 22630 relevant lines covered (73.29%)

2.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

84.94
/avalanche/benchmarks/scenarios/generic_benchmark_creation.py
1
################################################################################
2
# Copyright (c) 2021 ContinualAI.                                              #
3
# Copyrights licensed under the MIT License.                                   #
4
# See the accompanying LICENSE file for terms.                                 #
5
#                                                                              #
6
# Date: 16-04-2021                                                             #
7
# Author(s): Lorenzo Pellegrini                                                #
8
# E-mail: contact@continualai.org                                              #
9
# Website: avalanche.continualai.org                                           #
10
################################################################################
11

12
""" This module contains mid-level benchmark generators.
4✔
13
Consider using the higher-level ones found in benchmark_generators. If none of
14
them fit your needs, then the helper functions here listed may help.
15
"""
16

17
import itertools
4✔
18
from pathlib import Path
4✔
19
from typing import (
4✔
20
    Callable,
21
    Generator,
22
    List,
23
    Mapping,
24
    Sequence,
25
    TypeVar,
26
    Union,
27
    Any,
28
    Tuple,
29
    Dict,
30
    Optional,
31
    Iterable,
32
    NamedTuple,
33
)
34
from typing_extensions import (
4✔
35
    Protocol,
36
    Literal,
37
)
38
import warnings
4✔
39
from avalanche.benchmarks.scenarios.classification_scenario import (
4✔
40
    ClassificationExperience,
41
    ClassificationScenario,
42
    ClassificationStream,
43
)
44
from avalanche.benchmarks.scenarios.dataset_scenario import (
4✔
45
    DatasetScenario,
46
    DatasetStream,
47
    FactoryBasedStream,
48
    TStreamsUserDict,
49
)
50
from avalanche.benchmarks.scenarios.generic_scenario import DatasetExperience
4✔
51

52
from avalanche.benchmarks.utils import (
4✔
53
    FilelistDataset,
54
    PathsDataset,
55
    common_paths_root,
56
)
57
from torch.utils.data.dataset import Subset, ConcatDataset
4✔
58
from avalanche.benchmarks.utils.classification_dataset import (
4✔
59
    ClassificationDataset,
60
    make_classification_dataset,
61
)
62
from avalanche.benchmarks.utils.data import AvalancheDataset
4✔
63
from avalanche.benchmarks.utils.transform_groups import (
4✔
64
    TransformGroupDef,
65
    XTransform,
66
    YTransform,
67
)
68
from avalanche.benchmarks.utils.utils import (
4✔
69
    _is_int_iterable,
70
    make_generic_dataset,
71
    make_generic_tensor_dataset,
72
)
73
from avalanche.benchmarks.utils.dataset_definitions import (
4✔
74
    IDatasetWithTargets, 
75
    ITensorDataset,
76
)
77

78

79
TDatasetScenario = TypeVar(
4✔
80
    'TDatasetScenario',
81
    bound='DatasetScenario')
82

83
TTargetType = TypeVar(
4✔
84
    'TTargetType',
85
    contravariant=True)
86
TSupportedDataset = TypeVar(
4✔
87
    'TSupportedDataset',
88
    contravariant=True)
89
TAvalancheDataset = TypeVar(
4✔
90
    'TAvalancheDataset',
91
    bound='AvalancheDataset',
92
    covariant=True)
93

94

95
GenericSupportedDataset = Union[
4✔
96
    IDatasetWithTargets,
97
    ITensorDataset,
98
    Subset,
99
    ConcatDataset,
100
    AvalancheDataset
101
]
102

103

104
class DatasetFactory(
4✔
105
        Protocol[
106
            TSupportedDataset,
107
            TTargetType,
108
            TAvalancheDataset]):
109
    def __call__(
4✔
110
        self,
111
        dataset: TSupportedDataset,
112
        *,
113
        transform: Optional[XTransform] = None,
114
        target_transform: Optional[YTransform] = None,
115
        transform_groups: Optional[Mapping[str, TransformGroupDef]] = None,
116
        initial_transform_group: Optional[str] = None,
117
        task_labels: Optional[Union[int, Sequence[int]]] = None,
118
        targets: Optional[Sequence[TTargetType]] = None,
119
        collate_fn: Optional[Callable[[List], Any]] = None
120
    ) -> TAvalancheDataset:
NEW
121
        ...
×
122

123

124
class TensorDatasetFactory(
4✔
125
        Protocol[
126
            TAvalancheDataset]):
127
    def __call__(
4✔
128
        self,
129
        dataset_tensors: Sequence,
130
        *,
131
        task_labels: Optional[Union[int, Sequence[int]]] = None,
132
    ) -> TAvalancheDataset:
NEW
133
        ...
×
134

135

136
def _make_plain_experience(
4✔
137
    stream: DatasetStream[DatasetExperience[TAvalancheDataset]],
138
    experience_idx: int
139
) -> DatasetExperience[TAvalancheDataset]:
140
    dataset = stream.benchmark.stream_definitions[
4✔
141
        stream.name
142
    ].exps_data[experience_idx]
143

144
    return DatasetExperience(
4✔
145
        current_experience=experience_idx,
146
        origin_stream=stream,
147
        benchmark=stream.benchmark,
148
        dataset=dataset
149
    )
150

151

152
def _make_generic_scenario(
4✔
153
        stream_definitions: TStreamsUserDict,
154
        complete_test_set_only: bool):
155
    return DatasetScenario(
4✔
156
        stream_definitions=stream_definitions,
157
        complete_test_set_only=complete_test_set_only,
158
        stream_factory=FactoryBasedStream,
159
        experience_factory=_make_plain_experience
160
    )
161

162

163
def _make_classification_scenario(
4✔
164
    stream_definitions: TStreamsUserDict,
165
    complete_test_set_only: bool
166
) -> ClassificationScenario[
167
        ClassificationStream[
168
            ClassificationExperience[
169
                ClassificationDataset]],
170
        ClassificationExperience[
171
            ClassificationDataset],
172
        ClassificationDataset]:
173
    return ClassificationScenario(
4✔
174
        stream_definitions=stream_definitions,
175
        complete_test_set_only=complete_test_set_only
176
    )
177

178

179
def _detect_legacy_classification_usage(
4✔
180
    all_datasets: Iterable[Any]
181
) -> bool:
182
    """
183
    Used by :func:`create_multi_dataset_generic_benchmark` to check
184
    if the user is trying to create a classification benchmark.
185

186
    While using :func:`create_multi_dataset_generic_benchmark` to create a
187
    classification benchmark is acceptable, it would be better to use
188
    :func:`create_multi_dataset_classification_benchmark`, which returns
189
    a :class:`ClassificationScenario`
190
    
191
    Fields defined in :class:`ClassificationScenario` are not to be found
192
    in the generic :class:`DatasetScenario` instance returned by
193
    func:`create_multi_dataset_generic_benchmark` and may be needed
194
    by some continual learning strategies.
195

196
    This function works by checking if input datasets contain all
197
    int (including NumPy/PyTorch int types) targets.
198
    """
199

200
    for dataset in all_datasets:
4✔
201
        try:
4✔
202
            as_classification_dataset = make_classification_dataset(
4✔
203
                dataset
204
            )
205
            if not _is_int_iterable(as_classification_dataset.targets):
4✔
206
                return False
4✔
NEW
207
        except Exception:
×
NEW
208
            return False
×
209
        
210
    return True
4✔
211

212

213
def _manage_legacy_classification_usage(
4✔
214
    train_datasets: Sequence[GenericSupportedDataset],
215
    test_datasets: Sequence[GenericSupportedDataset],
216
    other_streams_datasets: Optional[
217
        Mapping[str, Sequence[GenericSupportedDataset]]],
218
    dataset_factory: Union[
219
        DatasetFactory,
220
        Literal['check_if_classification']
221
    ],
222
    benchmark_factory: Union[Callable[
223
        [
224
            TStreamsUserDict,
225
            bool
226
        ], TDatasetScenario
227
    ], Literal['check_if_classification']]) -> Tuple[
228
        DatasetFactory, 
229
        Callable[[
230
            TStreamsUserDict,
231
            bool
232
        ], TDatasetScenario]]:
233

234
    check_implicit_classification = \
4✔
235
        dataset_factory == 'check_if_classification' or \
236
        benchmark_factory == 'check_if_classification'
237
    
238
    is_implicit_classification = False
4✔
239
    if check_implicit_classification:
4✔
240
        all_datasets_iterables = [
4✔
241
            train_datasets,
242
            test_datasets,
243
        ]
244

245
        if other_streams_datasets is not None:
4✔
246
            all_datasets_iterables.extend(other_streams_datasets.values())
4✔
247

248
        is_implicit_classification = _detect_legacy_classification_usage(
4✔
249
            itertools.chain(*all_datasets_iterables)
250
        )
251

252
    if is_implicit_classification:
4✔
253
        warnings.warn(
4✔
254
            '`dataset_benchmark` is being called by passing classification '
255
            'datasets. It is recommended to switch to '
256
            '`dataset_classification_benchmark` to make sure a '
257
            '`ClassificationScenario` is returned',
258
            DeprecationWarning
259
        )
260
    
261
    dataset_factory_compat: DatasetFactory
262
    if dataset_factory == 'check_if_classification':
4✔
263
        if is_implicit_classification:
4✔
264
            dataset_factory_compat = make_classification_dataset
4✔
265
        else:
266
            dataset_factory_compat = make_generic_dataset
4✔
267
    else:
268
        dataset_factory_compat = dataset_factory
4✔
269
    
270
    benchmark_factory_compat: Callable[
271
        [
272
            TStreamsUserDict,
273
            bool
274
        ], TDatasetScenario
275
    ]
276
    if benchmark_factory == 'check_if_classification':       
4✔
277
        if is_implicit_classification:
4✔
278
            benchmark_factory_compat = \
4✔
279
                _make_classification_scenario  # type: ignore
280
        else:
281
            benchmark_factory_compat = _make_generic_scenario
4✔
282
    else:
283
        benchmark_factory_compat = benchmark_factory
4✔
284

285
    return dataset_factory_compat, benchmark_factory_compat
4✔
286

287

288
def create_multi_dataset_generic_benchmark(
4✔
289
    train_datasets: Sequence[GenericSupportedDataset],
290
    test_datasets: Sequence[GenericSupportedDataset],
291
    *,
292
    other_streams_datasets: Optional[
293
        Mapping[str, Sequence[GenericSupportedDataset]]] = None,
294
    complete_test_set_only: bool = False,
295
    train_transform: XTransform = None,
296
    train_target_transform: YTransform = None,
297
    eval_transform: XTransform = None,
298
    eval_target_transform: YTransform = None,
299
    other_streams_transforms: Optional[
300
        Mapping[str, Tuple[XTransform, YTransform]]] = None,
301
    dataset_factory: Union[
302
        DatasetFactory,
303
        Literal['check_if_classification']
304
    ] = 'check_if_classification',
305
    benchmark_factory: Union[Callable[
306
        [
307
            TStreamsUserDict,
308
            bool
309
        ], TDatasetScenario
310
    ], Literal['check_if_classification']] = 'check_if_classification'
311
) -> TDatasetScenario:
312
    """
313
    Creates a benchmark instance given a list of datasets. Each dataset will be
314
    considered as a separate experience.
315

316
    Contents of the datasets must already be set, including task labels.
317
    Transformations will be applied if defined.
318

319
    This function allows for the creation of custom streams as well.
320
    While "train" and "test" datasets must always be set, the experience list
321
    for other streams can be defined by using the `other_streams_datasets`
322
    parameter.
323

324
    If transformations are defined, they will be applied to the datasets
325
    of the related stream.
326

327
    :param train_datasets: A list of training datasets.
328
    :param test_datasets: A list of test datasets.
329
    :param other_streams_datasets: A dictionary describing the content of 
330
        custom streams. Keys must be valid stream names (letters and numbers,
331
        not starting with a number) while the value must be a list of dataset.
332
        If this dictionary contains the definition for "train" or "test"
333
        streams then those definition will override the `train_datasets` and
334
        `test_datasets` parameters.
335
    :param complete_test_set_only: If True, only the complete test set will
336
        be returned by the benchmark. This means that the ``test_dataset_list``
337
        parameter must be list with a single element (the complete test set).
338
        Defaults to False.
339
    :param train_transform: The transformation to apply to the training data,
340
        e.g. a random crop, a normalization or a concatenation of different
341
        transformations (see torchvision.transform documentation for a
342
        comprehensive list of possible transformations). Defaults to None.
343
    :param train_target_transform: The transformation to apply to training
344
        patterns targets. Defaults to None.
345
    :param eval_transform: The transformation to apply to the test data,
346
        e.g. a random crop, a normalization or a concatenation of different
347
        transformations (see torchvision.transform documentation for a
348
        comprehensive list of possible transformations). Defaults to None.
349
    :param eval_target_transform: The transformation to apply to test
350
        patterns targets. Defaults to None.
351
    :param other_streams_transforms: Transformations to apply to custom
352
        streams. If no transformations are defined for a custom stream,
353
        then "train" transformations will be used. This parameter must be a
354
        dictionary mapping stream names to transformations. The transformations
355
        must be a two elements tuple where the first element defines the
356
        X transformation while the second element is the Y transformation.
357
        Those elements can be None. If this dictionary contains the
358
        transformations for "train" or "test" streams then those
359
        transformations will override the `train_transform`,
360
        `train_target_transform`, `eval_transform` and
361
        `eval_target_transform` parameters.
362
    :param dataset_factory: The factory for the dataset. Should return
363
        an :class:`AvalancheDataset` (or any subclass) given the input
364
        dataset, the transform groups definition and the name of the
365
        initial group (equal to the name of the stream). Defaults
366
        to :func:`make_generic_dataset`.
367
    :param benchmark_factory: The factory for the benchmark.
368
        Should return the benchmark instance given the stream definitions
369
        and a flag stating if the test stream contains a single dataset.
370
        By default, returns a :class:`DatasetScenario`.
371

372
    :returns: A benchmark instance.
373
    """
374

375
    dataset_factory_compat, benchmark_factory_compat = \
4✔
376
        _manage_legacy_classification_usage(
377
            train_datasets=train_datasets,
378
            test_datasets=test_datasets,
379
            other_streams_datasets=other_streams_datasets,
380
            dataset_factory=dataset_factory,
381
            benchmark_factory=benchmark_factory
382
        )
383

384
    transform_groups = dict(
4✔
385
        train=(train_transform, train_target_transform),
386
        eval=(eval_transform, eval_target_transform),
387
    )
388

389
    if other_streams_transforms is not None:
4✔
390
        for stream_name, stream_transforms in other_streams_transforms.items():
×
391
            if isinstance(stream_transforms, Sequence):
×
392
                if len(stream_transforms) == 1:
×
393
                    # Suppose we got only the transformation for X values
NEW
394
                    warnings.warn(
×
395
                        'Transformations for other streams should be passed '
396
                        'as a 2 elements tuple `(Xtransform, YTransform)`. '
397
                        'You can pass None for the Y transformation.'
398
                    )
NEW
399
                    stream_transforms = (
×
400
                        stream_transforms[0],  # type: ignore
401
                        None)
402
            else:
403
                # Suppose it's the transformation for X values
NEW
404
                warnings.warn(
×
405
                    'Transformations for other streams should be passed '
406
                    'as a 2 elements tuple (Xtransform, YTransform).'
407
                )
UNCOV
408
                stream_transforms = (stream_transforms, None)
×
409

410
            transform_groups[stream_name] = stream_transforms
×
411

412
    input_streams = dict(train=train_datasets, test=test_datasets)
4✔
413

414
    if other_streams_datasets is not None:
4✔
415
        input_streams = {**input_streams, **other_streams_datasets}
4✔
416

417
    if complete_test_set_only:
4✔
418
        if len(input_streams["test"]) != 1:
4✔
419
            raise ValueError(
4✔
420
                "Test stream must contain one experience when"
421
                "complete_test_set_only is True"
422
            )
423

424
    stream_definitions: Dict[str, Tuple[Iterable[AvalancheDataset]]] = \
4✔
425
        dict()
426

427
    for stream_name, dataset_list in input_streams.items():
4✔
428
        initial_transform_group = "train"
4✔
429
        if stream_name in transform_groups:
4✔
430
            initial_transform_group = stream_name
4✔
431

432
        stream_datasets = []
4✔
433
        for dataset_idx in range(len(dataset_list)):
4✔
434
            dataset = dataset_list[dataset_idx]
4✔
435

436
            stream_datasets.append(
4✔
437
                dataset_factory_compat(
438
                    dataset=dataset,
439
                    transform_groups=transform_groups,
440
                    initial_transform_group=initial_transform_group
441
                )
442
            )
443
        stream_definitions[stream_name] = (stream_datasets,)
4✔
444

445
    return benchmark_factory_compat(
4✔
446
        stream_definitions,
447
        complete_test_set_only,
448
    )
449

450

451
def _adapt_lazy_stream(
4✔
452
        generator,
453
        transform_groups,
454
        initial_transform_group,
455
        dataset_factory):
456
    """
457
    A simple internal utility to apply transforms and dataset type to all lazily
458
    generated datasets. Used in the :func:`create_lazy_generic_benchmark`
459
    benchmark creation helper.
460

461
    :return: A datasets in which the proper transformation groups and dataset
462
        type are applied.
463
    """
464

465
    for dataset in generator:
4✔
466
        dataset = dataset_factory(
4✔
467
            dataset,
468
            transform_groups=transform_groups,
469
            initial_transform_group=initial_transform_group,
470
        )
471
        yield dataset
4✔
472

473

474
class LazyStreamDefinition(NamedTuple):
4✔
475
    """
4✔
476
    A simple class that can be used when preparing the parameters for the
477
    :func:`create_lazy_generic_benchmark` helper.
478

479
    This class is a named tuple containing the fields required for defining
480
    a lazily-created benchmark.
481

482
    - exps_generator: The experiences generator. Can be a "yield"-based
483
      generator, a custom sequence, a standard list or any kind of
484
      iterable returning :class:`AvalancheDataset`.
485
    - stream_length: The number of experiences in the stream. Must match the
486
      number of experiences returned by the generator.
487
    - exps_task_labels: A list containing the list of task labels of each
488
      experience. If an experience contains a single task label, a single int
489
      can be used.
490
    """
491

492
    exps_generator: Iterable[AvalancheDataset]
4✔
493
    """
1✔
494
    The experiences generator. Can be a "yield"-based generator, a custom
495
    sequence, a standard list or any kind of iterable returning
496
    :class:`AvalancheDataset`.
497
    """
498

499
    stream_length: int
4✔
500
    """
1✔
501
    The number of experiences in the stream. Must match the number of
502
    experiences returned by the generator
503
    """
504

505
    exps_task_labels: Sequence[Union[int, Iterable[int]]]
4✔
506
    """
4✔
507
    A list containing the list of task labels of each experience.
508
    If an experience contains a single task label, a single int can be used.
509
    
510
    This field is temporary required for internal purposes to support lazy
511
    streams. This field may become optional in the future.
512
    """
513

514

515
def create_lazy_generic_benchmark(
4✔
516
    train_generator: LazyStreamDefinition,
517
    test_generator: LazyStreamDefinition,
518
    *,
519
    other_streams_generators: Optional[Dict[str, LazyStreamDefinition]] = None,
520
    complete_test_set_only: bool = False,
521
    train_transform: XTransform = None,
522
    train_target_transform: YTransform = None,
523
    eval_transform: XTransform = None,
524
    eval_target_transform: YTransform = None,
525
    other_streams_transforms: Optional[
526
        Mapping[str, Tuple[XTransform, YTransform]]] = None,
527
    dataset_factory: DatasetFactory = make_generic_dataset,
528
    benchmark_factory: Callable[
529
        [
530
            TStreamsUserDict,
531
            bool
532
        ], TDatasetScenario
533
    ] = _make_generic_scenario
534
) -> TDatasetScenario:
535
    """
536
    Creates a lazily-defined benchmark instance given a dataset generator for
537
    each stream.
538

539
    Generators must return properly initialized instances of
540
    :class:`AvalancheDataset` which will be used to create experiences.
541

542
    The created datasets can have transformations already set.
543
    However, if transformations are shared across all datasets of the same
544
    stream, it is recommended to use the `train_transform`, `eval_transform`
545
    and `other_streams_transforms` parameters, so that transformations groups
546
    can be correctly applied (transformations are lazily added atop the datasets
547
    returned by the generators).
548

549
    This function allows for the creation of custom streams as well.
550
    While "train" and "test" streams must be always set, the generators
551
    for other streams can be defined by using the `other_streams_generators`
552
    parameter.
553

554
    :param train_generator: A proper lazy-generation definition for the training
555
        stream. It is recommended to pass an instance
556
        of :class:`LazyStreamDefinition`. See its description for more details.
557
    :param test_generator: A proper lazy-generation definition for the test
558
        stream. It is recommended to pass an instance
559
        of :class:`LazyStreamDefinition`. See its description for more details.
560
    :param other_streams_generators: A dictionary describing the content of
561
        custom streams. Keys must be valid stream names (letters and numbers,
562
        not starting with a number) while the value must be a
563
        lazy-generation definition (like the ones of the training and
564
        test streams). If this dictionary contains the definition for
565
        "train" or "test" streams then those definition will override the
566
        `train_generator` and `test_generator` parameters.
567
    :param complete_test_set_only: If True, only the complete test set will
568
        be returned by the benchmark. This means that the ``test_generator``
569
        parameter must define a stream with a single experience (the complete
570
        test set). Defaults to False.
571
    :param train_transform: The transformation to apply to the training data,
572
        e.g. a random crop, a normalization or a concatenation of different
573
        transformations (see torchvision.transform documentation for a
574
        comprehensive list of possible transformations). Defaults to None.
575
    :param train_target_transform: The transformation to apply to training
576
        patterns targets. Defaults to None.
577
    :param eval_transform: The transformation to apply to the test data,
578
        e.g. a random crop, a normalization or a concatenation of different
579
        transformations (see torchvision.transform documentation for a
580
        comprehensive list of possible transformations). Defaults to None.
581
    :param eval_target_transform: The transformation to apply to test
582
        patterns targets. Defaults to None.
583
    :param other_streams_transforms: Transformations to apply to custom
584
        streams. If no transformations are defined for a custom stream,
585
        then "train" transformations will be used. This parameter must be a
586
        dictionary mapping stream names to transformations. The transformations
587
        must be a two elements tuple where the first element defines the
588
        X transformation while the second element is the Y transformation.
589
        Those elements can be None. If this dictionary contains the
590
        transformations for "train" or "test" streams then those transformations
591
        will override the `train_transform`, `train_target_transform`,
592
        `eval_transform` and `eval_target_transform` parameters.
593
    :param dataset_factory: The factory for the dataset. Should return
594
        an :class:`AvalancheDataset` (or any subclass) given the input
595
        dataset, the transform groups definition and the name of the
596
        initial group (equal to the name of the stream). Defaults
597
        to :func:`make_generic_dataset`.
598
    :param benchmark_factory: The factory for the benchmark.
599
        Should return the benchmark instance given the stream definitions
600
        and a flag stating if the test stream contains a single dataset.
601
        By default, returns a :class:`DatasetScenario`.
602
    
603
    :returns: A lazily-initialized benchmark instance.
604
    """
605

606
    transform_groups = dict(
4✔
607
        train=(train_transform, train_target_transform),
608
        eval=(eval_transform, eval_target_transform),
609
    )
610

611
    if other_streams_transforms is not None:
4✔
612
        for stream_name, stream_transforms in other_streams_transforms.items():
×
613
            if isinstance(stream_transforms, Sequence):
×
614
                if len(stream_transforms) == 1:
×
615
                    # Suppose we got only the transformation for X values
NEW
616
                    warnings.warn(
×
617
                        'Transformations for other streams should be passed '
618
                        'as a 2 elements tuple `(Xtransform, YTransform)`. '
619
                        'You can pass None for the Y transformation.'
620
                    )
NEW
621
                    stream_transforms = (
×
622
                        stream_transforms[0],  # type: ignore
623
                        None)
624
            else:
625
                # Suppose it's the transformation for X values
NEW
626
                warnings.warn(
×
627
                    'Transformations for other streams should be passed '
628
                    'as a 2 elements tuple (Xtransform, YTransform).'
629
                )
UNCOV
630
                stream_transforms = (stream_transforms, None)
×
631

632
            transform_groups[stream_name] = stream_transforms
×
633

634
    input_streams = dict(train=train_generator, test=test_generator)
4✔
635

636
    if other_streams_generators is not None:
4✔
637
        input_streams = {**input_streams, **other_streams_generators}
×
638

639
    if complete_test_set_only:
4✔
640
        if input_streams["test"][1] != 1:
4✔
641
            raise ValueError(
×
642
                "Test stream must contain one experience when"
643
                "complete_test_set_only is True"
644
            )
645
    
646
    stream_definitions: Dict[
4✔
647
        str, Tuple[
648
            # Dataset generator + stream length
649
            Tuple[Generator[AvalancheDataset, None, None], int],
650
            # Task label(s) for each experience
651
            Iterable[Union[int, Iterable[int]]]
652
            ]
653
        ] = dict()
654

655
    for stream_name, (
4✔
656
        generator,
657
        stream_length,
658
        task_labels,
659
    ) in input_streams.items():
660
        initial_transform_group = "train"
4✔
661
        if stream_name in transform_groups:
4✔
662
            initial_transform_group = stream_name
4✔
663

664
        adapted_stream_generator = _adapt_lazy_stream(
4✔
665
            generator,
666
            transform_groups,
667
            initial_transform_group=initial_transform_group,
668
            dataset_factory=dataset_factory
669
        )
670

671
        stream_definitions[stream_name] = (
4✔
672
            (adapted_stream_generator, stream_length),
673
            task_labels,
674
        )
675

676
    return benchmark_factory(
4✔
677
        stream_definitions,
678
        complete_test_set_only
679
    )
680

681

682
def create_generic_benchmark_from_filelists(
4✔
683
    root: Optional[Union[str, Path]],
684
    train_file_lists: Sequence[Union[str, Path]],
685
    test_file_lists: Sequence[Union[str, Path]],
686
    *,
687
    other_streams_file_lists: Optional[
688
        Dict[str, Sequence[Union[str, Path]]]] = None,
689
    task_labels: Sequence[int],
690
    complete_test_set_only: bool = False,
691
    train_transform: XTransform = None,
692
    train_target_transform: YTransform = None,
693
    eval_transform: XTransform = None,
694
    eval_target_transform: YTransform = None,
695
    other_streams_transforms: Optional[
696
        Mapping[str, Tuple[XTransform, YTransform]]] = None,
697
    dataset_factory: DatasetFactory = make_classification_dataset,
698
    benchmark_factory: Callable[
699
        [
700
            TStreamsUserDict,
701
            bool
702
        ], TDatasetScenario
703
    ] = _make_classification_scenario  # type: ignore
704
) -> TDatasetScenario:
705
    """
706
    Creates a benchmark instance given a list of filelists and the respective
707
    task labels. A separate dataset will be created for each filelist and each
708
    of those datasets will be considered a separate experience.
709

710
    This helper functions is the best shot when loading Caffe-style dataset
711
    based on filelists.
712

713
    Beware that this helper function is limited is the following two aspects:
714

715
    - The resulting benchmark instance and the intermediate datasets used to
716
      populate it will be of type CLASSIFICATION.
717
    - Task labels can only be defined by choosing a single task label for
718
      each experience (the same task label is applied to all patterns of
719
      experiences sharing the same position in different streams).
720

721
    Despite those constraints, this helper function is usually sufficiently
722
    powerful to cover most continual learning benchmarks based on file lists.
723

724
    When in need to create a similar benchmark instance starting from an
725
    in-memory list of paths, then the similar helper function
726
    :func:`create_generic_benchmark_from_paths` can be used.
727

728
    When in need to create a benchmark instance in which task labels are defined
729
    in a more fine-grained way, then consider using
730
    :func:`create_multi_dataset_generic_benchmark` by passing properly
731
    initialized :class:`AvalancheDataset` instances.
732

733
    :param root: The root path of the dataset. Can be None.
734
    :param train_file_lists: A list of filelists describing the
735
        paths of the training patterns for each experience.
736
    :param test_file_lists: A list of filelists describing the
737
        paths of the test patterns for each experience.
738
    :param other_streams_file_lists: A dictionary describing the content of
739
        custom streams. Keys must be valid stream names (letters and numbers,
740
        not starting with a number) while the value must be a list of filelists
741
        (same as `train_file_lists` and `test_file_lists` parameters). If this
742
        dictionary contains the definition for "train" or "test" streams then
743
        those definition will  override the `train_file_lists` and
744
        `test_file_lists` parameters.
745
    :param task_labels: A list of task labels. Must contain at least a value
746
        for each experience. Each value describes the task label that will be
747
        applied to all patterns of a certain experience. For more info on that,
748
        see the function description.
749
    :param complete_test_set_only: If True, only the complete test set will
750
        be returned by the benchmark. This means that the ``test_file_lists``
751
        parameter must be list with a single element (the complete test set).
752
        Alternatively, can be a plain string or :class:`Path` object.
753
        Defaults to False.
754
    :param train_transform: The transformation to apply to the training data,
755
        e.g. a random crop, a normalization or a concatenation of different
756
        transformations (see torchvision.transform documentation for a
757
        comprehensive list of possible transformations). Defaults to None.
758
    :param train_target_transform: The transformation to apply to training
759
        patterns targets. Defaults to None.
760
    :param eval_transform: The transformation to apply to the test data,
761
        e.g. a random crop, a normalization or a concatenation of different
762
        transformations (see torchvision.transform documentation for a
763
        comprehensive list of possible transformations). Defaults to None.
764
    :param eval_target_transform: The transformation to apply to test
765
        patterns targets. Defaults to None.
766
    :param other_streams_transforms: Transformations to apply to custom
767
        streams. If no transformations are defined for a custom stream,
768
        then "train" transformations will be used. This parameter must be a
769
        dictionary mapping stream names to transformations. The transformations
770
        must be a two elements tuple where the first element defines the
771
        X transformation while the second element is the Y transformation.
772
        Those elements can be None. If this dictionary contains the
773
        transformations for "train" or "test" streams then those transformations
774
        will override the `train_transform`, `train_target_transform`,
775
        `eval_transform` and `eval_target_transform` parameters.
776
    :param dataset_factory: The factory for the dataset. Should return
777
        an :class:`AvalancheDataset` (or any subclass) given the input
778
        dataset, the transform groups definition and the name of the
779
        initial group (equal to the name of the stream). Defaults
780
        to :func:`make_classification_dataset`.
781
    :param benchmark_factory: The factory for the benchmark.
782
        Should return the benchmark instance given the stream definitions
783
        and a flag stating if the test stream contains a single dataset.
784
        By default, returns a :class:`ClassificationScenario`.
785

786
    :returns: A benchmark instance.
787
    """
788

789
    input_streams = dict(train=train_file_lists, test=test_file_lists)
4✔
790

791
    if other_streams_file_lists is not None:
4✔
792
        input_streams = {**input_streams, **other_streams_file_lists}
×
793

794
    stream_definitions: Dict[str, Sequence[AvalancheDataset]] = dict()
4✔
795

796
    for stream_name, file_lists in input_streams.items():
4✔
797
        stream_datasets: List[AvalancheDataset] = []
4✔
798
        for exp_id, f_list in enumerate(file_lists):
4✔
799

800
            f_list_dataset = FilelistDataset(root, f_list)
4✔
801
            stream_datasets.append(
4✔
802
                dataset_factory(
803
                    f_list_dataset, task_labels=task_labels[exp_id]
804
                )
805
            )
806

807
        stream_definitions[stream_name] = stream_datasets
4✔
808

809
    return create_multi_dataset_generic_benchmark(
4✔
810
        [],
811
        [],
812
        other_streams_datasets=stream_definitions,
813
        train_transform=train_transform,
814
        train_target_transform=train_target_transform,
815
        eval_transform=eval_transform,
816
        eval_target_transform=eval_target_transform,
817
        complete_test_set_only=complete_test_set_only,
818
        other_streams_transforms=other_streams_transforms,
819
        dataset_factory=dataset_factory,
820
        benchmark_factory=benchmark_factory
821
    )
822

823

824
FileAndLabel = Union[
4✔
825
    Tuple[Union[str, Path], int], Tuple[Union[str, Path], int, Sequence]
826
]
827

828

829
def create_generic_benchmark_from_paths(
4✔
830
    train_lists_of_files: Sequence[Sequence[FileAndLabel]],
831
    test_lists_of_files: Sequence[Sequence[FileAndLabel]],
832
    *,
833
    other_streams_lists_of_files: Optional[Dict[
834
        str, Sequence[Sequence[FileAndLabel]]
835
    ]] = None,
836
    task_labels: Sequence[int],
837
    complete_test_set_only: bool = False,
838
    train_transform: XTransform = None,
839
    train_target_transform: YTransform = None,
840
    eval_transform: XTransform = None,
841
    eval_target_transform: YTransform = None,
842
    other_streams_transforms: Optional[
843
        Mapping[str, Tuple[XTransform, YTransform]]] = None,
844
    dataset_factory: Union[
845
        DatasetFactory,
846
        Literal['check_if_classification']
847
    ] = 'check_if_classification',
848
    benchmark_factory: Union[Callable[
849
        [
850
            TStreamsUserDict,
851
            bool
852
        ], TDatasetScenario
853
    ], Literal['check_if_classification']] = 'check_if_classification'
854
) -> TDatasetScenario:
855
    """
856
    Creates a benchmark instance given a sequence of lists of files. A separate
857
    dataset will be created for each list. Each of those datasets
858
    will be considered a separate experience.
859

860
    This is very similar to :func:`create_generic_benchmark_from_filelists`,
861
    with the main difference being that
862
    :func:`create_generic_benchmark_from_filelists` accepts, for each
863
    experience, a file list formatted in Caffe-style. On the contrary, this
864
    accepts a list of tuples where each tuple contains two elements: the full
865
    path to the pattern and its label. Optionally, the tuple may contain a third
866
    element describing the bounding box of the element to crop. This last
867
    bounding box may be useful when trying to extract the part of the image
868
    depicting the desired element.
869

870
    Apart from that, the same limitations of
871
    :func:`create_generic_benchmark_from_filelists` regarding task labels apply.
872

873
    The label of each pattern doesn't have to be an int. Also, a dataset type
874
    can be defined.
875

876
    :param train_lists_of_files: A list of lists. Each list describes the paths
877
        and labels of patterns to include in that training experience, as
878
        tuples. Each tuple must contain two elements: the full path to the
879
        pattern and its class label. Optionally, the tuple may contain a
880
        third element describing the bounding box to use for cropping (top,
881
        left, height, width).
882
    :param test_lists_of_files: A list of lists. Each list describes the paths
883
        and labels of patterns to include in that test experience, as tuples.
884
        Each tuple must contain two elements: the full path to the pattern
885
        and its class label. Optionally, the tuple may contain a third element
886
        describing the bounding box to use for cropping (top, left, height,
887
        width).
888
    :param other_streams_lists_of_files: A dictionary describing the content of
889
        custom streams. Keys must be valid stream names (letters and numbers,
890
        not starting with a number) while the value follow the same structure
891
        of `train_lists_of_files` and `test_lists_of_files` parameters. If this
892
        dictionary contains the definition for "train" or "test" streams then
893
        those definition will  override the `train_lists_of_files` and
894
        `test_lists_of_files` parameters.
895
    :param task_labels: A list of task labels. Must contain at least a value
896
        for each experience. Each value describes the task label that will be
897
        applied to all patterns of a certain experience. For more info on that,
898
        see the function description.
899
    :param complete_test_set_only: If True, only the complete test set will
900
        be returned by the benchmark. This means that the ``test_list_of_files``
901
        parameter must define a single experience (the complete test set).
902
        Defaults to False.
903
    :param train_transform: The transformation to apply to the training data,
904
        e.g. a random crop, a normalization or a concatenation of different
905
        transformations (see torchvision.transform documentation for a
906
        comprehensive list of possible transformations). Defaults to None.
907
    :param train_target_transform: The transformation to apply to training
908
        patterns targets. Defaults to None.
909
    :param eval_transform: The transformation to apply to the test data,
910
        e.g. a random crop, a normalization or a concatenation of different
911
        transformations (see torchvision.transform documentation for a
912
        comprehensive list of possible transformations). Defaults to None.
913
    :param eval_target_transform: The transformation to apply to test
914
        patterns targets. Defaults to None.
915
    :param other_streams_transforms: Transformations to apply to custom
916
        streams. If no transformations are defined for a custom stream,
917
        then "train" transformations will be used. This parameter must be a
918
        dictionary mapping stream names to transformations. The transformations
919
        must be a two elements tuple where the first element defines the
920
        X transformation while the second element is the Y transformation.
921
        Those elements can be None. If this dictionary contains the
922
        transformations for "train" or "test" streams then those transformations
923
        will override the `train_transform`, `train_target_transform`,
924
        `eval_transform` and `eval_target_transform` parameters.
925
    :param dataset_factory: The factory for the dataset. Should return
926
        an :class:`AvalancheDataset` (or any subclass) given the input
927
        dataset, the transform groups definition and the name of the
928
        initial group (equal to the name of the stream). Defaults
929
        to :func:`make_generic_dataset`.
930
    :param benchmark_factory: The factory for the benchmark.
931
        Should return the benchmark instance given the stream definitions
932
        and a flag stating if the test stream contains a single dataset.
933
        By default, returns a :class:`DatasetScenario`.
934

935
    :returns: A benchmark instance.
936
    """
937

938
    input_streams = dict(train=train_lists_of_files, test=test_lists_of_files)
4✔
939

940
    if other_streams_lists_of_files is not None:
4✔
941
        input_streams = {**input_streams, **other_streams_lists_of_files}
×
942

943
    stream_definitions: Dict[str, Sequence[AvalancheDataset]] = dict()
4✔
944

945
    for stream_name, lists_of_files in input_streams.items():
4✔
946
        stream_datasets: List[AvalancheDataset] = []
4✔
947
        for exp_id, list_of_files in enumerate(lists_of_files):
4✔
948
            common_root, exp_paths_list = common_paths_root(list_of_files)
4✔
949
            paths_dataset: PathsDataset[Any, Any] = \
4✔
950
                PathsDataset(common_root, exp_paths_list)
951
            stream_datasets.append(
4✔
952
                make_generic_dataset(
953
                    paths_dataset,
954
                    task_labels=task_labels[exp_id]
955
                )
956
            )
957

958
        stream_definitions[stream_name] = stream_datasets
4✔
959

960
    return create_multi_dataset_generic_benchmark(
4✔
961
        [],
962
        [],
963
        other_streams_datasets=stream_definitions,
964
        train_transform=train_transform,
965
        train_target_transform=train_target_transform,
966
        eval_transform=eval_transform,
967
        eval_target_transform=eval_target_transform,
968
        complete_test_set_only=complete_test_set_only,
969
        other_streams_transforms=other_streams_transforms,
970
        dataset_factory=dataset_factory,
971
        benchmark_factory=benchmark_factory
972
    )
973

974

975
def create_generic_benchmark_from_tensor_lists(
4✔
976
    train_tensors: Sequence[Sequence[Any]],
977
    test_tensors: Sequence[Sequence[Any]],
978
    *,
979
    other_streams_tensors: Optional[Dict[str, Sequence[Sequence[Any]]]] = None,
980
    task_labels: Sequence[int],
981
    complete_test_set_only: bool = False,
982
    train_transform: XTransform = None,
983
    train_target_transform: YTransform = None,
984
    eval_transform: XTransform = None,
985
    eval_target_transform: YTransform = None,
986
    other_streams_transforms: Optional[
987
        Mapping[str, Tuple[XTransform, YTransform]]] = None,
988
    dataset_factory: Union[
989
        DatasetFactory,
990
        Literal['check_if_classification']
991
    ] = 'check_if_classification',
992
    benchmark_factory: Union[Callable[
993
        [
994
            TStreamsUserDict,
995
            bool
996
        ], TDatasetScenario
997
    ], Literal['check_if_classification']] = 'check_if_classification'
998
) -> TDatasetScenario:
999
    """
1000
    Creates a benchmark instance given lists of Tensors. A separate dataset will
1001
    be created from each Tensor tuple (x, y, z, ...) and each of those training
1002
    datasets will be considered a separate training experience. Using this
1003
    helper function is the lowest-level way to create a Continual Learning
1004
    benchmark. When possible, consider using higher level helpers.
1005

1006
    Experiences are defined by passing lists of tensors as the `train_tensors`,
1007
    `test_tensors` (and `other_streams_tensors`) parameters. Those parameters
1008
    must be lists containing lists of tensors, one list for each experience.
1009
    Each tensor defines the value of a feature ("x", "y", "z", ...) for all
1010
    patterns of that experience.
1011

1012
    By default the second tensor of each experience will be used to fill the
1013
    `targets` value (label of each pattern).
1014

1015
    Beware that task labels can only be defined by choosing a single task label
1016
    for each experience (the same task label is applied to all patterns of
1017
    experiences sharing the same position in different streams).
1018

1019
    When in need to create a benchmark instance in which task labels are defined
1020
    in a more fine-grained way, then consider using
1021
    :func:`create_multi_dataset_generic_benchmark` by passing properly
1022
    initialized :class:`AvalancheDataset` instances.
1023

1024
    :param train_tensors: A list of lists. The first list must contain the
1025
        tensors for the first training experience (one tensor per feature), the
1026
        second list must contain the tensors for the second training experience,
1027
        and so on.
1028
    :param test_tensors: A list of lists. The first list must contain the
1029
        tensors for the first test experience (one tensor per feature), the
1030
        second list must contain the tensors for the second test experience,
1031
        and so on. When using `complete_test_set_only`, this parameter
1032
        must be a list containing a single sub-list for the single test
1033
        experience.
1034
    :param other_streams_tensors: A dictionary describing the content of
1035
        custom streams. Keys must be valid stream names (letters and numbers,
1036
        not starting with a number) while the value follow the same structure
1037
        of `train_tensors` and `test_tensors` parameters. If this
1038
        dictionary contains the definition for "train" or "test" streams then
1039
        those definition will  override the `train_tensors` and `test_tensors`
1040
        parameters.
1041
    :param task_labels: A list of task labels. Must contain at least a value
1042
        for each experience. Each value describes the task label that will be
1043
        applied to all patterns of a certain experience. For more info on that,
1044
        see the function description.
1045
    :param complete_test_set_only: If True, only the complete test set will
1046
        be returned by the benchmark. This means that ``test_tensors`` must
1047
        define a single experience. Defaults to False.
1048
    :param train_transform: The transformation to apply to the training data,
1049
        e.g. a random crop, a normalization or a concatenation of different
1050
        transformations (see torchvision.transform documentation for a
1051
        comprehensive list of possible transformations). Defaults to None.
1052
    :param train_target_transform: The transformation to apply to training
1053
        patterns targets. Defaults to None.
1054
    :param eval_transform: The transformation to apply to the test data,
1055
        e.g. a random crop, a normalization or a concatenation of different
1056
        transformations (see torchvision.transform documentation for a
1057
        comprehensive list of possible transformations). Defaults to None.
1058
    :param eval_target_transform: The transformation to apply to test
1059
        patterns targets. Defaults to None.
1060
    :param other_streams_transforms: Transformations to apply to custom
1061
        streams. If no transformations are defined for a custom stream,
1062
        then "train" transformations will be used. This parameter must be a
1063
        dictionary mapping stream names to transformations. The transformations
1064
        must be a two elements tuple where the first element defines the
1065
        X transformation while the second element is the Y transformation.
1066
        Those elements can be None. If this dictionary contains the
1067
        transformations for "train" or "test" streams then those transformations
1068
        will override the `train_transform`, `train_target_transform`,
1069
        `eval_transform` and `eval_target_transform` parameters.
1070
    :param dataset_factory: The factory for the dataset. Should return
1071
        an :class:`AvalancheDataset` (or any subclass) given the input
1072
        dataset, the transform groups definition and the name of the
1073
        initial group (equal to the name of the stream). Defaults
1074
        to :func:`make_generic_dataset`.
1075
    :param tensor_dataset_factory: The factory for the intermediate
1076
        tensor dataset. This is used to convert the tensors list to a
1077
        PyTorch dataset. The returned dataset will be then processed
1078
        again using `dataset_factory`
1079
    :param benchmark_factory: The factory for the benchmark.
1080
        Should return the benchmark instance given the stream definitions
1081
        and a flag stating if the test stream contains a single dataset.
1082
        By default, returns a :class:`DatasetScenario`.
1083

1084
    :returns: A benchmark instance.
1085
    """
1086

1087
    input_streams = dict(train=train_tensors, test=test_tensors)
4✔
1088

1089
    if other_streams_tensors is not None:
4✔
1090
        input_streams = {**input_streams, **other_streams_tensors}
×
1091

1092
    stream_definitions: Dict[str, Sequence[AvalancheDataset]] = dict()
4✔
1093

1094
    for stream_name, list_of_exps_tensors in input_streams.items():
4✔
1095
        stream_datasets: List[AvalancheDataset] = []
4✔
1096
        for exp_id, exp_tensors in enumerate(list_of_exps_tensors):
4✔
1097
            stream_datasets.append(
4✔
1098
                make_generic_tensor_dataset(
1099
                    exp_tensors, task_labels=task_labels[exp_id]
1100
                )
1101
            )
1102

1103
        stream_definitions[stream_name] = stream_datasets
4✔
1104

1105
    return create_multi_dataset_generic_benchmark(
4✔
1106
        [],
1107
        [],
1108
        other_streams_datasets=stream_definitions,
1109
        train_transform=train_transform,
1110
        train_target_transform=train_target_transform,
1111
        eval_transform=eval_transform,
1112
        eval_target_transform=eval_target_transform,
1113
        complete_test_set_only=complete_test_set_only,
1114
        other_streams_transforms=other_streams_transforms,
1115
        dataset_factory=dataset_factory,
1116
        benchmark_factory=benchmark_factory
1117
    )
1118

1119

1120
__all__ = [
4✔
1121
    "create_multi_dataset_generic_benchmark",
1122
    "LazyStreamDefinition",
1123
    "create_lazy_generic_benchmark",
1124
    "create_generic_benchmark_from_filelists",
1125
    "create_generic_benchmark_from_paths",
1126
    "create_generic_benchmark_from_tensor_lists",
1127
]
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc