• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ContinualAI / avalanche / 5268393053

pending completion
5268393053

Pull #1397

github

web-flow
Merge 60d244754 into e91562200
Pull Request #1397: Specialize benchmark creation helpers

417 of 538 new or added lines in 30 files covered. (77.51%)

43 existing lines in 5 files now uncovered.

16586 of 22630 relevant lines covered (73.29%)

2.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

94.35
/avalanche/benchmarks/utils/classification_dataset.py
1
################################################################################
2
# Copyright (c) 2021 ContinualAI.                                              #
3
# Copyrights licensed under the MIT License.                                   #
4
# See the accompanying LICENSE file for terms.                                 #
5
#                                                                              #
6
# Date: 12-05-2020                                                             #
7
# Author(s): Lorenzo Pellegrini, Antonio Carta                                 #
8
# E-mail: contact@continualai.org                                              #
9
# Website: avalanche.continualai.org                                           #
10
################################################################################
11

12
"""
4✔
13
This module contains the implementation of the ``ClassificationDataset``,
14
which is the dataset used for supervised continual learning benchmarks.
15
ClassificationDatasets are ``AvalancheDatasets`` that manage class and task
16
labels automatically. Concatenation and subsampling operations are optimized
17
to be used frequently, as is common in replay strategies.
18
"""
19

20
from functools import partial
4✔
21
import torch
4✔
22
from torch.utils.data.dataset import Subset, ConcatDataset, TensorDataset
4✔
23

24
from avalanche.benchmarks.utils.utils import (
4✔
25
    TaskSet,
26
    _count_unique,
27
    find_common_transforms_group,
28
    _init_task_labels,
29
    _init_transform_groups,
30
    _split_user_def_targets,
31
    _split_user_def_task_label,
32
    _traverse_supported_dataset,
33
)
34

35
from avalanche.benchmarks.utils.data import AvalancheDataset
4✔
36
from avalanche.benchmarks.utils.transform_groups import (
4✔
37
    TransformGroupDef,
38
    DefaultTransformGroups,
39
    TransformGroups,
40
    XTransform,
41
    YTransform,
42
)
43
from avalanche.benchmarks.utils.data_attribute import DataAttribute
4✔
44
from avalanche.benchmarks.utils.dataset_utils import (
4✔
45
    SubSequence,
46
)
47
from avalanche.benchmarks.utils.flat_data import ConstantSequence
4✔
48
from avalanche.benchmarks.utils.dataset_definitions import (
4✔
49
    IDataset,
50
    ISupportedClassificationDataset,
51
    ITensorDataset,
52
    IDatasetWithTargets,
53
)
54

55
from typing import (
4✔
56
    List,
57
    Any,
58
    Sequence,
59
    Union,
60
    Optional,
61
    TypeVar,
62
    Callable,
63
    Dict,
64
    Tuple,
65
    Mapping,
66
)
67

68

69
T_co = TypeVar("T_co", covariant=True)
4✔
70
TAvalancheDataset = TypeVar("TAvalancheDataset", bound="AvalancheDataset")
4✔
71
TTargetType = int
4✔
72

73
TClassificationDataset = TypeVar(
4✔
74
    "TClassificationDataset",
75
    bound="ClassificationDataset"
76
)
77

78

79
def lookup(indexable, idx):
4✔
80
    """
81
    A simple function that implements indexing into an indexable object.
82
    Together with 'partial' this allows us to circumvent lambda functions
83
    that cannot be pickled.
84
    """
85
    return indexable[idx]
4✔
86

87

88
class ClassificationDataset(AvalancheDataset[T_co]):
4✔
89

90
    def __init__(
4✔
91
            self,
92
            datasets: Sequence[IDataset[T_co]],
93
            *,
94
            indices: Optional[List[int]] = None,
95
            data_attributes: Optional[List[DataAttribute]] = None,
96
            transform_groups: Optional[TransformGroups] = None,
97
            frozen_transform_groups: Optional[TransformGroups] = None,
98
            collate_fn: Optional[Callable[[List], Any]] = None):
99
        super().__init__(
4✔
100
            datasets=datasets,
101
            indices=indices,
102
            data_attributes=data_attributes,
103
            transform_groups=transform_groups,
104
            frozen_transform_groups=frozen_transform_groups,
105
            collate_fn=collate_fn
106
        )
107

108
        assert 'targets' in self._data_attributes, \
4✔
109
            'The supervised version of the ClassificationDataset requires ' + \
110
            'the targets field'
111
        assert 'targets_task_labels' in self._data_attributes, \
4✔
112
            'The supervised version of the ClassificationDataset requires ' + \
113
            'the targets_task_labels field'
114
    
115
    @property
4✔
116
    def targets(self) -> DataAttribute[TTargetType]:
4✔
117
        return self._data_attributes['targets']
4✔
118

119
    @property
4✔
120
    def targets_task_labels(self) -> DataAttribute[int]:
4✔
121
        return self._data_attributes['targets_task_labels']
4✔
122
    
123
    @property
4✔
124
    def task_pattern_indices(self):
3✔
125
        """A dictionary mapping task ids to their sample indices."""
NEW
126
        return self.targets_task_labels.val_to_idx
×
127

128
    @property
4✔
129
    def task_set(self: TClassificationDataset) -> \
4✔
130
            TaskSet[TClassificationDataset]:
131
        """Returns the datasets's ``TaskSet``, which is a mapping <task-id,
132
        task-dataset>."""
133
        return TaskSet(self)
4✔
134
    
135
    def subset(self, indices):
4✔
136
        data = super().subset(indices)
4✔
137
        return data.with_transforms(
4✔
138
            self._flat_data._transform_groups.current_group)
139

140
    def concat(self, other):
4✔
141
        data = super().concat(other)
4✔
142
        return data.with_transforms(
4✔
143
            self._flat_data._transform_groups.current_group)
144

145
    def __hash__(self):
4✔
146
        return id(self)
4✔
147

148

149
SupportedDataset = Union[
4✔
150
    IDatasetWithTargets,
151
    ITensorDataset,
152
    Subset,
153
    ConcatDataset,
154
    ClassificationDataset
155
]
156

157

158
def make_classification_dataset(
4✔
159
    dataset: SupportedDataset,
160
    *,
161
    transform: Optional[XTransform] = None,
162
    target_transform: Optional[YTransform] = None,
163
    transform_groups: Optional[Mapping[str, TransformGroupDef]] = None,
164
    initial_transform_group: Optional[str] = None,
165
    task_labels: Optional[Union[int, Sequence[int]]] = None,
166
    targets: Optional[Sequence[TTargetType]] = None,
167
    collate_fn: Optional[Callable[[List], Any]] = None
168
) -> ClassificationDataset:
169
    """Avalanche Classification Dataset.
170

171
    Supervised continual learning benchmarks in Avalanche return instances of
172
    this dataset, but it can also be used in a completely standalone manner.
173

174
    This dataset applies input/target transformations, it supports
175
    slicing and advanced indexing and it also contains useful fields as
176
    `targets`, which contains the pattern labels, and `targets_task_labels`,
177
    which contains the pattern task labels. The `task_set` field can be used to
178
    obtain a the subset of patterns labeled with a given task label.
179

180
    This dataset can also be used to apply several advanced operations involving
181
    transformations. For instance, it allows the user to add and replace
182
    transformations, freeze them so that they can't be changed, etc.
183

184
    This dataset also allows the user to keep distinct transformations groups.
185
    Simply put, a transformation group is a pair of transform+target_transform
186
    (exactly as in torchvision datasets). This dataset natively supports keeping
187
    two transformation groups: the first, 'train', contains transformations
188
    applied to training patterns. Those transformations usually involve some
189
    kind of data augmentation. The second one is 'eval', that will contain
190
    transformations applied to test patterns. Having both groups can be
191
    useful when, for instance, in need to test on the training data (as this
192
    process usually involves removing data augmentation operations). Switching
193
    between transformations can be easily achieved by using the
194
    :func:`train` and :func:`eval` methods.
195

196
    Moreover, arbitrary transformation groups can be added and used. For more
197
    info see the constructor and the :func:`with_transforms` method.
198

199
    This dataset will try to inherit the task labels from the input
200
    dataset. If none are available and none are given via the `task_labels`
201
    parameter, each pattern will be assigned a default task label 0.
202

203
    Creates a ``AvalancheDataset`` instance.
204

205
    :param dataset: The dataset to decorate. Beware that
206
        AvalancheDataset will not overwrite transformations already
207
        applied by this dataset.
208
    :param transform: A function/transform that takes the X value of a
209
        pattern from the original dataset and returns a transformed version.
210
    :param target_transform: A function/transform that takes in the target
211
        and transforms it.
212
    :param transform_groups: A dictionary containing the transform groups.
213
        Transform groups are used to quickly switch between training and
214
        eval (test) transformations. This becomes useful when in need to
215
        test on the training dataset as test transformations usually don't
216
        contain random augmentations. ``AvalancheDataset`` natively supports
217
        the 'train' and 'eval' groups by calling the ``train()`` and
218
        ``eval()`` methods. When using custom groups one can use the
219
        ``with_transforms(group_name)`` method instead. Defaults to None,
220
        which means that the current transforms will be used to
221
        handle both 'train' and 'eval' groups (just like in standard
222
        ``torchvision`` datasets).
223
    :param initial_transform_group: The name of the initial transform group
224
        to be used. Defaults to None, which means that the current group of
225
        the input dataset will be used (if an AvalancheDataset). If the
226
        input dataset is not an AvalancheDataset, then 'train' will be
227
        used.
228
    :param task_labels: The task label of each instance. Must be a sequence
229
        of ints, one for each instance in the dataset. Alternatively can be
230
        a single int value, in which case that value will be used as the
231
        task label for all the instances. Defaults to None, which means that
232
        the dataset will try to obtain the task labels from the original
233
        dataset. If no task labels could be found, a default task label
234
        0 will be applied to all instances.
235
    :param targets: The label of each pattern. Defaults to None, which
236
        means that the targets will be retrieved from the dataset (if
237
        possible).
238
    :param collate_fn: The function to use when slicing to merge single
239
        patterns.This function is the function
240
        used in the data loading process, too. If None
241
        the constructor will check if a
242
        `collate_fn` field exists in the dataset. If no such field exists,
243
        the default collate function will be used.
244
    """
245

246
    transform_gs = _init_transform_groups(
4✔
247
        transform_groups,
248
        transform,
249
        target_transform,
250
        initial_transform_group,
251
        dataset,
252
    )
253
    targets_data: Optional[DataAttribute[TTargetType]] = \
4✔
254
        _init_targets(dataset, targets)
255
    task_labels_data: Optional[DataAttribute[int]] = \
4✔
256
        _init_task_labels(dataset, task_labels)
257

258
    das: List[DataAttribute] = []
4✔
259
    if targets_data is not None:
4✔
260
        das.append(targets_data)
4✔
261
    if task_labels_data is not None:
4✔
262
        das.append(task_labels_data)
4✔
263

264
    data: ClassificationDataset = ClassificationDataset(
4✔
265
        [dataset],
266
        data_attributes=das if len(das) > 0 else None,
267
        transform_groups=transform_gs,
268
        collate_fn=collate_fn,
269
    )
270
    
271
    if initial_transform_group is not None:
4✔
272
        return data.with_transforms(initial_transform_group)
4✔
273
    else:
274
        return data
4✔
275

276

277
def _init_targets(dataset, targets, check_shape=True) -> \
4✔
278
        Optional[DataAttribute[TTargetType]]:
279
    if targets is not None:
4✔
280
        # User defined targets always take precedence
281
        if isinstance(targets, int):
4✔
282
            targets = ConstantSequence(targets, len(dataset))
4✔
283
        elif len(targets) != len(dataset) and check_shape:
4✔
284
            raise ValueError(
×
285
                "Invalid amount of target labels. It must be equal to the "
286
                "number of patterns in the dataset. Got {}, expected "
287
                "{}!".format(len(targets), len(dataset))
288
            )
289
        return DataAttribute(targets, "targets")
4✔
290

291
    targets = _traverse_supported_dataset(
4✔
292
        dataset, _select_targets)
293
    
294
    if targets is not None:
4✔
295
        if isinstance(targets, torch.Tensor):
4✔
296
            targets = targets.tolist()
4✔
297

298
    if targets is None:
4✔
299
        return None
×
300
    
301
    return DataAttribute(targets, "targets")
4✔
302

303

304
def classification_subset(
4✔
305
    dataset: SupportedDataset,
306
    indices: Optional[Sequence[int]] = None,
307
    *,
308
    class_mapping: Optional[Sequence[int]] = None,
309
    transform: Optional[XTransform] = None,
310
    target_transform: Optional[YTransform] = None,
311
    transform_groups: Optional[Mapping[str, 
312
                                       Tuple[XTransform, YTransform]]] = None,
313
    initial_transform_group: Optional[str] = None,
314
    task_labels: Optional[Union[int, Sequence[int]]] = None,
315
    targets: Optional[Sequence[TTargetType]] = None,
316
    collate_fn: Optional[Callable[[List], Any]] = None
317
) -> ClassificationDataset:
318
    """Creates an ``AvalancheSubset`` instance.
319

320
    For simple subset operations you should use the method
321
    `dataset.subset(indices)`.
322
    Use this constructor only if you need to redefine transformation or
323
    class/task labels.
324

325
    A Dataset that behaves like a PyTorch :class:`torch.utils.data.Subset`.
326
    This Dataset also supports transformations, slicing, advanced indexing,
327
    the targets field, class mapping and all the other goodies listed in
328
    :class:`AvalancheDataset`.
329

330
    :param dataset: The whole dataset.
331
    :param indices: Indices in the whole set selected for subset. Can
332
        be None, which means that the whole dataset will be returned.
333
    :param class_mapping: A list that, for each possible target (Y) value,
334
        contains its corresponding remapped value. Can be None.
335
        Beware that setting this parameter will force the final
336
        dataset type to be CLASSIFICATION or UNDEFINED.
337
    :param transform: A function/transform that takes the X value of a
338
        pattern from the original dataset and returns a transformed version.
339
    :param target_transform: A function/transform that takes in the target
340
        and transforms it.
341
    :param transform_groups: A dictionary containing the transform groups.
342
        Transform groups are used to quickly switch between training and
343
        eval (test) transformations. This becomes useful when in need to
344
        test on the training dataset as test transformations usually don't
345
        contain random augmentations. ``AvalancheDataset`` natively supports
346
        the 'train' and 'eval' groups by calling the ``train()`` and
347
        ``eval()`` methods. When using custom groups one can use the
348
        ``with_transforms(group_name)`` method instead. Defaults to None,
349
        which means that the current transforms will be used to
350
        handle both 'train' and 'eval' groups (just like in standard
351
        ``torchvision`` datasets).
352
    :param initial_transform_group: The name of the initial transform group
353
        to be used. Defaults to None, which means that the current group of
354
        the input dataset will be used (if an AvalancheDataset). If the
355
        input dataset is not an AvalancheDataset, then 'train' will be
356
        used.
357
    :param task_labels: The task label for each instance. Must be a sequence
358
        of ints, one for each instance in the dataset. This can either be a
359
        list of task labels for the original dataset or the list of task
360
        labels for the instances of the subset (an automatic detection will
361
        be made). In the unfortunate case in which the original dataset and
362
        the subset contain the same amount of instances, then this parameter
363
        is considered to contain the task labels of the subset.
364
        Alternatively can be a single int value, in which case
365
        that value will be used as the task label for all the instances.
366
        Defaults to None, which means that the dataset will try to
367
        obtain the task labels from the original dataset. If no task labels
368
        could be found, a default task label 0 will be applied to all
369
        instances.
370
    :param targets: The label of each pattern. Defaults to None, which
371
        means that the targets will be retrieved from the dataset (if
372
        possible). This can either be a list of target labels for the
373
        original dataset or the list of target labels for the instances of
374
        the subset (an automatic detection will be made). In the unfortunate
375
        case in which the original dataset and the subset contain the same
376
        amount of instances, then this parameter is considered to contain
377
        the target labels of the subset.
378
    :param collate_fn: The function to use when slicing to merge single
379
        patterns. This function is the function
380
        used in the data loading process, too. If None,
381
        the constructor will check if a
382
        `collate_fn` field exists in the dataset. If no such field exists,
383
        the default collate function will be used.
384
    """
385
    
386
    if isinstance(dataset, ClassificationDataset):
4✔
387
        if (
4✔
388
            class_mapping is None
389
            and transform is None
390
            and target_transform is None
391
            and transform_groups is None
392
            and initial_transform_group is None
393
            and task_labels is None
394
            and targets is None
395
            and collate_fn is None
396
        ):
397
            return dataset.subset(indices)
4✔
398

399
    targets_data: Optional[DataAttribute[TTargetType]] = \
4✔
400
        _init_targets(dataset, targets, check_shape=False)
401
    task_labels_data: Optional[DataAttribute[int]] = \
4✔
402
        _init_task_labels(dataset, task_labels, check_shape=False)
403

404
    transform_gs = _init_transform_groups(
4✔
405
        transform_groups,
406
        transform,
407
        target_transform,
408
        initial_transform_group,
409
        dataset,
410
    )
411

412
    if initial_transform_group is not None and isinstance(
4✔
413
        dataset, AvalancheDataset
414
    ):
415
        dataset = dataset.with_transforms(initial_transform_group)
4✔
416

417
    if class_mapping is not None:  # update targets
4✔
418
        if targets_data is None:
4✔
419
            tgs = [class_mapping[el] for el in dataset.targets]  # type: ignore
×
420
        else:
421
            tgs = [class_mapping[el] for el in targets_data]
4✔
422

423
        targets_data = DataAttribute(tgs, "targets")
4✔
424

425
    if class_mapping is not None:
4✔
426
        frozen_transform_groups = DefaultTransformGroups(
4✔
427
            (None, partial(lookup, class_mapping))
428
        )
429
    else:
430
        frozen_transform_groups = None
4✔
431

432
    das = []
4✔
433
    if targets_data is not None:
4✔
434
        das.append(targets_data)
4✔
435
    
436
    if task_labels_data is not None:
4✔
437
        # special treatment for task labels depending on length for
438
        # backward compatibility
439
        if len(task_labels_data) != len(dataset):
4✔
440
            # task labels are already subsampled
441
            dataset_avl = AvalancheDataset(
4✔
442
                [dataset],
443
                indices=list(indices) if indices is not None else None,
444
                data_attributes=das,
445
                transform_groups=transform_gs,
446
                frozen_transform_groups=frozen_transform_groups,
447
                collate_fn=collate_fn,
448
            )
449

450
            # now add task labels
451
            return ClassificationDataset(
4✔
452
                [dataset_avl],
453
                data_attributes=[task_labels_data])
454
        else:
455
            das.append(task_labels_data)
4✔
456

457
    return ClassificationDataset(
4✔
458
        [dataset],
459
        indices=list(indices) if indices is not None else None,
460
        data_attributes=das if len(das) > 0 else None,
461
        transform_groups=transform_gs,
462
        frozen_transform_groups=frozen_transform_groups,
463
        collate_fn=collate_fn,
464
    )
465

466

467
def make_tensor_classification_dataset(
4✔
468
    *dataset_tensors: Sequence,
469
    transform: Optional[XTransform] = None,
470
    target_transform: Optional[YTransform] = None,
471
    transform_groups: Optional[Dict[str, Tuple[XTransform, YTransform]]] = None,
472
    initial_transform_group: Optional[str] = "train",
473
    task_labels: Optional[Union[int, Sequence[int]]] = None,
474
    targets: Optional[Union[Sequence[TTargetType], int]] = None,
475
    collate_fn: Optional[Callable[[List], Any]] = None
476
) -> ClassificationDataset:
477
    """Creates a ``AvalancheTensorDataset`` instance.
478

479
    A Dataset that wraps existing ndarrays, Tensors, lists... to provide
480
    basic Dataset functionalities. Very similar to TensorDataset from PyTorch,
481
    this Dataset also supports transformations, slicing, advanced indexing,
482
    the targets field and all the other goodies listed in
483
    :class:`AvalancheDataset`.
484

485
    :param dataset_tensors: Sequences, Tensors or ndarrays representing the
486
        content of the dataset.
487
    :param transform: A function/transform that takes in a single element
488
        from the first tensor and returns a transformed version.
489
    :param target_transform: A function/transform that takes a single
490
        element of the second tensor and transforms it.
491
    :param transform_groups: A dictionary containing the transform groups.
492
        Transform groups are used to quickly switch between training and
493
        eval (test) transformations. This becomes useful when in need to
494
        test on the training dataset as test transformations usually don't
495
        contain random augmentations. ``AvalancheDataset`` natively supports
496
        the 'train' and 'eval' groups by calling the ``train()`` and
497
        ``eval()`` methods. When using custom groups one can use the
498
        ``with_transforms(group_name)`` method instead. Defaults to None,
499
        which means that the current transforms will be used to
500
        handle both 'train' and 'eval' groups (just like in standard
501
        ``torchvision`` datasets).
502
    :param initial_transform_group: The name of the transform group
503
        to be used. Defaults to 'train'.
504
    :param task_labels: The task labels for each pattern. Must be a sequence
505
        of ints, one for each pattern in the dataset. Alternatively can be a
506
        single int value, in which case that value will be used as the task
507
        label for all the instances. Defaults to None, which means that a
508
        default task label 0 will be applied to all patterns.
509
    :param targets: The label of each pattern. Defaults to None, which
510
        means that the targets will be retrieved from the second tensor of
511
        the dataset. Otherwise, it can be a sequence of values containing
512
        as many elements as the number of patterns.
513
    :param collate_fn: The function to use when slicing to merge single
514
        patterns. In the future this function may become the function
515
        used in the data loading process, too.
516
    """
517
    if len(dataset_tensors) < 1:
4✔
518
        raise ValueError("At least one sequence must be passed")
×
519

520
    if targets is None:
4✔
521
        targets = dataset_tensors[1]
4✔
522
    elif isinstance(targets, int):
4✔
523
        targets = dataset_tensors[targets]
×
524
    tts = []
4✔
525
    for tt in dataset_tensors:  # TorchTensor requires a pytorch tensor
4✔
526
        if not hasattr(tt, 'size'):
4✔
527
            tt = torch.tensor(tt)
×
528
        tts.append(tt)
4✔
529
    dataset = _TensorClassificationDataset(*tts)
4✔
530

531
    transform_gs = _init_transform_groups(
4✔
532
        transform_groups,
533
        transform,
534
        target_transform,
535
        initial_transform_group,
536
        dataset,
537
    )
538
    targets_data = _init_targets(dataset, targets)
4✔
539
    task_labels_data = _init_task_labels(dataset, task_labels)
4✔
540
    if initial_transform_group is not None and isinstance(
4✔
541
        dataset, AvalancheDataset
542
    ):
543
        dataset = dataset.with_transforms(initial_transform_group)
×
544

545
    das = []
4✔
546
    for d in [targets_data, task_labels_data]:
4✔
547
        if d is not None:
4✔
548
            das.append(d)
4✔
549

550
    return ClassificationDataset(
4✔
551
        [dataset],
552
        data_attributes=das if len(das) > 0 else None,
553
        transform_groups=transform_gs,
554
        collate_fn=collate_fn,
555
    )
556

557

558
class _TensorClassificationDataset(TensorDataset):
4✔
559
    """we want class labels to be integers, not tensors."""
4✔
560

561
    def __getitem__(self, item):
4✔
562
        elem = list(super().__getitem__(item))
4✔
563
        elem[1] = elem[1].item()
4✔
564
        return tuple(elem)
4✔
565

566

567
def concat_classification_datasets(
4✔
568
    datasets: Sequence[SupportedDataset],
569
    *,
570
    transform: Optional[XTransform] = None,
571
    target_transform: Optional[YTransform] = None,
572
    transform_groups: Optional[Mapping[str, TransformGroupDef]] = None,
573
    initial_transform_group: Optional[str] = None,
574
    task_labels: Optional[Union[int, 
575
                                Sequence[int],
576
                                Sequence[Sequence[int]]]] = None,
577
    targets: Optional[Union[
578
        Sequence[TTargetType], Sequence[Sequence[TTargetType]]
579
    ]] = None,
580
    collate_fn: Optional[Callable[[List], Any]] = None
581
) -> ClassificationDataset:
582
    """Creates a ``AvalancheConcatDataset`` instance.
583

584
    For simple subset operations you should use the method
585
    `dataset.concat(other)` or
586
    `concat_datasets` from `avalanche.benchmarks.utils.utils`.
587
    Use this constructor only if you need to redefine transformation or
588
    class/task labels.
589

590
    A Dataset that behaves like a PyTorch
591
    :class:`torch.utils.data.ConcatDataset`. However, this Dataset also supports
592
    transformations, slicing, advanced indexing and the targets field and all
593
    the other goodies listed in :class:`AvalancheDataset`.
594

595
    This dataset guarantees that the operations involving the transformations
596
    and transformations groups are consistent across the concatenated dataset
597
    (if they are subclasses of :class:`AvalancheDataset`).
598

599
    :param datasets: A collection of datasets.
600
    :param transform: A function/transform that takes the X value of a
601
        pattern from the original dataset and returns a transformed version.
602
    :param target_transform: A function/transform that takes in the target
603
        and transforms it.
604
    :param transform_groups: A dictionary containing the transform groups.
605
        Transform groups are used to quickly switch between training and
606
        eval (test) transformations. This becomes useful when in need to
607
        test on the training dataset as test transformations usually don't
608
        contain random augmentations. ``AvalancheDataset`` natively supports
609
        the 'train' and 'eval' groups by calling the ``train()`` and
610
        ``eval()`` methods. When using custom groups one can use the
611
        ``with_transforms(group_name)`` method instead. Defaults to None,
612
        which means that the current transforms will be used to
613
        handle both 'train' and 'eval' groups (just like in standard
614
        ``torchvision`` datasets).
615
    :param initial_transform_group: The name of the initial transform group
616
        to be used. Defaults to None, which means that if all
617
        AvalancheDatasets in the input datasets list agree on a common
618
        group (the "current group" is the same for all datasets), then that
619
        group will be used as the initial one. If the list of input datasets
620
        does not contain an AvalancheDataset or if the AvalancheDatasets
621
        do not agree on a common group, then 'train' will be used.
622
    :param targets: The label of each pattern. Can either be a sequence of
623
        labels or, alternatively, a sequence containing sequences of labels
624
        (one for each dataset to be concatenated). Defaults to None, which
625
        means that the targets will be retrieved from the datasets (if
626
        possible).
627
    :param task_labels: The task labels for each pattern. Must be a sequence
628
        of ints, one for each pattern in the dataset. Alternatively, task
629
        labels can be expressed as a sequence containing sequences of ints
630
        (one for each dataset to be concatenated) or even a single int,
631
        in which case that value will be used as the task label for all
632
        instances. Defaults to None, which means that the dataset will try
633
        to obtain the task labels from the original datasets. If no task
634
        labels could be found for a dataset, a default task label 0 will
635
        be applied to all patterns of that dataset.
636
    :param collate_fn: The function to use when slicing to merge single
637
        patterns. In the future this function may become the function
638
        used in the data loading process, too. If None, the constructor
639
        will check if a `collate_fn` field exists in the first dataset. If
640
        no such field exists, the default collate function will be used.
641
        Beware that the chosen collate function will be applied to all
642
        the concatenated datasets even if a different collate is defined
643
        in different datasets.
644
    """
645
    dds = []
4✔
646
    per_dataset_task_labels = _split_user_def_task_label(
4✔
647
        datasets,
648
        task_labels
649
    )
650

651
    per_dataset_targets = _split_user_def_targets(
4✔
652
        datasets,
653
        targets,
654
        lambda x: isinstance(x, int)
655
    )
656

657
    # Find common "current_group" or use "train"
658
    if initial_transform_group is None:
4✔
659
        initial_transform_group = \
4✔
660
            find_common_transforms_group(datasets, default_group="train")
661

662
    for dd, dataset_task_labels, dataset_targets in \
4✔
663
            zip(datasets, per_dataset_task_labels, per_dataset_targets):
664
        dd = make_classification_dataset(
4✔
665
            dd,
666
            transform=transform,
667
            target_transform=target_transform,
668
            transform_groups=transform_groups,
669
            initial_transform_group=initial_transform_group,
670
            task_labels=dataset_task_labels,
671
            targets=dataset_targets,
672
            collate_fn=collate_fn,
673
        )
674
        
675
        dds.append(dd)
4✔
676

677
    if len(dds) > 0:
4✔
678
        transform_groups_obj = _init_transform_groups(
4✔
679
            transform_groups,
680
            transform,
681
            target_transform,
682
            initial_transform_group,
683
            dds[0],
684
        )
685
    else:
686
        transform_groups_obj = None
×
687
    
688
    data: ClassificationDataset = ClassificationDataset(
4✔
689
        dds,
690
        transform_groups=transform_groups_obj
691
    )
692
    return data.with_transforms(initial_transform_group)
4✔
693

694

695
def _select_targets(
4✔
696
        dataset: SupportedDataset,
697
        indices: Optional[List[int]]) -> Sequence[TTargetType]:
698
    if hasattr(dataset, "targets"):
4✔
699
        # Standard supported dataset
700
        found_targets = dataset.targets  # type: ignore
4✔
701
    elif hasattr(dataset, "tensors"):
4✔
702
        # Support for PyTorch TensorDataset
703
        if len(dataset.tensors) < 2:  # type: ignore
4✔
704
            raise ValueError(
×
705
                "Tensor dataset has not enough tensors: "
706
                "at least 2 are required."
707
            )
708
        found_targets = dataset.tensors[1]  # type: ignore
4✔
709
    else:
710
        raise ValueError(
4✔
711
            "Unsupported dataset: must have a valid targets field "
712
            "or has to be a Tensor Dataset with at least 2 "
713
            "Tensors"
714
        )
715

716
    if indices is not None:
4✔
717
        found_targets = SubSequence(found_targets, indices=indices)
4✔
718

719
    return found_targets
4✔
720

721

722
def concat_classification_datasets_sequentially(
4✔
723
    train_dataset_list: Sequence[ISupportedClassificationDataset],
724
    test_dataset_list: Sequence[ISupportedClassificationDataset],
725
) -> Tuple[ClassificationDataset, 
726
           ClassificationDataset,
727
           List[list]]:
728
    """
729
    Concatenates a list of datasets. This is completely different from
730
    :class:`ConcatDataset`, in which datasets are merged together without
731
    other processing. Instead, this function re-maps the datasets class IDs.
732
    For instance:
733
    let the dataset[0] contain patterns of 3 different classes,
734
    let the dataset[1] contain patterns of 2 different classes, then class IDs
735
    will be mapped as follows:
736

737
    dataset[0] class "0" -> new class ID is "0"
738

739
    dataset[0] class "1" -> new class ID is "1"
740

741
    dataset[0] class "2" -> new class ID is "2"
742

743
    dataset[1] class "0" -> new class ID is "3"
744

745
    dataset[1] class "1" -> new class ID is "4"
746

747
    ... -> ...
748

749
    dataset[-1] class "C-1" -> new class ID is "overall_n_classes-1"
750

751
    In contrast, using PyTorch ConcatDataset:
752

753
    dataset[0] class "0" -> ID is "0"
754

755
    dataset[0] class "1" -> ID is "1"
756

757
    dataset[0] class "2" -> ID is "2"
758

759
    dataset[1] class "0" -> ID is "0"
760

761
    dataset[1] class "1" -> ID is "1"
762

763
    Note: ``train_dataset_list`` and ``test_dataset_list`` must have the same
764
    number of datasets.
765

766
    :param train_dataset_list: A list of training datasets
767
    :param test_dataset_list: A list of test datasets
768

769
    :returns: A concatenated dataset.
770
    """
771
    remapped_train_datasets: List[ClassificationDataset] = []
4✔
772
    remapped_test_datasets: List[ClassificationDataset] = []
4✔
773
    next_remapped_idx = 0
4✔
774

775
    train_dataset_list_sup = list(
4✔
776
        map(make_classification_dataset, train_dataset_list)
777
    )
778
    test_dataset_list_sup = list(
4✔
779
        map(make_classification_dataset, test_dataset_list)
780
    )
781
    del train_dataset_list
4✔
782
    del test_dataset_list
4✔
783

784
    # Obtain the number of classes of each dataset
785
    classes_per_dataset = [
4✔
786
        _count_unique(
787
            train_dataset_list_sup[dataset_idx].targets,
788
            test_dataset_list_sup[dataset_idx].targets,
789
        )
790
        for dataset_idx in range(len(train_dataset_list_sup))
791
    ]
792

793
    new_class_ids_per_dataset = []
4✔
794
    for dataset_idx in range(len(train_dataset_list_sup)):
4✔
795

796
        # Get the train and test sets of the dataset
797
        train_set = train_dataset_list_sup[dataset_idx]
4✔
798
        test_set = test_dataset_list_sup[dataset_idx]
4✔
799

800
        # Get the classes in the dataset
801
        dataset_classes = set(map(int, train_set.targets))
4✔
802

803
        # The class IDs for this dataset will be in range
804
        # [n_classes_in_previous_datasets,
805
        #       n_classes_in_previous_datasets + classes_in_this_dataset)
806
        new_classes = list(
4✔
807
            range(
808
                next_remapped_idx,
809
                next_remapped_idx + classes_per_dataset[dataset_idx],
810
            )
811
        )
812
        new_class_ids_per_dataset.append(new_classes)
4✔
813

814
        # AvalancheSubset is used to apply the class IDs transformation.
815
        # Remember, the class_mapping parameter must be a list in which:
816
        # new_class_id = class_mapping[original_class_id]
817
        # Hence, a list of size equal to the maximum class index is created
818
        # Only elements corresponding to the present classes are remapped
819
        class_mapping = [-1] * (max(dataset_classes) + 1)
4✔
820
        j = 0
4✔
821
        for i in dataset_classes:
4✔
822
            class_mapping[i] = new_classes[j]
4✔
823
            j += 1
4✔
824

825
        a = classification_subset(train_set, class_mapping=class_mapping)
4✔
826

827
        # Create remapped datasets and append them to the final list
828
        remapped_train_datasets.append(
4✔
829
            classification_subset(train_set, class_mapping=class_mapping)
830
        )
831
        remapped_test_datasets.append(
4✔
832
            classification_subset(test_set, class_mapping=class_mapping)
833
        )
834
        next_remapped_idx += classes_per_dataset[dataset_idx]
4✔
835

836
    return (
4✔
837
        concat_classification_datasets(remapped_train_datasets),
838
        concat_classification_datasets(remapped_test_datasets),
839
        new_class_ids_per_dataset,
840
    )
841

842

843
__all__ = [
4✔
844
    "SupportedDataset",
845
    "ClassificationDataset",
846
    "make_classification_dataset",
847
    "classification_subset",
848
    "make_tensor_classification_dataset",
849
    "concat_classification_datasets",
850
    "concat_classification_datasets_sequentially"
851
]
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc