• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ContinualAI / avalanche / 8382229381

21 Mar 2024 09:54PM UTC coverage: 51.806%. Remained the same
8382229381

Pull #1629

github

web-flow
Merge 9c9b8f9c8 into dbdc3804b
Pull Request #1629: WIP: Documentation & Typos

25 of 47 new or added lines in 6 files covered. (53.19%)

71 existing lines in 3 files now uncovered.

14756 of 28483 relevant lines covered (51.81%)

0.52 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

43.46
/avalanche/benchmarks/utils/detection_dataset.py
1
################################################################################
2
# Copyright (c) 2021 ContinualAI.                                              #
3
# Copyrights licensed under the MIT License.                                   #
4
# See the accompanying LICENSE file for terms.                                 #
5
#                                                                              #
6
# Date: 12-05-2020                                                             #
7
# Author(s): Lorenzo Pellegrini, Antonio Carta                                 #
8
# E-mail: contact@continualai.org                                              #
9
# Website: avalanche.continualai.org                                           #
10
################################################################################
11

12
"""
1✔
13
This module contains the implementation of the ``DetectionDataset``,
14
which is the dataset used for supervised continual learning benchmarks.
15
DetectionDatasets are ``AvalancheDatasets`` that manage targets and task
16
labels automatically. Concatenation and subsampling operations are optimized
17
to be used frequently, as is common in replay strategies.
18
"""
19
from functools import partial
1✔
20
from typing import (
1✔
21
    List,
22
    Any,
23
    Sequence,
24
    Union,
25
    Optional,
26
    TypeVar,
27
    Callable,
28
    Dict,
29
    Tuple,
30
    Mapping,
31
    overload,
32
)
33

34
import torch
1✔
35
from torch import Tensor
1✔
36
from torch.utils.data.dataset import Subset, ConcatDataset
1✔
37

38
from avalanche.benchmarks.utils.utils import (
1✔
39
    TaskSet,
40
    _init_task_labels,
41
    _init_transform_groups,
42
    _split_user_def_targets,
43
    _split_user_def_task_label,
44
    _traverse_supported_dataset,
45
)
46

47
from .collate_functions import detection_collate_fn
1✔
48
from .data import AvalancheDataset
1✔
49
from .data_attribute import DataAttribute
1✔
50
from .dataset_definitions import (
1✔
51
    IDataset,
52
    IDatasetWithTargets,
53
)
54
from .dataset_utils import (
1✔
55
    SubSequence,
56
)
57
from .flat_data import ConstantSequence
1✔
58
from .transform_groups import (
1✔
59
    TransformGroupDef,
60
    DefaultTransformGroups,
61
    TransformGroups,
62
    XTransform,
63
    YTransform,
64
)
65

66
T_co = TypeVar("T_co", covariant=True)
1✔
67
TAvalancheDataset = TypeVar("TAvalancheDataset", bound="AvalancheDataset")
1✔
68
TTargetType = Dict[str, Tensor]
1✔
69

70

71
# Image (tensor), target dict, task label
72
DetectionExampleT = Tuple[Tensor, TTargetType, int]
1✔
73
TDetectionDataset = TypeVar("TDetectionDataset", bound="DetectionDataset")
1✔
74

75

76
class DetectionDataset(AvalancheDataset[T_co]):
1✔
77
    @property
1✔
78
    def task_pattern_indices(self) -> Dict[int, Sequence[int]]:
1✔
79
        """A dictionary mapping task ids to their sample indices."""
80
        return self.targets_task_labels.val_to_idx  # type: ignore
×
81

82
    @property
1✔
83
    def task_set(self: TDetectionDataset) -> TaskSet[TDetectionDataset]:
1✔
84
        """Returns the dataset's ``TaskSet``, which is a mapping <task-id,
85
        task-dataset>."""
86
        return TaskSet(self)
×
87

88
    def subset(self, indices):
1✔
89
        data = super().subset(indices)
×
90
        return data.with_transforms(self._flat_data._transform_groups.current_group)
×
91

92
    def concat(self, other):
1✔
93
        data = super().concat(other)
×
94
        return data.with_transforms(self._flat_data._transform_groups.current_group)
×
95

96
    def __hash__(self):
1✔
97
        return id(self)
×
98

99

100
class SupervisedDetectionDataset(DetectionDataset[T_co]):
1✔
101
    def __init__(
1✔
102
        self,
103
        datasets: List[IDataset[T_co]],
104
        *,
105
        indices: Optional[List[int]] = None,
106
        data_attributes: Optional[List[DataAttribute]] = None,
107
        transform_groups: Optional[TransformGroups] = None,
108
        frozen_transform_groups: Optional[TransformGroups] = None,
109
        collate_fn: Optional[Callable[[List], Any]] = None
110
    ):
111
        super().__init__(
1✔
112
            datasets=datasets,
113
            indices=indices,
114
            data_attributes=data_attributes,
115
            transform_groups=transform_groups,
116
            frozen_transform_groups=frozen_transform_groups,
117
            collate_fn=collate_fn,
118
        )
119

120
        assert hasattr(self, "targets"), (
1✔
121
            "The supervised version of the ClassificationDataset requires "
122
            + "the targets field"
123
        )
124
        assert hasattr(self, "targets_task_labels"), (
1✔
125
            "The supervised version of the ClassificationDataset requires "
126
            + "the targets_task_labels field"
127
        )
128

129
    @property
1✔
130
    def targets(self) -> DataAttribute[TTargetType]:
1✔
131
        return self._data_attributes["targets"]
1✔
132

133
    @property
1✔
134
    def targets_task_labels(self) -> DataAttribute[int]:
1✔
135
        return self._data_attributes["targets_task_labels"]
1✔
136

137

138
SupportedDetectionDataset = Union[
1✔
139
    IDatasetWithTargets, Subset, ConcatDataset, DetectionDataset
140
]
141

142

143
@overload
1✔
144
def make_detection_dataset(
1✔
145
    dataset: SupervisedDetectionDataset,
146
    *,
147
    transform: Optional[XTransform] = None,
148
    target_transform: Optional[YTransform] = None,
149
    transform_groups: Optional[Mapping[str, TransformGroupDef]] = None,
150
    initial_transform_group: Optional[str] = None,
151
    task_labels: Optional[Union[int, Sequence[int]]] = None,
152
    targets: Optional[Sequence[TTargetType]] = None,
153
    collate_fn: Optional[Callable[[List], Any]] = None
154
) -> SupervisedDetectionDataset: ...
155

156

157
@overload
1✔
158
def make_detection_dataset(
1✔
159
    dataset: SupportedDetectionDataset,
160
    *,
161
    transform: Optional[XTransform] = None,
162
    target_transform: Optional[YTransform] = None,
163
    transform_groups: Optional[Mapping[str, TransformGroupDef]] = None,
164
    initial_transform_group: Optional[str] = None,
165
    task_labels: Union[int, Sequence[int]],
166
    targets: Sequence[TTargetType],
167
    collate_fn: Optional[Callable[[List], Any]] = None
168
) -> SupervisedDetectionDataset: ...
169

170

171
@overload
1✔
172
def make_detection_dataset(
1✔
173
    dataset: SupportedDetectionDataset,
174
    *,
175
    transform: Optional[XTransform] = None,
176
    target_transform: Optional[YTransform] = None,
177
    transform_groups: Optional[Mapping[str, TransformGroupDef]] = None,
178
    initial_transform_group: Optional[str] = None,
179
    task_labels: Optional[Union[int, Sequence[int]]] = None,
180
    targets: Optional[Sequence[TTargetType]] = None,
181
    collate_fn: Optional[Callable[[List], Any]] = None
182
) -> DetectionDataset: ...
183

184

185
def make_detection_dataset(
1✔
186
    dataset: SupportedDetectionDataset,
187
    *,
188
    transform: Optional[XTransform] = None,
189
    target_transform: Optional[YTransform] = None,
190
    transform_groups: Optional[Mapping[str, TransformGroupDef]] = None,
191
    initial_transform_group: Optional[str] = None,
192
    task_labels: Optional[Union[int, Sequence[int]]] = None,
193
    targets: Optional[Sequence[TTargetType]] = None,
194
    collate_fn: Optional[Callable[[List], Any]] = None
195
) -> Union[DetectionDataset, SupervisedDetectionDataset]:
196
    """Avalanche Detection Dataset.
197

198
    Supervised continual learning benchmarks in Avalanche return instances of
199
    this dataset, but it can also be used in a completely standalone manner.
200

201
    This dataset applies input/target transformations, it supports
202
    slicing and advanced indexing and it also contains useful fields as
203
    `targets`, which contains the pattern dictionaries, and
204
    `targets_task_labels`, which contains the pattern task labels.
205
    The `task_set` field can be used to obtain a the subset of patterns
206
    labeled with a given task label.
207

208
    This dataset can also be used to apply several advanced operations involving
209
    transformations. For instance, it allows the user to add and replace
210
    transformations, freeze them so that they can't be changed, etc.
211

212
    This dataset also allows the user to keep distinct transformations groups.
213
    Simply put, a transformation group is a pair of transform+target_transform
214
    (exactly as in torchvision datasets). This dataset natively supports keeping
215
    two transformation groups: the first, 'train', contains transformations
216
    applied to training patterns. Those transformations usually involve some
217
    kind of data augmentation. The second one is 'eval', that will contain
218
    transformations applied to test patterns. Having both groups can be
219
    useful when, for instance, in need to test on the training data (as this
220
    process usually involves removing data augmentation operations). Switching
221
    between transformations can be easily achieved by using the
222
    :func:`train` and :func:`eval` methods.
223

224
    Moreover, arbitrary transformation groups can be added and used. For more
225
    info see the constructor and the :func:`with_transforms` method.
226

227
    This dataset will try to inherit the task labels from the input
228
    dataset. If none are available and none are given via the `task_labels`
229
    parameter, each pattern will be assigned a default task label 0.
230

231
    Creates a ``AvalancheDataset`` instance.
232

233
    :param dataset: The dataset to decorate. Beware that
234
        AvalancheDataset will not overwrite transformations already
235
        applied by this dataset.
236
    :param transform: A function/transform that takes the X value of a
237
        pattern from the original dataset and returns a transformed version.
238
    :param target_transform: A function/transform that takes in the target
239
        and transforms it.
240
    :param transform_groups: A dictionary containing the transform groups.
241
        Transform groups are used to quickly switch between training and
242
        eval (test) transformations. This becomes useful when in need to
243
        test on the training dataset as test transformations usually don't
244
        contain random augmentations. ``AvalancheDataset`` natively supports
245
        the 'train' and 'eval' groups by calling the ``train()`` and
246
        ``eval()`` methods. When using custom groups one can use the
247
        ``with_transforms(group_name)`` method instead. Defaults to None,
248
        which means that the current transforms will be used to
249
        handle both 'train' and 'eval' groups (just like in standard
250
        ``torchvision`` datasets).
251
    :param initial_transform_group: The name of the initial transform group
252
        to be used. Defaults to None, which means that the current group of
253
        the input dataset will be used (if an AvalancheDataset). If the
254
        input dataset is not an AvalancheDataset, then 'train' will be
255
        used.
256
    :param task_labels: The task label of each instance. Must be a sequence
257
        of ints, one for each instance in the dataset. Alternatively can be
258
        a single int value, in which case that value will be used as the
259
        task label for all the instances. Defaults to None, which means that
260
        the dataset will try to obtain the task labels from the original
261
        dataset. If no task labels could be found, a default task label
262
        0 will be applied to all instances.
263
    :param targets: The dictionary of detection boxes of each pattern.
264
        Defaults to None, which means that the targets will be retrieved from
265
        the dataset (if possible).
266
    :param collate_fn: The function to use when slicing to merge single
267
        patterns. This function is the function used in the data loading
268
        process, too. If None, the constructor will check if a
269
        `collate_fn` field exists in the dataset. If no such field exists,
270
        the default collate function for detection will be used.
271
    """
272

273
    is_supervised = isinstance(dataset, SupervisedDetectionDataset)
1✔
274

275
    transform_gs = _init_transform_groups(
1✔
276
        transform_groups,
277
        transform,
278
        target_transform,
279
        initial_transform_group,
280
        dataset,
281
    )
282
    targets_data: Optional[DataAttribute[TTargetType]] = _init_targets(dataset, targets)
1✔
283
    task_labels_data: Optional[DataAttribute[int]] = _init_task_labels(
1✔
284
        dataset, task_labels
285
    )
286

287
    das: List[DataAttribute] = []
1✔
288
    if targets_data is not None:
1✔
289
        das.append(targets_data)
1✔
290
    if task_labels_data is not None:
1✔
291
        das.append(task_labels_data)
1✔
292

293
    # Check if supervision data has been added
294
    is_supervised = is_supervised or (
1✔
295
        targets_data is not None and task_labels_data is not None
296
    )
297

298
    if collate_fn is None:
1✔
299
        collate_fn = getattr(dataset, "collate_fn", detection_collate_fn)
1✔
300

301
    data: Union[DetectionDataset, SupervisedDetectionDataset]
302
    if is_supervised:
1✔
303
        data = SupervisedDetectionDataset(
1✔
304
            [dataset],
305
            data_attributes=das if len(das) > 0 else None,
306
            transform_groups=transform_gs,
307
            collate_fn=collate_fn,
308
        )
309
    else:
310
        data = DetectionDataset(
×
311
            [dataset],
312
            data_attributes=das if len(das) > 0 else None,
313
            transform_groups=transform_gs,
314
            collate_fn=collate_fn,
315
        )
316

317
    if initial_transform_group is not None:
1✔
318
        return data.with_transforms(initial_transform_group)
×
319
    else:
320
        return data
1✔
321

322

323
def _init_targets(
1✔
324
    dataset, targets, check_shape=True
325
) -> Optional[DataAttribute[TTargetType]]:
326
    if targets is not None:
1✔
327
        # User defined targets always take precedence
328
        if len(targets) != len(dataset) and check_shape:
1✔
329
            raise ValueError(
×
330
                "Invalid amount of target labels. It must be equal to the "
331
                "number of patterns in the dataset. Got {}, expected "
332
                "{}!".format(len(targets), len(dataset))
333
            )
334
        return DataAttribute(targets, "targets")
1✔
335

336
    targets = _traverse_supported_dataset(dataset, _select_targets)
×
337

338
    if targets is None:
×
339
        return None
×
340

341
    return DataAttribute(targets, "targets")
×
342

343

344
def _detection_class_mapping_transform(class_mapping, example_target_dict):
1✔
345
    example_target_dict = dict(example_target_dict)
×
346

347
    # example_target_dict["labels"] is a tensor containing one label
348
    # for each bounding box in the image. We need to remap each of them
349
    example_target_labels = example_target_dict["labels"]
×
350
    example_mapped_labels = [class_mapping[int(el)] for el in example_target_labels]
×
351

352
    if isinstance(example_target_labels, Tensor):
×
353
        example_mapped_labels = torch.as_tensor(example_mapped_labels)
×
354

355
    example_target_dict["labels"] = example_mapped_labels
×
356

357
    return example_target_dict
×
358

359

360
@overload
1✔
361
def detection_subset(
1✔
362
    dataset: SupervisedDetectionDataset,
363
    indices: Optional[Sequence[int]] = None,
364
    *,
365
    class_mapping: Optional[Sequence[int]] = None,
366
    transform: Optional[XTransform] = None,
367
    target_transform: Optional[YTransform] = None,
368
    transform_groups: Optional[Mapping[str, Tuple[XTransform, YTransform]]] = None,
369
    initial_transform_group: Optional[str] = None,
370
    task_labels: Optional[Union[int, Sequence[int]]] = None,
371
    targets: Optional[Sequence[TTargetType]] = None,
372
    collate_fn: Optional[Callable[[List], Any]] = None
373
) -> SupervisedDetectionDataset: ...
374

375

376
@overload
1✔
377
def detection_subset(
1✔
378
    dataset: SupportedDetectionDataset,
379
    indices: Optional[Sequence[int]] = None,
380
    *,
381
    class_mapping: Optional[Sequence[int]] = None,
382
    transform: Optional[XTransform] = None,
383
    target_transform: Optional[YTransform] = None,
384
    transform_groups: Optional[Mapping[str, Tuple[XTransform, YTransform]]] = None,
385
    initial_transform_group: Optional[str] = None,
386
    task_labels: Union[int, Sequence[int]],
387
    targets: Sequence[TTargetType],
388
    collate_fn: Optional[Callable[[List], Any]] = None
389
) -> SupervisedDetectionDataset: ...
390

391

392
@overload
1✔
393
def detection_subset(
1✔
394
    dataset: SupportedDetectionDataset,
395
    indices: Optional[Sequence[int]] = None,
396
    *,
397
    class_mapping: Optional[Sequence[int]] = None,
398
    transform: Optional[XTransform] = None,
399
    target_transform: Optional[YTransform] = None,
400
    transform_groups: Optional[Mapping[str, Tuple[XTransform, YTransform]]] = None,
401
    initial_transform_group: Optional[str] = None,
402
    task_labels: Optional[Union[int, Sequence[int]]] = None,
403
    targets: Optional[Sequence[TTargetType]] = None,
404
    collate_fn: Optional[Callable[[List], Any]] = None
405
) -> DetectionDataset: ...
406

407

408
def detection_subset(
1✔
409
    dataset: SupportedDetectionDataset,
410
    indices: Optional[Sequence[int]] = None,
411
    *,
412
    class_mapping: Optional[Sequence[int]] = None,
413
    transform: Optional[XTransform] = None,
414
    target_transform: Optional[YTransform] = None,
415
    transform_groups: Optional[Mapping[str, Tuple[XTransform, YTransform]]] = None,
416
    initial_transform_group: Optional[str] = None,
417
    task_labels: Optional[Union[int, Sequence[int]]] = None,
418
    targets: Optional[Sequence[TTargetType]] = None,
419
    collate_fn: Optional[Callable[[List], Any]] = None
420
) -> Union[DetectionDataset, SupervisedDetectionDataset]:
421
    """Creates an ``AvalancheSubset`` instance.
422

423
    For simple subset operations you should use the method
424
    `dataset.subset(indices)`.
425
    Use this constructor only if you need to redefine transformation or
426
    class/task labels.
427

428
    A Dataset that behaves like a PyTorch :class:`torch.utils.data.Subset`.
429
    This Dataset also supports transformations, slicing, advanced indexing,
430
    the targets field, class mapping and all the other goodies listed in
431
    :class:`AvalancheDataset`.
432

433
    :param dataset: The whole dataset.
434
    :param indices: Indices in the whole set selected for subset. Can
435
        be None, which means that the whole dataset will be returned.
436
    :param class_mapping: A list that, for each possible class label value,
437
        contains its corresponding remapped value. Can be None.
438
    :param transform: A function/transform that takes the X value of a
439
        pattern from the original dataset and returns a transformed version.
440
    :param target_transform: A function/transform that takes in the target
441
        and transforms it.
442
    :param transform_groups: A dictionary containing the transform groups.
443
        Transform groups are used to quickly switch between training and
444
        eval (test) transformations. This becomes useful when in need to
445
        test on the training dataset as test transformations usually don't
446
        contain random augmentations. ``AvalancheDataset`` natively supports
447
        the 'train' and 'eval' groups by calling the ``train()`` and
448
        ``eval()`` methods. When using custom groups one can use the
449
        ``with_transforms(group_name)`` method instead. Defaults to None,
450
        which means that the current transforms will be used to
451
        handle both 'train' and 'eval' groups (just like in standard
452
        ``torchvision`` datasets).
453
    :param initial_transform_group: The name of the initial transform group
454
        to be used. Defaults to None, which means that the current group of
455
        the input dataset will be used (if an AvalancheDataset). If the
456
        input dataset is not an AvalancheDataset, then 'train' will be
457
        used.
458
    :param task_labels: The task label for each instance. Must be a sequence
459
        of ints, one for each instance in the dataset. This can either be a
460
        list of task labels for the original dataset or the list of task
461
        labels for the instances of the subset (an automatic detection will
462
        be made). In the unfortunate case in which the original dataset and
463
        the subset contain the same amount of instances, then this parameter
464
        is considered to contain the task labels of the subset.
465
        Alternatively can be a single int value, in which case
466
        that value will be used as the task label for all the instances.
467
        Defaults to None, which means that the dataset will try to
468
        obtain the task labels from the original dataset. If no task labels
469
        could be found, a default task label 0 will be applied to all
470
        instances.
471
    :param targets: The target dictionary of each pattern. Defaults to None,
472
        which means that the targets will be retrieved from the dataset (if
473
        possible). This can either be a list of target dictionaries for the
474
        original dataset or the list of target dictionaries for the instances
475
        of the subset (an automatic detection will be made). In the
476
        unfortunate case in which the original dataset and the subset contain
477
        the same amount of instances, then this parameter is considered to
478
        contain the target dictionaries of the subset.
479
    :param collate_fn: The function to use when slicing to merge single
480
        patterns. This function is the function used in the data loading
481
        process, too. If None, the constructor will check if a
482
        `collate_fn` field exists in the dataset. If no such field exists,
483
        the default collate function for detection will be used
484
    """
485

486
    is_supervised = isinstance(dataset, SupervisedDetectionDataset)
×
487

488
    if isinstance(dataset, DetectionDataset):
×
489
        if (
×
490
            class_mapping is None
491
            and transform is None
492
            and target_transform is None
493
            and transform_groups is None
494
            and initial_transform_group is None
495
            and task_labels is None
496
            and targets is None
497
            and collate_fn is None
498
        ):
499
            return dataset.subset(indices)
×
500

501
    targets_data: Optional[DataAttribute[TTargetType]] = _init_targets(
×
502
        dataset, targets, check_shape=False
503
    )
504
    task_labels_data: Optional[DataAttribute[int]] = _init_task_labels(
×
505
        dataset, task_labels, check_shape=False
506
    )
507

508
    del task_labels
×
509
    del targets
×
510

511
    transform_gs = _init_transform_groups(
×
512
        transform_groups,
513
        transform,
514
        target_transform,
515
        initial_transform_group,
516
        dataset,
517
    )
518

519
    if initial_transform_group is not None and isinstance(dataset, AvalancheDataset):
×
520
        dataset = dataset.with_transforms(initial_transform_group)
×
521

522
    if class_mapping is not None:  # update targets
×
523
        if targets_data is None:
×
524
            # Should not happen
525
            # The following line usually fails
526
            targets_data = dataset.targets  # type: ignore
×
527

528
        assert (
×
529
            targets_data is not None
530
        ), "To execute the class mapping, a list of targets is required."
531

532
        tgs = [
×
533
            _detection_class_mapping_transform(class_mapping, example_target_dict)
534
            for example_target_dict in targets_data
535
        ]
536

537
        targets_data = DataAttribute(tgs, "targets")
×
538

539
    if class_mapping is not None:
×
540
        mapping_fn = partial(_detection_class_mapping_transform, class_mapping)
×
541
        frozen_transform_groups = DefaultTransformGroups((None, mapping_fn))
×
542
    else:
543
        frozen_transform_groups = None
×
544

545
    das: List[DataAttribute] = []
×
546
    if targets_data is not None:
×
547
        das.append(targets_data)
×
548
    if task_labels_data is not None:
×
549
        das.append(task_labels_data)
×
550

551
    # Check if supervision data has been added
552
    is_supervised = is_supervised or (
×
553
        targets_data is not None and task_labels_data is not None
554
    )
555

556
    if collate_fn is None:
×
557
        collate_fn = detection_collate_fn
×
558

559
    if is_supervised:
×
560
        return SupervisedDetectionDataset(
×
561
            [dataset],
562
            indices=list(indices) if indices is not None else None,
563
            data_attributes=das if len(das) > 0 else None,
564
            transform_groups=transform_gs,
565
            frozen_transform_groups=frozen_transform_groups,
566
            collate_fn=collate_fn,
567
        )
568
    else:
569
        return DetectionDataset(
×
570
            [dataset],
571
            indices=list(indices) if indices is not None else None,
572
            data_attributes=das if len(das) > 0 else None,
573
            transform_groups=transform_gs,
574
            frozen_transform_groups=frozen_transform_groups,
575
            collate_fn=collate_fn,
576
        )
577

578

579
@overload
1✔
580
def concat_detection_datasets(
1✔
581
    datasets: Sequence[SupervisedDetectionDataset],
582
    *,
583
    transform: Optional[XTransform] = None,
584
    target_transform: Optional[YTransform] = None,
585
    transform_groups: Optional[Mapping[str, Tuple[XTransform, YTransform]]] = None,
586
    initial_transform_group: Optional[str] = None,
587
    task_labels: Optional[Union[int, Sequence[int], Sequence[Sequence[int]]]] = None,
588
    targets: Optional[
589
        Union[Sequence[TTargetType], Sequence[Sequence[TTargetType]]]
590
    ] = None,
591
    collate_fn: Optional[Callable[[List], Any]] = None
592
) -> SupervisedDetectionDataset: ...
593

594

595
@overload
1✔
596
def concat_detection_datasets(
1✔
597
    datasets: Sequence[SupportedDetectionDataset],
598
    *,
599
    transform: Optional[XTransform] = None,
600
    target_transform: Optional[YTransform] = None,
601
    transform_groups: Optional[Mapping[str, Tuple[XTransform, YTransform]]] = None,
602
    initial_transform_group: Optional[str] = None,
603
    task_labels: Union[int, Sequence[int], Sequence[Sequence[int]]],
604
    targets: Union[Sequence[TTargetType], Sequence[Sequence[TTargetType]]],
605
    collate_fn: Optional[Callable[[List], Any]] = None
606
) -> SupervisedDetectionDataset: ...
607

608

609
@overload
1✔
610
def concat_detection_datasets(
1✔
611
    datasets: Sequence[SupportedDetectionDataset],
612
    *,
613
    transform: Optional[XTransform] = None,
614
    target_transform: Optional[YTransform] = None,
615
    transform_groups: Optional[Mapping[str, Tuple[XTransform, YTransform]]] = None,
616
    initial_transform_group: Optional[str] = None,
617
    task_labels: Optional[Union[int, Sequence[int], Sequence[Sequence[int]]]] = None,
618
    targets: Optional[
619
        Union[Sequence[TTargetType], Sequence[Sequence[TTargetType]]]
620
    ] = None,
621
    collate_fn: Optional[Callable[[List], Any]] = None
622
) -> DetectionDataset: ...
623

624

625
def concat_detection_datasets(
1✔
626
    datasets: Sequence[SupportedDetectionDataset],
627
    *,
628
    transform: Optional[XTransform] = None,
629
    target_transform: Optional[YTransform] = None,
630
    transform_groups: Optional[Mapping[str, Tuple[XTransform, YTransform]]] = None,
631
    initial_transform_group: Optional[str] = None,
632
    task_labels: Optional[Union[int, Sequence[int], Sequence[Sequence[int]]]] = None,
633
    targets: Optional[
634
        Union[Sequence[TTargetType], Sequence[Sequence[TTargetType]]]
635
    ] = None,
636
    collate_fn: Optional[Callable[[List], Any]] = None
637
) -> Union[DetectionDataset, SupervisedDetectionDataset]:
638
    """Creates a ``AvalancheConcatDataset`` instance.
639

640
    For simple subset operations you should use the method
641
    `dataset.concat(other)` or
642
    `concat_datasets` from `avalanche.benchmarks.utils.utils`.
643
    Use this constructor only if you need to redefine transformation or
644
    class/task labels.
645

646
    A Dataset that behaves like a PyTorch
647
    :class:`torch.utils.data.ConcatDataset`. However, this Dataset also supports
648
    transformations, slicing, advanced indexing and the targets field and all
649
    the other goodies listed in :class:`AvalancheDataset`.
650

651
    This dataset guarantees that the operations involving the transformations
652
    and transformations groups are consistent across the concatenated dataset
653
    (if they are subclasses of :class:`AvalancheDataset`).
654

655
    :param datasets: A collection of datasets.
656
    :param transform: A function/transform that takes the X value of a
657
        pattern from the original dataset and returns a transformed version.
658
    :param target_transform: A function/transform that takes in the target
659
        and transforms it.
660
    :param transform_groups: A dictionary containing the transform groups.
661
        Transform groups are used to quickly switch between training and
662
        eval (test) transformations. This becomes useful when in need to
663
        test on the training dataset as test transformations usually don't
664
        contain random augmentations. ``AvalancheDataset`` natively supports
665
        the 'train' and 'eval' groups by calling the ``train()`` and
666
        ``eval()`` methods. When using custom groups one can use the
667
        ``with_transforms(group_name)`` method instead. Defaults to None,
668
        which means that the current transforms will be used to
669
        handle both 'train' and 'eval' groups (just like in standard
670
        ``torchvision`` datasets).
671
    :param initial_transform_group: The name of the initial transform group
672
        to be used. Defaults to None, which means that if all
673
        AvalancheDatasets in the input datasets list agree on a common
674
        group (the "current group" is the same for all datasets), then that
675
        group will be used as the initial one. If the list of input datasets
676
        does not contain an AvalancheDataset or if the AvalancheDatasets
677
        do not agree on a common group, then 'train' will be used.
678
    :param targets: The label of each pattern. Can either be a sequence of
679
        labels or, alternatively, a sequence containing sequences of labels
680
        (one for each dataset to be concatenated). Defaults to None, which
681
        means that the targets will be retrieved from the datasets (if
682
        possible).
683
    :param task_labels: The task labels for each pattern. Must be a sequence
684
        of ints, one for each pattern in the dataset. Alternatively, task
685
        labels can be expressed as a sequence containing sequences of ints
686
        (one for each dataset to be concatenated) or even a single int,
687
        in which case that value will be used as the task label for all
688
        instances. Defaults to None, which means that the dataset will try
689
        to obtain the task labels from the original datasets. If no task
690
        labels could be found for a dataset, a default task label 0 will
691
        be applied to all patterns of that dataset.
692
    :param collate_fn: The function to use when slicing to merge single
693
        patterns. This function is the function used in the data loading
694
        process, too. If None, the constructor will check if a `collate_fn`
695
        field exists in the first dataset. If no such field exists, the
696
        default collate function for detection  will be used.
697
        Beware that the chosen collate function will be applied to all
698
        the concatenated datasets even if a different collate is defined
699
        in different datasets.
700
    """
701
    dds = []
×
702
    per_dataset_task_labels = _split_user_def_task_label(datasets, task_labels)
×
703

704
    per_dataset_targets = _split_user_def_targets(
×
705
        datasets, targets, lambda x: isinstance(x, dict)
706
    )
707

708
    for dd, dataset_task_labels, dataset_targets in zip(
×
709
        datasets, per_dataset_task_labels, per_dataset_targets
710
    ):
711
        dd = make_detection_dataset(
×
712
            dd,
713
            transform=transform,
714
            target_transform=target_transform,
715
            transform_groups=transform_groups,
716
            initial_transform_group=initial_transform_group,
717
            task_labels=dataset_task_labels,
718
            targets=dataset_targets,
719
            collate_fn=collate_fn,
720
        )
721
        dds.append(dd)
×
722

723
    if (
×
724
        transform is None
725
        and target_transform is None
726
        and transform_groups is None
727
        and initial_transform_group is None
728
        and task_labels is None
729
        and targets is None
730
        and collate_fn is None
731
        and len(datasets) > 0
732
    ):
733
        d0 = datasets[0]
×
734
        if isinstance(d0, DetectionDataset):
×
735
            for d1 in datasets[1:]:
×
736
                d0 = d0.concat(d1)
×
737
            return d0
×
738

739
    das: List[DataAttribute] = []
×
740
    if len(dds) > 0:
×
741
        #######################################
742
        # TRANSFORMATION GROUPS
743
        #######################################
NEW
744
        transform_groups_obj = _init_transform_groups(
×
745
            transform_groups,
746
            transform,
747
            target_transform,
748
            initial_transform_group,
749
            dds[0],
750
        )
751

752
        # Find common "current_group" or use "train"
UNCOV
753
        if initial_transform_group is None:
×
754
            uniform_group = None
×
755
            for d_set in datasets:
×
756
                if isinstance(d_set, AvalancheDataset):
×
757
                    if uniform_group is None:
×
758
                        uniform_group = d_set._transform_groups.current_group
×
759
                    else:
UNCOV
760
                        if uniform_group != d_set._transform_groups.current_group:
×
761
                            uniform_group = None
×
762
                            break
×
763

UNCOV
764
            if uniform_group is None:
×
765
                initial_transform_group = "train"
×
766
            else:
UNCOV
767
                initial_transform_group = uniform_group
×
768

769
        #######################################
770
        # DATA ATTRIBUTES
771
        #######################################
772

UNCOV
773
        totlen = sum([len(d) for d in datasets])
×
774
        if task_labels is not None:  # User defined targets always take precedence
×
775
            all_labels: IDataset[int]
UNCOV
776
            if isinstance(task_labels, int):
×
777
                all_labels = ConstantSequence(task_labels, totlen)
×
778
            else:
UNCOV
779
                all_labels_lst = []
×
780
                for dd, dataset_task_labels in zip(dds, per_dataset_task_labels):
×
781
                    assert dataset_task_labels is not None
×
782

783
                    # We already checked that len(t_labels) == len(dataset)
784
                    # (done in _split_user_def_task_label)
UNCOV
785
                    if isinstance(dataset_task_labels, int):
×
786
                        all_labels_lst.extend([dataset_task_labels] * len(dd))
×
787
                    else:
UNCOV
788
                        all_labels_lst.extend(dataset_task_labels)
×
789
                all_labels = all_labels_lst
×
790
            das.append(
×
791
                DataAttribute(all_labels, "targets_task_labels", use_in_getitem=True)
792
            )
793

UNCOV
794
        if targets is not None:  # User defined targets always take precedence
×
795
            all_targets_lst: List[TTargetType] = []
×
796
            for dd, dataset_targets in zip(dds, per_dataset_targets):
×
797
                assert dataset_targets is not None
×
798

799
                # We already checked that len(targets) == len(dataset)
800
                # (done in _split_user_def_targets)
UNCOV
801
                all_targets_lst.extend(dataset_targets)
×
802
            das.append(DataAttribute(all_targets_lst, "targets"))
×
803
    else:
UNCOV
804
        transform_groups_obj = None
×
805
        initial_transform_group = "train"
×
806

UNCOV
807
    data = DetectionDataset(
×
808
        dds,
809
        transform_groups=transform_groups_obj,
810
        data_attributes=das if len(das) > 0 else None,
811
    )
UNCOV
812
    return data.with_transforms(initial_transform_group)
×
813

814

815
def _select_targets(
1✔
816
    dataset: SupportedDetectionDataset, indices: Optional[List[int]]
817
) -> Sequence[TTargetType]:
UNCOV
818
    if hasattr(dataset, "targets"):
×
819
        # Standard supported dataset
UNCOV
820
        found_targets = dataset.targets
×
821
    else:
UNCOV
822
        raise ValueError("Unsupported dataset: must have a valid targets field")
×
823

UNCOV
824
    if indices is not None:
×
825
        found_targets = SubSequence(found_targets, indices=indices)
×
826

UNCOV
827
    return found_targets
×
828

829

830
__all__ = [
1✔
831
    "SupportedDetectionDataset",
832
    "DetectionDataset",
833
    "make_detection_dataset",
834
    "detection_subset",
835
    "concat_detection_datasets",
836
]
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc