• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ContinualAI / avalanche / 8382229381

21 Mar 2024 09:54PM UTC coverage: 51.806%. Remained the same
8382229381

Pull #1629

github

web-flow
Merge 9c9b8f9c8 into dbdc3804b
Pull Request #1629: WIP: Documentation & Typos

25 of 47 new or added lines in 6 files covered. (53.19%)

71 existing lines in 3 files now uncovered.

14756 of 28483 relevant lines covered (51.81%)

0.52 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.56
/avalanche/benchmarks/utils/classification_dataset.py
1
################################################################################
2
# Copyright (c) 2021 ContinualAI.                                              #
3
# Copyrights licensed under the MIT License.                                   #
4
# See the accompanying LICENSE file for terms.                                 #
5
#                                                                              #
6
# Date: 12-05-2020                                                             #
7
# Author(s): Lorenzo Pellegrini, Antonio Carta                                 #
8
# E-mail: contact@continualai.org                                              #
9
# Website: avalanche.continualai.org                                           #
10
################################################################################
11

12
"""
1✔
13
This module contains the implementation of the ``ClassificationDataset``,
14
which is the dataset used for supervised continual learning benchmarks.
15
ClassificationDatasets are ``AvalancheDatasets`` that manage class and task
16
labels automatically. Concatenation and subsampling operations are optimized
17
to be used frequently, as is common in replay strategies.
18
"""
19

20
from functools import partial
1✔
21
import torch
1✔
22
from torch.utils.data.dataset import Subset, ConcatDataset, TensorDataset
1✔
23

24
from avalanche.benchmarks.utils.utils import (
1✔
25
    TaskSet,
26
    _count_unique,
27
    find_common_transforms_group,
28
    _init_task_labels,
29
    _init_transform_groups,
30
    _split_user_def_targets,
31
    _split_user_def_task_label,
32
    _traverse_supported_dataset,
33
)
34

35
from avalanche.benchmarks.utils.data import AvalancheDataset
1✔
36
from avalanche.benchmarks.utils.transform_groups import (
1✔
37
    TransformGroupDef,
38
    DefaultTransformGroups,
39
    XTransform,
40
    YTransform,
41
)
42
from avalanche.benchmarks.utils.data_attribute import DataAttribute
1✔
43
from avalanche.benchmarks.utils.dataset_utils import (
1✔
44
    SubSequence,
45
)
46
from avalanche.benchmarks.utils.flat_data import ConstantSequence
1✔
47
from avalanche.benchmarks.utils.dataset_definitions import (
1✔
48
    ISupportedClassificationDataset,
49
    ITensorDataset,
50
    IDatasetWithTargets,
51
)
52

53
from typing import (
1✔
54
    List,
55
    Any,
56
    Sequence,
57
    Union,
58
    Optional,
59
    TypeVar,
60
    Callable,
61
    Dict,
62
    Tuple,
63
    Mapping,
64
    overload,
65
)
66

67

68
T_co = TypeVar("T_co", covariant=True)
1✔
69
TAvalancheDataset = TypeVar("TAvalancheDataset", bound="AvalancheDataset")
1✔
70
TTargetType = int
1✔
71

72
TClassificationDataset = TypeVar(
1✔
73
    "TClassificationDataset", bound="ClassificationDataset"
74
)
75

76

77
def lookup(indexable, idx):
1✔
78
    """
79
    A simple function that implements indexing into an indexable object.
80
    Together with 'partial' this allows us to circumvent lambda functions
81
    that cannot be pickled.
82
    """
83
    return indexable[idx]
1✔
84

85

86
class ClassificationDataset(AvalancheDataset):
1✔
87
    def __init__(self, *args, **kwargs):
1✔
88
        super().__init__(*args, **kwargs)
1✔
89
        assert "targets" in self._data_attributes, (
1✔
90
            "The supervised version of the ClassificationDataset requires "
91
            + "the targets field"
92
        )
93

94
    @property
1✔
95
    def targets(self) -> DataAttribute[TTargetType]:
1✔
96
        return self._data_attributes["targets"]
1✔
97

98
    # TODO: this shouldn't be needed
99
    def subset(self, indices):
1✔
100
        data = super().subset(indices)
1✔
101
        return data.with_transforms(self._flat_data._transform_groups.current_group)
1✔
102

103
    # TODO: this shouldn't be needed
104
    def concat(self, other):
1✔
105
        data = super().concat(other)
1✔
106
        return data.with_transforms(self._flat_data._transform_groups.current_group)
1✔
107

108
    def __hash__(self):
1✔
UNCOV
109
        return id(self)
×
110

111

112
class TaskAwareClassificationDataset(AvalancheDataset[T_co]):
1✔
113
    @property
1✔
114
    def task_pattern_indices(self) -> Dict[int, Sequence[int]]:
1✔
115
        """A dictionary mapping task ids to their sample indices."""
UNCOV
116
        return self.targets_task_labels.val_to_idx  # type: ignore
×
117

118
    @property
1✔
119
    def task_set(self: TClassificationDataset) -> TaskSet[TClassificationDataset]:
1✔
120
        """Returns the datasets's ``TaskSet``, which is a mapping <task-id,
121
        task-dataset>."""
122
        return TaskSet(self)
1✔
123

124
    def subset(self, indices):
1✔
125
        data = super().subset(indices)
1✔
126
        return data.with_transforms(self._flat_data._transform_groups.current_group)
1✔
127

128
    def concat(self, other):
1✔
129
        data = super().concat(other)
1✔
130
        return data.with_transforms(self._flat_data._transform_groups.current_group)
1✔
131

132
    def __hash__(self):
1✔
133
        return id(self)
1✔
134

135

136
class TaskAwareSupervisedClassificationDataset(TaskAwareClassificationDataset[T_co]):
1✔
137
    # TODO: remove? ClassificationDataset should have targets
138
    def __init__(self, *args, **kwargs):
1✔
139
        super().__init__(*args, **kwargs)
1✔
140
        assert "targets" in self._data_attributes, (
1✔
141
            "The supervised version of the ClassificationDataset requires "
142
            + "the targets field"
143
        )
144
        assert "targets_task_labels" in self._data_attributes, (
1✔
145
            "The supervised version of the ClassificationDataset requires "
146
            + "the targets_task_labels field"
147
        )
148

149
    @property
1✔
150
    def targets(self) -> DataAttribute[TTargetType]:
1✔
151
        return self._data_attributes["targets"]
1✔
152

153
    @property
1✔
154
    def targets_task_labels(self) -> DataAttribute[int]:
1✔
155
        return self._data_attributes["targets_task_labels"]
1✔
156

157

158
SupportedDataset = Union[
1✔
159
    IDatasetWithTargets,
160
    ITensorDataset,
161
    Subset,
162
    ConcatDataset,
163
    TaskAwareClassificationDataset,
164
]
165

166

167
@overload
1✔
168
def _make_taskaware_classification_dataset(
1✔
169
    dataset: TaskAwareSupervisedClassificationDataset,
170
    *,
171
    transform: Optional[XTransform] = None,
172
    target_transform: Optional[YTransform] = None,
173
    transform_groups: Optional[Mapping[str, TransformGroupDef]] = None,
174
    initial_transform_group: Optional[str] = None,
175
    task_labels: Optional[Union[int, Sequence[int]]] = None,
176
    targets: Optional[Sequence[TTargetType]] = None,
177
    collate_fn: Optional[Callable[[List], Any]] = None
178
) -> TaskAwareSupervisedClassificationDataset: ...
179

180

181
@overload
1✔
182
def _make_taskaware_classification_dataset(
1✔
183
    dataset: SupportedDataset,
184
    *,
185
    transform: Optional[XTransform] = None,
186
    target_transform: Optional[YTransform] = None,
187
    transform_groups: Optional[Mapping[str, TransformGroupDef]] = None,
188
    initial_transform_group: Optional[str] = None,
189
    task_labels: Union[int, Sequence[int]],
190
    targets: Sequence[TTargetType],
191
    collate_fn: Optional[Callable[[List], Any]] = None
192
) -> TaskAwareSupervisedClassificationDataset: ...
193

194

195
@overload
1✔
196
def _make_taskaware_classification_dataset(
1✔
197
    dataset: SupportedDataset,
198
    *,
199
    transform: Optional[XTransform] = None,
200
    target_transform: Optional[YTransform] = None,
201
    transform_groups: Optional[Mapping[str, TransformGroupDef]] = None,
202
    initial_transform_group: Optional[str] = None,
203
    task_labels: Optional[Union[int, Sequence[int]]] = None,
204
    targets: Optional[Sequence[TTargetType]] = None,
205
    collate_fn: Optional[Callable[[List], Any]] = None
206
) -> TaskAwareClassificationDataset: ...
207

208

209
def _make_taskaware_classification_dataset(
1✔
210
    dataset: SupportedDataset,
211
    *,
212
    transform: Optional[XTransform] = None,
213
    target_transform: Optional[YTransform] = None,
214
    transform_groups: Optional[Mapping[str, TransformGroupDef]] = None,
215
    initial_transform_group: Optional[str] = None,
216
    task_labels: Optional[Union[int, Sequence[int]]] = None,
217
    targets: Optional[Sequence[TTargetType]] = None,
218
    collate_fn: Optional[Callable[[List], Any]] = None
219
) -> Union[TaskAwareClassificationDataset, TaskAwareSupervisedClassificationDataset]:
220
    """Avalanche Classification Dataset.
221

222
    Supervised continual learning benchmarks in Avalanche return instances of
223
    this dataset, but it can also be used in a completely standalone manner.
224

225
    This dataset applies input/target transformations, it supports
226
    slicing and advanced indexing and it also contains useful fields as
227
    `targets`, which contains the pattern labels, and `targets_task_labels`,
228
    which contains the pattern task labels. The `task_set` field can be used to
229
    obtain a the subset of patterns labeled with a given task label.
230

231
    This dataset can also be used to apply several advanced operations involving
232
    transformations. For instance, it allows the user to add and replace
233
    transformations, freeze them so that they can't be changed, etc.
234

235
    This dataset also allows the user to keep distinct transformations groups.
236
    Simply put, a transformation group is a pair of transform+target_transform
237
    (exactly as in torchvision datasets). This dataset natively supports keeping
238
    two transformation groups: the first, 'train', contains transformations
239
    applied to training patterns. Those transformations usually involve some
240
    kind of data augmentation. The second one is 'eval', that will contain
241
    transformations applied to test patterns. Having both groups can be
242
    useful when, for instance, in need to test on the training data (as this
243
    process usually involves removing data augmentation operations). Switching
244
    between transformations can be easily achieved by using the
245
    :func:`train` and :func:`eval` methods.
246

247
    Moreover, arbitrary transformation groups can be added and used. For more
248
    info see the constructor and the :func:`with_transforms` method.
249

250
    This dataset will try to inherit the task labels from the input
251
    dataset. If none are available and none are given via the `task_labels`
252
    parameter, each pattern will be assigned a default task label 0.
253

254
    Creates a ``AvalancheDataset`` instance.
255

256
    :param dataset: The dataset to decorate. Beware that
257
        AvalancheDataset will not overwrite transformations already
258
        applied by this dataset.
259
    :param transform: A function/transform that takes the X value of a
260
        pattern from the original dataset and returns a transformed version.
261
    :param target_transform: A function/transform that takes in the target
262
        and transforms it.
263
    :param transform_groups: A dictionary containing the transform groups.
264
        Transform groups are used to quickly switch between training and
265
        eval (test) transformations. This becomes useful when in need to
266
        test on the training dataset as test transformations usually don't
267
        contain random augmentations. ``AvalancheDataset`` natively supports
268
        the 'train' and 'eval' groups by calling the ``train()`` and
269
        ``eval()`` methods. When using custom groups one can use the
270
        ``with_transforms(group_name)`` method instead. Defaults to None,
271
        which means that the current transforms will be used to
272
        handle both 'train' and 'eval' groups (just like in standard
273
        ``torchvision`` datasets).
274
    :param initial_transform_group: The name of the initial transform group
275
        to be used. Defaults to None, which means that the current group of
276
        the input dataset will be used (if an AvalancheDataset). If the
277
        input dataset is not an AvalancheDataset, then 'train' will be
278
        used.
279
    :param task_labels: The task label of each instance. Must be a sequence
280
        of ints, one for each instance in the dataset. Alternatively can be
281
        a single int value, in which case that value will be used as the
282
        task label for all the instances. Defaults to None, which means that
283
        the dataset will try to obtain the task labels from the original
284
        dataset. If no task labels could be found, a default task label
285
        0 will be applied to all instances.
286
    :param targets: The label of each pattern. Defaults to None, which
287
        means that the targets will be retrieved from the dataset (if
288
        possible).
289
    :param collate_fn: The function to use when slicing to merge single
290
        patterns.This function is the function
291
        used in the data loading process, too. If None
292
        the constructor will check if a
293
        `collate_fn` field exists in the dataset. If no such field exists,
294
        the default collate function will be used.
295
    """
296

297
    is_supervised = isinstance(dataset, TaskAwareSupervisedClassificationDataset)
1✔
298

299
    transform_gs = _init_transform_groups(
1✔
300
        transform_groups,
301
        transform,
302
        target_transform,
303
        initial_transform_group,
304
        dataset,
305
    )
306
    targets_data: Optional[DataAttribute[TTargetType]] = _init_targets(dataset, targets)
1✔
307
    task_labels_data: Optional[DataAttribute[int]] = _init_task_labels(
1✔
308
        dataset, task_labels
309
    )
310

311
    das: List[DataAttribute] = []
1✔
312
    if targets_data is not None:
1✔
313
        das.append(targets_data)
1✔
314
    if task_labels_data is not None:
1✔
315
        das.append(task_labels_data)
1✔
316

317
        # Check if supervision data has been added
318
    is_supervised = is_supervised or (
1✔
319
        targets_data is not None and task_labels_data is not None
320
    )
321

322
    data: Union[
323
        TaskAwareClassificationDataset, TaskAwareSupervisedClassificationDataset
324
    ]
325
    if is_supervised:
1✔
326
        data = TaskAwareSupervisedClassificationDataset(
1✔
327
            [dataset],
328
            data_attributes=das if len(das) > 0 else None,
329
            transform_groups=transform_gs,
330
            collate_fn=collate_fn,
331
        )
332
    else:
UNCOV
333
        data = TaskAwareClassificationDataset(
×
334
            [dataset],
335
            data_attributes=das if len(das) > 0 else None,
336
            transform_groups=transform_gs,
337
            collate_fn=collate_fn,
338
        )
339

340
    if initial_transform_group is not None:
1✔
341
        return data.with_transforms(initial_transform_group)
1✔
342
    else:
343
        return data
1✔
344

345

346
def _init_targets(
1✔
347
    dataset, targets, check_shape=True
348
) -> Optional[DataAttribute[TTargetType]]:
349
    if targets is not None:
1✔
350
        # User defined targets always take precedence
351
        if isinstance(targets, int):
1✔
352
            targets = ConstantSequence(targets, len(dataset))
1✔
353
        elif len(targets) != len(dataset) and check_shape:
1✔
UNCOV
354
            raise ValueError(
×
355
                "Invalid amount of target labels. It must be equal to the "
356
                "number of patterns in the dataset. Got {}, expected "
357
                "{}!".format(len(targets), len(dataset))
358
            )
359
        return DataAttribute(targets, "targets")
1✔
360

361
    targets = _traverse_supported_dataset(dataset, _select_targets)
1✔
362

363
    if targets is not None:
1✔
364
        if isinstance(targets, torch.Tensor):
1✔
365
            targets = targets.tolist()
1✔
366

367
    if targets is None:
1✔
UNCOV
368
        return None
×
369

370
    return DataAttribute(targets, "targets")
1✔
371

372

373
@overload
1✔
374
def _taskaware_classification_subset(
1✔
375
    dataset: TaskAwareSupervisedClassificationDataset,
376
    indices: Optional[Sequence[int]] = None,
377
    *,
378
    class_mapping: Optional[Sequence[int]] = None,
379
    transform: Optional[XTransform] = None,
380
    target_transform: Optional[YTransform] = None,
381
    transform_groups: Optional[Mapping[str, Tuple[XTransform, YTransform]]] = None,
382
    initial_transform_group: Optional[str] = None,
383
    task_labels: Optional[Union[int, Sequence[int]]] = None,
384
    targets: Optional[Sequence[TTargetType]] = None,
385
    collate_fn: Optional[Callable[[List], Any]] = None
386
) -> TaskAwareSupervisedClassificationDataset: ...
387

388

389
@overload
1✔
390
def _taskaware_classification_subset(
1✔
391
    dataset: SupportedDataset,
392
    indices: Optional[Sequence[int]] = None,
393
    *,
394
    class_mapping: Optional[Sequence[int]] = None,
395
    transform: Optional[XTransform] = None,
396
    target_transform: Optional[YTransform] = None,
397
    transform_groups: Optional[Mapping[str, Tuple[XTransform, YTransform]]] = None,
398
    initial_transform_group: Optional[str] = None,
399
    task_labels: Union[int, Sequence[int]],
400
    targets: Sequence[TTargetType],
401
    collate_fn: Optional[Callable[[List], Any]] = None
402
) -> TaskAwareSupervisedClassificationDataset: ...
403

404

405
@overload
1✔
406
def _taskaware_classification_subset(
1✔
407
    dataset: SupportedDataset,
408
    indices: Optional[Sequence[int]] = None,
409
    *,
410
    class_mapping: Optional[Sequence[int]] = None,
411
    transform: Optional[XTransform] = None,
412
    target_transform: Optional[YTransform] = None,
413
    transform_groups: Optional[Mapping[str, Tuple[XTransform, YTransform]]] = None,
414
    initial_transform_group: Optional[str] = None,
415
    task_labels: Optional[Union[int, Sequence[int]]] = None,
416
    targets: Optional[Sequence[TTargetType]] = None,
417
    collate_fn: Optional[Callable[[List], Any]] = None
418
) -> TaskAwareClassificationDataset: ...
419

420

421
def _taskaware_classification_subset(
1✔
422
    dataset: SupportedDataset,
423
    indices: Optional[Sequence[int]] = None,
424
    *,
425
    class_mapping: Optional[Sequence[int]] = None,
426
    transform: Optional[XTransform] = None,
427
    target_transform: Optional[YTransform] = None,
428
    transform_groups: Optional[Mapping[str, Tuple[XTransform, YTransform]]] = None,
429
    initial_transform_group: Optional[str] = None,
430
    task_labels: Optional[Union[int, Sequence[int]]] = None,
431
    targets: Optional[Sequence[TTargetType]] = None,
432
    collate_fn: Optional[Callable[[List], Any]] = None
433
) -> Union[TaskAwareClassificationDataset, TaskAwareSupervisedClassificationDataset]:
434
    """Creates an ``AvalancheSubset`` instance.
435

436
    For simple subset operations you should use the method
437
    `dataset.subset(indices)`.
438
    Use this constructor only if you need to redefine transformation or
439
    class/task labels.
440

441
    A Dataset that behaves like a PyTorch :class:`torch.utils.data.Subset`.
442
    This Dataset also supports transformations, slicing, advanced indexing,
443
    the targets field, class mapping and all the other goodies listed in
444
    :class:`AvalancheDataset`.
445

446
    :param dataset: The whole dataset.
447
    :param indices: Indices in the whole set selected for subset. Can
448
        be None, which means that the whole dataset will be returned.
449
    :param class_mapping: A list that, for each possible target (Y) value,
450
        contains its corresponding remapped value. Can be None.
451
        Beware that setting this parameter will force the final
452
        dataset type to be CLASSIFICATION or UNDEFINED.
453
    :param transform: A function/transform that takes the X value of a
454
        pattern from the original dataset and returns a transformed version.
455
    :param target_transform: A function/transform that takes in the target
456
        and transforms it.
457
    :param transform_groups: A dictionary containing the transform groups.
458
        Transform groups are used to quickly switch between training and
459
        eval (test) transformations. This becomes useful when in need to
460
        test on the training dataset as test transformations usually don't
461
        contain random augmentations. ``AvalancheDataset`` natively supports
462
        the 'train' and 'eval' groups by calling the ``train()`` and
463
        ``eval()`` methods. When using custom groups one can use the
464
        ``with_transforms(group_name)`` method instead. Defaults to None,
465
        which means that the current transforms will be used to
466
        handle both 'train' and 'eval' groups (just like in standard
467
        ``torchvision`` datasets).
468
    :param initial_transform_group: The name of the initial transform group
469
        to be used. Defaults to None, which means that the current group of
470
        the input dataset will be used (if an AvalancheDataset). If the
471
        input dataset is not an AvalancheDataset, then 'train' will be
472
        used.
473
    :param task_labels: The task label for each instance. Must be a sequence
474
        of ints, one for each instance in the dataset. This can either be a
475
        list of task labels for the original dataset or the list of task
476
        labels for the instances of the subset (an automatic detection will
477
        be made). In the unfortunate case in which the original dataset and
478
        the subset contain the same amount of instances, then this parameter
479
        is considered to contain the task labels of the subset.
480
        Alternatively can be a single int value, in which case
481
        that value will be used as the task label for all the instances.
482
        Defaults to None, which means that the dataset will try to
483
        obtain the task labels from the original dataset. If no task labels
484
        could be found, a default task label 0 will be applied to all
485
        instances.
486
    :param targets: The label of each pattern. Defaults to None, which
487
        means that the targets will be retrieved from the dataset (if
488
        possible). This can either be a list of target labels for the
489
        original dataset or the list of target labels for the instances of
490
        the subset (an automatic detection will be made). In the unfortunate
491
        case in which the original dataset and the subset contain the same
492
        amount of instances, then this parameter is considered to contain
493
        the target labels of the subset.
494
    :param collate_fn: The function to use when slicing to merge single
495
        patterns. This function is the function
496
        used in the data loading process, too. If None,
497
        the constructor will check if a
498
        `collate_fn` field exists in the dataset. If no such field exists,
499
        the default collate function will be used.
500
    """
501

502
    is_supervised = isinstance(dataset, TaskAwareSupervisedClassificationDataset)
1✔
503

504
    if isinstance(dataset, TaskAwareClassificationDataset):
1✔
505
        if (
1✔
506
            class_mapping is None
507
            and transform is None
508
            and target_transform is None
509
            and transform_groups is None
510
            and initial_transform_group is None
511
            and task_labels is None
512
            and targets is None
513
            and collate_fn is None
514
        ):
515
            return dataset.subset(indices)
1✔
516

517
    targets_data: Optional[DataAttribute[TTargetType]] = _init_targets(
1✔
518
        dataset, targets, check_shape=False
519
    )
520
    task_labels_data: Optional[DataAttribute[int]] = _init_task_labels(
1✔
521
        dataset, task_labels, check_shape=False
522
    )
523

524
    transform_gs = _init_transform_groups(
1✔
525
        transform_groups,
526
        transform,
527
        target_transform,
528
        initial_transform_group,
529
        dataset,
530
    )
531

532
    if initial_transform_group is not None and isinstance(dataset, AvalancheDataset):
1✔
533
        dataset = dataset.with_transforms(initial_transform_group)
1✔
534

535
    if class_mapping is not None:  # update targets
1✔
536
        if targets_data is None:
1✔
UNCOV
537
            tgs = [class_mapping[el] for el in dataset.targets]  # type: ignore
×
538
        else:
539
            tgs = [class_mapping[el] for el in targets_data]
1✔
540

541
        targets_data = DataAttribute(tgs, "targets")
1✔
542

543
    if class_mapping is not None:
1✔
544
        frozen_transform_groups = DefaultTransformGroups(
1✔
545
            (None, partial(lookup, class_mapping))
546
        )
547
    else:
548
        frozen_transform_groups = None
1✔
549

550
    das = []
1✔
551
    if targets_data is not None:
1✔
552
        das.append(targets_data)
1✔
553

554
    # Check if supervision data has been added
555
    is_supervised = is_supervised or (
1✔
556
        targets_data is not None and task_labels_data is not None
557
    )
558

559
    if task_labels_data is not None:
1✔
560
        # special treatment for task labels depending on length for
561
        # backward compatibility
562
        if len(task_labels_data) != len(dataset):
1✔
563
            # task labels are already subsampled
564
            dataset = TaskAwareClassificationDataset(
1✔
565
                [dataset],
566
                indices=list(indices) if indices is not None else None,
567
                data_attributes=das,
568
                transform_groups=transform_gs,
569
                frozen_transform_groups=frozen_transform_groups,
570
                collate_fn=collate_fn,
571
            )
572
            # now add task labels
573
            if is_supervised:
1✔
574
                return TaskAwareSupervisedClassificationDataset(
1✔
575
                    [dataset],
576
                    data_attributes=[dataset.targets, task_labels_data],  # type: ignore
577
                )
578
            else:
UNCOV
579
                return TaskAwareClassificationDataset(
×
580
                    [dataset],
581
                    data_attributes=[dataset.targets, task_labels_data],  # type: ignore
582
                )
583
        else:
584
            das.append(task_labels_data)
1✔
585

586
    if is_supervised:
1✔
587
        return TaskAwareSupervisedClassificationDataset(
1✔
588
            [dataset],
589
            indices=list(indices) if indices is not None else None,
590
            data_attributes=das if len(das) > 0 else None,
591
            transform_groups=transform_gs,
592
            frozen_transform_groups=frozen_transform_groups,
593
            collate_fn=collate_fn,
594
        )
595
    else:
UNCOV
596
        return TaskAwareClassificationDataset(
×
597
            [dataset],
598
            indices=list(indices) if indices is not None else None,
599
            data_attributes=das if len(das) > 0 else None,
600
            transform_groups=transform_gs,
601
            frozen_transform_groups=frozen_transform_groups,
602
            collate_fn=collate_fn,
603
        )
604

605

606
@overload
1✔
607
def _make_taskaware_tensor_classification_dataset(
1✔
608
    *dataset_tensors: Sequence,
609
    transform: Optional[XTransform] = None,
610
    target_transform: Optional[YTransform] = None,
611
    transform_groups: Optional[Dict[str, Tuple[XTransform, YTransform]]] = None,
612
    initial_transform_group: Optional[str] = "train",
613
    task_labels: Union[int, Sequence[int]],
614
    targets: Union[Sequence[TTargetType], int],
615
    collate_fn: Optional[Callable[[List], Any]] = None
616
) -> TaskAwareSupervisedClassificationDataset: ...
617

618

619
@overload
1✔
620
def _make_taskaware_tensor_classification_dataset(
1✔
621
    *dataset_tensors: Sequence,
622
    transform: Optional[XTransform] = None,
623
    target_transform: Optional[YTransform] = None,
624
    transform_groups: Optional[Dict[str, Tuple[XTransform, YTransform]]] = None,
625
    initial_transform_group: Optional[str] = "train",
626
    task_labels: Optional[Union[int, Sequence[int]]] = None,
627
    targets: Optional[Union[Sequence[TTargetType], int]] = None,
628
    collate_fn: Optional[Callable[[List], Any]] = None
629
) -> Union[
630
    TaskAwareClassificationDataset, TaskAwareSupervisedClassificationDataset
631
]: ...
632

633

634
def _make_taskaware_tensor_classification_dataset(
1✔
635
    *dataset_tensors: Sequence,
636
    transform: Optional[XTransform] = None,
637
    target_transform: Optional[YTransform] = None,
638
    transform_groups: Optional[Dict[str, Tuple[XTransform, YTransform]]] = None,
639
    initial_transform_group: Optional[str] = "train",
640
    task_labels: Optional[Union[int, Sequence[int]]] = None,
641
    targets: Optional[Union[Sequence[TTargetType], int]] = None,
642
    collate_fn: Optional[Callable[[List], Any]] = None
643
) -> Union[TaskAwareClassificationDataset, TaskAwareSupervisedClassificationDataset]:
644
    """Creates a ``AvalancheTensorDataset`` instance.
645

646
    A Dataset that wraps existing ndarrays, Tensors, lists... to provide
647
    basic Dataset functionalities. Very similar to TensorDataset from PyTorch,
648
    this Dataset also supports transformations, slicing, advanced indexing,
649
    the targets field and all the other goodies listed in
650
    :class:`AvalancheDataset`.
651

652
    :param dataset_tensors: Sequences, Tensors or ndarrays representing the
653
        content of the dataset.
654
    :param transform: A function/transform that takes in a single element
655
        from the first tensor and returns a transformed version.
656
    :param target_transform: A function/transform that takes a single
657
        element of the second tensor and transforms it.
658
    :param transform_groups: A dictionary containing the transform groups.
659
        Transform groups are used to quickly switch between training and
660
        eval (test) transformations. This becomes useful when in need to
661
        test on the training dataset as test transformations usually don't
662
        contain random augmentations. ``AvalancheDataset`` natively supports
663
        the 'train' and 'eval' groups by calling the ``train()`` and
664
        ``eval()`` methods. When using custom groups one can use the
665
        ``with_transforms(group_name)`` method instead. Defaults to None,
666
        which means that the current transforms will be used to
667
        handle both 'train' and 'eval' groups (just like in standard
668
        ``torchvision`` datasets).
669
    :param initial_transform_group: The name of the transform group
670
        to be used. Defaults to 'train'.
671
    :param task_labels: The task labels for each pattern. Must be a sequence
672
        of ints, one for each pattern in the dataset. Alternatively can be a
673
        single int value, in which case that value will be used as the task
674
        label for all the instances. Defaults to None, which means that a
675
        default task label 0 will be applied to all patterns.
676
    :param targets: The label of each pattern. Defaults to None, which
677
        means that the targets will be retrieved from the second tensor of
678
        the dataset. Otherwise, it can be a sequence of values containing
679
        as many elements as the number of patterns.
680
    :param collate_fn: The function to use when slicing to merge single
681
        patterns. In the future this function may become the function
682
        used in the data loading process, too.
683
    """
684
    if len(dataset_tensors) < 1:
1✔
UNCOV
685
        raise ValueError("At least one sequence must be passed")
×
686

687
    if targets is None:
1✔
688
        targets = dataset_tensors[1]
1✔
689
    elif isinstance(targets, int):
1✔
UNCOV
690
        targets = dataset_tensors[targets]
×
691
    tts = []
1✔
692
    for tt in dataset_tensors:  # TorchTensor requires a pytorch tensor
1✔
693
        if not hasattr(tt, "size"):
1✔
UNCOV
694
            tt = torch.tensor(tt)
×
695
        tts.append(tt)
1✔
696
    dataset = _TensorClassificationDataset(*tts)
1✔
697

698
    transform_gs = _init_transform_groups(
1✔
699
        transform_groups,
700
        transform,
701
        target_transform,
702
        initial_transform_group,
703
        dataset,
704
    )
705
    targets_data = _init_targets(dataset, targets)
1✔
706
    task_labels_data = _init_task_labels(dataset, task_labels)
1✔
707
    if initial_transform_group is not None and isinstance(dataset, AvalancheDataset):
1✔
UNCOV
708
        dataset = dataset.with_transforms(initial_transform_group)
×
709

710
    das = []
1✔
711
    for d in [targets_data, task_labels_data]:
1✔
712
        if d is not None:
1✔
713
            das.append(d)
1✔
714

715
    # Check if supervision data has been added
716
    is_supervised = targets_data is not None and task_labels_data is not None
1✔
717

718
    if is_supervised:
1✔
719
        return TaskAwareSupervisedClassificationDataset(
1✔
720
            [dataset],
721
            data_attributes=das if len(das) > 0 else None,
722
            transform_groups=transform_gs,
723
            collate_fn=collate_fn,
724
        )
725
    else:
UNCOV
726
        return TaskAwareClassificationDataset(
×
727
            [dataset],
728
            data_attributes=das if len(das) > 0 else None,
729
            transform_groups=transform_gs,
730
            collate_fn=collate_fn,
731
        )
732

733

734
class _TensorClassificationDataset(TensorDataset):
1✔
735
    """we want class labels to be integers, not tensors."""
1✔
736

737
    def __getitem__(self, item):
1✔
738
        elem = list(super().__getitem__(item))
1✔
739
        elem[1] = elem[1].item()
1✔
740
        return tuple(elem)
1✔
741

742

743
@overload
1✔
744
def _concat_taskaware_classification_datasets(
1✔
745
    datasets: Sequence[TaskAwareSupervisedClassificationDataset],
746
    *,
747
    transform: Optional[XTransform] = None,
748
    target_transform: Optional[YTransform] = None,
749
    transform_groups: Optional[Mapping[str, TransformGroupDef]] = None,
750
    initial_transform_group: Optional[str] = None,
751
    task_labels: Optional[Union[int, Sequence[int], Sequence[Sequence[int]]]] = None,
752
    targets: Optional[
753
        Union[Sequence[TTargetType], Sequence[Sequence[TTargetType]]]
754
    ] = None,
755
    collate_fn: Optional[Callable[[List], Any]] = None
756
) -> TaskAwareSupervisedClassificationDataset: ...
757

758

759
@overload
1✔
760
def _concat_taskaware_classification_datasets(
1✔
761
    datasets: Sequence[SupportedDataset],
762
    *,
763
    transform: Optional[XTransform] = None,
764
    target_transform: Optional[YTransform] = None,
765
    transform_groups: Optional[Mapping[str, TransformGroupDef]] = None,
766
    initial_transform_group: Optional[str] = None,
767
    task_labels: Union[int, Sequence[int], Sequence[Sequence[int]]],
768
    targets: Union[Sequence[TTargetType], Sequence[Sequence[TTargetType]]],
769
    collate_fn: Optional[Callable[[List], Any]] = None
770
) -> TaskAwareSupervisedClassificationDataset: ...
771

772

773
@overload
1✔
774
def _concat_taskaware_classification_datasets(
1✔
775
    datasets: Sequence[SupportedDataset],
776
    *,
777
    transform: Optional[XTransform] = None,
778
    target_transform: Optional[YTransform] = None,
779
    transform_groups: Optional[Mapping[str, TransformGroupDef]] = None,
780
    initial_transform_group: Optional[str] = None,
781
    task_labels: Optional[Union[int, Sequence[int], Sequence[Sequence[int]]]] = None,
782
    targets: Optional[
783
        Union[Sequence[TTargetType], Sequence[Sequence[TTargetType]]]
784
    ] = None,
785
    collate_fn: Optional[Callable[[List], Any]] = None
786
) -> TaskAwareClassificationDataset: ...
787

788

789
def _concat_taskaware_classification_datasets(
1✔
790
    datasets: Sequence[SupportedDataset],
791
    *,
792
    transform: Optional[XTransform] = None,
793
    target_transform: Optional[YTransform] = None,
794
    transform_groups: Optional[Mapping[str, TransformGroupDef]] = None,
795
    initial_transform_group: Optional[str] = None,
796
    task_labels: Optional[Union[int, Sequence[int], Sequence[Sequence[int]]]] = None,
797
    targets: Optional[
798
        Union[Sequence[TTargetType], Sequence[Sequence[TTargetType]]]
799
    ] = None,
800
    collate_fn: Optional[Callable[[List], Any]] = None
801
) -> Union[TaskAwareClassificationDataset, TaskAwareSupervisedClassificationDataset]:
802
    """Creates a ``AvalancheConcatDataset`` instance.
803

804
    For simple subset operations you should use the method
805
    `dataset.concat(other)` or
806
    `concat_datasets` from `avalanche.benchmarks.utils.utils`.
807
    Use this constructor only if you need to redefine transformation or
808
    class/task labels.
809

810
    A Dataset that behaves like a PyTorch
811
    :class:`torch.utils.data.ConcatDataset`. However, this Dataset also supports
812
    transformations, slicing, advanced indexing and the targets field and all
813
    the other goodies listed in :class:`AvalancheDataset`.
814

815
    This dataset guarantees that the operations involving the transformations
816
    and transformations groups are consistent across the concatenated dataset
817
    (if they are subclasses of :class:`AvalancheDataset`).
818

819
    :param datasets: A collection of datasets.
820
    :param transform: A function/transform that takes the X value of a
821
        pattern from the original dataset and returns a transformed version.
822
    :param target_transform: A function/transform that takes in the target
823
        and transforms it.
824
    :param transform_groups: A dictionary containing the transform groups.
825
        Transform groups are used to quickly switch between training and
826
        eval (test) transformations. This becomes useful when in need to
827
        test on the training dataset as test transformations usually don't
828
        contain random augmentations. ``AvalancheDataset`` natively supports
829
        the 'train' and 'eval' groups by calling the ``train()`` and
830
        ``eval()`` methods. When using custom groups one can use the
831
        ``with_transforms(group_name)`` method instead. Defaults to None,
832
        which means that the current transforms will be used to
833
        handle both 'train' and 'eval' groups (just like in standard
834
        ``torchvision`` datasets).
835
    :param initial_transform_group: The name of the initial transform group
836
        to be used. Defaults to None, which means that if all
837
        AvalancheDatasets in the input datasets list agree on a common
838
        group (the "current group" is the same for all datasets), then that
839
        group will be used as the initial one. If the list of input datasets
840
        does not contain an AvalancheDataset or if the AvalancheDatasets
841
        do not agree on a common group, then 'train' will be used.
842
    :param targets: The label of each pattern. Can either be a sequence of
843
        labels or, alternatively, a sequence containing sequences of labels
844
        (one for each dataset to be concatenated). Defaults to None, which
845
        means that the targets will be retrieved from the datasets (if
846
        possible).
847
    :param task_labels: The task labels for each pattern. Must be a sequence
848
        of ints, one for each pattern in the dataset. Alternatively, task
849
        labels can be expressed as a sequence containing sequences of ints
850
        (one for each dataset to be concatenated) or even a single int,
851
        in which case that value will be used as the task label for all
852
        instances. Defaults to None, which means that the dataset will try
853
        to obtain the task labels from the original datasets. If no task
854
        labels could be found for a dataset, a default task label 0 will
855
        be applied to all patterns of that dataset.
856
    :param collate_fn: The function to use when slicing to merge single
857
        patterns. In the future this function may become the function
858
        used in the data loading process, too. If None, the constructor
859
        will check if a `collate_fn` field exists in the first dataset. If
860
        no such field exists, the default collate function will be used.
861
        Beware that the chosen collate function will be applied to all
862
        the concatenated datasets even if a different collate is defined
863
        in different datasets.
864
    """
865
    dds = []
1✔
866
    per_dataset_task_labels = _split_user_def_task_label(datasets, task_labels)
1✔
867

868
    per_dataset_targets = _split_user_def_targets(
1✔
869
        datasets, targets, lambda x: isinstance(x, int)
870
    )
871

872
    # Find common "current_group" or use "train"
873
    if initial_transform_group is None:
1✔
874
        initial_transform_group = find_common_transforms_group(
1✔
875
            datasets, default_group="train"
876
        )
877

878
    supervised = True
1✔
879
    for dd, dataset_task_labels, dataset_targets in zip(
1✔
880
        datasets, per_dataset_task_labels, per_dataset_targets
881
    ):
882
        dd = _make_taskaware_classification_dataset(
1✔
883
            dd,
884
            transform=transform,
885
            target_transform=target_transform,
886
            transform_groups=transform_groups,
887
            initial_transform_group=initial_transform_group,
888
            task_labels=dataset_task_labels,
889
            targets=dataset_targets,
890
            collate_fn=collate_fn,
891
        )
892

893
        if not isinstance(dd, TaskAwareSupervisedClassificationDataset):
1✔
UNCOV
894
            supervised = False
×
895

896
        dds.append(dd)
1✔
897

898
    if len(dds) > 0:
1✔
899
        transform_groups_obj = _init_transform_groups(
1✔
900
            transform_groups,
901
            transform,
902
            target_transform,
903
            initial_transform_group,
904
            dds[0],
905
        )
906
    else:
UNCOV
907
        transform_groups_obj = None
×
908

909
    supervised = supervised and (
1✔
910
        (len(dds) > 0) or (targets is not None and task_labels is not None)
911
    )
912

913
    data: Union[
914
        TaskAwareSupervisedClassificationDataset, TaskAwareClassificationDataset
915
    ]
916
    if supervised:
1✔
917
        data = TaskAwareSupervisedClassificationDataset(
1✔
918
            dds, transform_groups=transform_groups_obj
919
        )
920
    else:
UNCOV
921
        data = TaskAwareClassificationDataset(
×
922
            dds, transform_groups=transform_groups_obj
923
        )
924
    return data.with_transforms(initial_transform_group)
1✔
925

926

927
def _select_targets(
1✔
928
    dataset: SupportedDataset, indices: Optional[List[int]]
929
) -> Sequence[TTargetType]:
930
    if hasattr(dataset, "targets"):
1✔
931
        # Standard supported dataset
932
        found_targets = dataset.targets  # type: ignore
1✔
933
    elif hasattr(dataset, "tensors"):
1✔
934
        # Support for PyTorch TensorDataset
935
        if len(dataset.tensors) < 2:  # type: ignore
1✔
UNCOV
936
            raise ValueError(
×
937
                "Tensor dataset has not enough tensors: " "at least 2 are required."
938
            )
939
        found_targets = dataset.tensors[1]  # type: ignore
1✔
940
    else:
941
        raise ValueError(
1✔
942
            "Unsupported dataset: must have a valid targets field "
943
            "or has to be a Tensor Dataset with at least 2 "
944
            "Tensors"
945
        )
946

947
    if indices is not None:
1✔
948
        found_targets = SubSequence(found_targets, indices=indices)
1✔
949

950
    return found_targets
1✔
951

952

953
def _concat_taskaware_classification_datasets_sequentially(
1✔
954
    train_dataset_list: Sequence[ISupportedClassificationDataset],
955
    test_dataset_list: Sequence[ISupportedClassificationDataset],
956
) -> Tuple[
957
    TaskAwareSupervisedClassificationDataset,
958
    TaskAwareSupervisedClassificationDataset,
959
    List[list],
960
]:
961
    """
962
    Concatenates a list of datasets. This is completely different from
963
    :class:`ConcatDataset`, in which datasets are merged together without
964
    other processing. Instead, this function re-maps the datasets class IDs.
965
    For instance:
966
    let the dataset[0] contain patterns of 3 different classes,
967
    let the dataset[1] contain patterns of 2 different classes, then class IDs
968
    will be mapped as follows:
969

970
    dataset[0] class "0" -> new class ID is "0"
971

972
    dataset[0] class "1" -> new class ID is "1"
973

974
    dataset[0] class "2" -> new class ID is "2"
975

976
    dataset[1] class "0" -> new class ID is "3"
977

978
    dataset[1] class "1" -> new class ID is "4"
979

980
    ... -> ...
981

982
    dataset[-1] class "C-1" -> new class ID is "overall_n_classes-1"
983

984
    In contrast, using PyTorch ConcatDataset:
985

986
    dataset[0] class "0" -> ID is "0"
987

988
    dataset[0] class "1" -> ID is "1"
989

990
    dataset[0] class "2" -> ID is "2"
991

992
    dataset[1] class "0" -> ID is "0"
993

994
    dataset[1] class "1" -> ID is "1"
995

996
    Note: ``train_dataset_list`` and ``test_dataset_list`` must have the same
997
    number of datasets.
998

999
    :param train_dataset_list: A list of training datasets
1000
    :param test_dataset_list: A list of test datasets
1001

1002
    :returns: A concatenated dataset.
1003
    """
1004
    remapped_train_datasets: List[TaskAwareSupervisedClassificationDataset] = []
1✔
1005
    remapped_test_datasets: List[TaskAwareSupervisedClassificationDataset] = []
1✔
1006
    next_remapped_idx = 0
1✔
1007

1008
    train_dataset_list_sup = list(
1✔
1009
        map(_as_taskaware_supervised_classification_dataset, train_dataset_list)
1010
    )
1011
    test_dataset_list_sup = list(
1✔
1012
        map(_as_taskaware_supervised_classification_dataset, test_dataset_list)
1013
    )
1014
    del train_dataset_list
1✔
1015
    del test_dataset_list
1✔
1016

1017
    # Obtain the number of classes of each dataset
1018
    classes_per_dataset = [
1✔
1019
        _count_unique(
1020
            train_dataset_list_sup[dataset_idx].targets,
1021
            test_dataset_list_sup[dataset_idx].targets,
1022
        )
1023
        for dataset_idx in range(len(train_dataset_list_sup))
1024
    ]
1025

1026
    new_class_ids_per_dataset = []
1✔
1027
    for dataset_idx in range(len(train_dataset_list_sup)):
1✔
1028
        # Get the train and test sets of the dataset
1029
        train_set = train_dataset_list_sup[dataset_idx]
1✔
1030
        test_set = test_dataset_list_sup[dataset_idx]
1✔
1031

1032
        # Get the classes in the dataset
1033
        dataset_classes = set(map(int, train_set.targets))
1✔
1034

1035
        # The class IDs for this dataset will be in range
1036
        # [n_classes_in_previous_datasets,
1037
        #       n_classes_in_previous_datasets + classes_in_this_dataset)
1038
        new_classes = list(
1✔
1039
            range(
1040
                next_remapped_idx,
1041
                next_remapped_idx + classes_per_dataset[dataset_idx],
1042
            )
1043
        )
1044
        new_class_ids_per_dataset.append(new_classes)
1✔
1045

1046
        # AvalancheSubset is used to apply the class IDs transformation.
1047
        # Remember, the class_mapping parameter must be a list in which:
1048
        # new_class_id = class_mapping[original_class_id]
1049
        # Hence, a list of size equal to the maximum class index is created
1050
        # Only elements corresponding to the present classes are remapped
1051
        class_mapping = [-1] * (max(dataset_classes) + 1)
1✔
1052
        j = 0
1✔
1053
        for i in dataset_classes:
1✔
1054
            class_mapping[i] = new_classes[j]
1✔
1055
            j += 1
1✔
1056

1057
        a = _taskaware_classification_subset(train_set, class_mapping=class_mapping)
1✔
1058

1059
        # Create remapped datasets and append them to the final list
1060
        remapped_train_datasets.append(
1✔
1061
            _taskaware_classification_subset(train_set, class_mapping=class_mapping)
1062
        )
1063
        remapped_test_datasets.append(
1✔
1064
            _taskaware_classification_subset(test_set, class_mapping=class_mapping)
1065
        )
1066
        next_remapped_idx += classes_per_dataset[dataset_idx]
1✔
1067

1068
    return (
1✔
1069
        _concat_taskaware_classification_datasets(remapped_train_datasets),
1070
        _concat_taskaware_classification_datasets(remapped_test_datasets),
1071
        new_class_ids_per_dataset,
1072
    )
1073

1074

1075
def _as_taskaware_supervised_classification_dataset(
1✔
1076
    dataset,
1077
    *,
1078
    transform: Optional[XTransform] = None,
1079
    target_transform: Optional[YTransform] = None,
1080
    transform_groups: Optional[Mapping[str, TransformGroupDef]] = None,
1081
    initial_transform_group: Optional[str] = None,
1082
    task_labels: Optional[Union[int, Sequence[int]]] = None,
1083
    targets: Optional[Sequence[TTargetType]] = None,
1084
    collate_fn: Optional[Callable[[List], Any]] = None
1085
) -> TaskAwareSupervisedClassificationDataset:
1086
    if (
1✔
1087
        transform is not None
1088
        or target_transform is not None
1089
        or transform_groups is not None
1090
        or initial_transform_group is not None
1091
        or task_labels is not None
1092
        or targets is not None
1093
        or collate_fn is not None
1094
        or not isinstance(dataset, TaskAwareSupervisedClassificationDataset)
1095
    ):
1096
        result_dataset = _make_taskaware_classification_dataset(
1✔
1097
            dataset=dataset,
1098
            transform=transform,
1099
            target_transform=target_transform,
1100
            transform_groups=transform_groups,
1101
            initial_transform_group=initial_transform_group,
1102
            task_labels=task_labels,
1103
            targets=targets,
1104
            collate_fn=collate_fn,
1105
        )
1106

1107
        if not isinstance(result_dataset, TaskAwareSupervisedClassificationDataset):
1✔
UNCOV
1108
            raise ValueError(
×
1109
                "The given dataset does not have supervision fields "
1110
                "(targets, task_labels)."
1111
            )
1112

1113
        return result_dataset
1✔
1114

1115
    return dataset
1✔
1116

1117

1118
__all__ = [
1✔
1119
    "SupportedDataset",
1120
    "TaskAwareClassificationDataset",
1121
    "TaskAwareSupervisedClassificationDataset",
1122
    "_make_taskaware_classification_dataset",
1123
    "_make_taskaware_tensor_classification_dataset",
1124
    "_taskaware_classification_subset",
1125
    "_concat_taskaware_classification_datasets",
1126
]
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc