• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ContinualAI / avalanche / 5268393053

pending completion
5268393053

Pull #1397

github

web-flow
Merge 60d244754 into e91562200
Pull Request #1397: Specialize benchmark creation helpers

417 of 538 new or added lines in 30 files covered. (77.51%)

43 existing lines in 5 files now uncovered.

16586 of 22630 relevant lines covered (73.29%)

2.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

73.2
/avalanche/benchmarks/utils/utils.py
1
################################################################################
2
# Copyright (c) 2021 ContinualAI.                                              #
3
# Copyrights licensed under the MIT License.                                   #
4
# See the accompanying LICENSE file for terms.                                 #
5
#                                                                              #
6
# Date: 12-05-2020                                                             #
7
# Author(s): Lorenzo Pellegrini                                                #
8
# E-mail: contact@continualai.org                                              #
9
# Website: avalanche.continualai.org                                           #
10
################################################################################
11

12
""" Common benchmarks/environments utils. """
4✔
13

14
from collections import OrderedDict, defaultdict, deque
4✔
15
from typing import (
4✔
16
    TYPE_CHECKING,
17
    Any,
18
    Callable,
19
    Generic,
20
    Iterator,
21
    List,
22
    Iterable,
23
    Mapping,
24
    Optional,
25
    Sequence,
26
    TypeVar,
27
    Union,
28
    Dict,
29
    SupportsInt,
30
)
31
import warnings
4✔
32
import numpy as np
4✔
33

34
import torch
4✔
35
from torch import Tensor
4✔
36
from torch.utils.data import Subset, ConcatDataset, TensorDataset
4✔
37

38
from avalanche.benchmarks.utils.data import AvalancheDataset
4✔
39
from avalanche.benchmarks.utils.data_attribute import DataAttribute
4✔
40
from avalanche.benchmarks.utils.dataset_definitions import (
4✔
41
    ISupportedClassificationDataset,
42
)
43
from avalanche.benchmarks.utils.dataset_utils import (
4✔
44
    SubSequence,
45
    find_list_from_index,
46
)
47
from avalanche.benchmarks.utils.flat_data import ConstantSequence
4✔
48
from avalanche.benchmarks.utils.transform_groups import (
4✔
49
    TransformGroupDef,
50
    TransformGroups,
51
    XTransform,
52
    YTransform
53
)
54

55
if TYPE_CHECKING:
4✔
56
    from avalanche.benchmarks.utils.classification_dataset import (
×
57
        ClassificationDataset
58
    )
59

60
T_co = TypeVar("T_co", covariant=True)
4✔
61
TAvalancheDataset = TypeVar("TAvalancheDataset", bound="AvalancheDataset")
4✔
62

63

64
def tensor_as_list(sequence) -> List:
4✔
65
    # Numpy: list(np.array([1, 2, 3])) returns [1, 2, 3]
66
    # whereas: list(torch.tensor([1, 2, 3])) returns ->
67
    # -> [tensor(1), tensor(2), tensor(3)]
68
    #
69
    # This is why we have to handle Tensor in a different way
70
    if isinstance(sequence, list):
4✔
71
        return sequence
×
72
    if not isinstance(sequence, Iterable):
4✔
73
        return [sequence]
×
74
    if isinstance(sequence, Tensor):
4✔
75
        return sequence.tolist()
4✔
76
    return list(sequence)
4✔
77

78

79
def _indexes_grouped_by_classes(
4✔
80
    targets: Sequence[int],
81
    patterns_indexes: Union[None, Sequence[int]],
82
    sort_indexes: bool = True,
83
    sort_classes: bool = True,
84
) -> Union[List[int], None]:
85
    result_per_class: Dict[int, List[int]] = OrderedDict()
×
86
    result: List[int] = []
×
87

88
    indexes_was_none = patterns_indexes is None
×
89

90
    if patterns_indexes is not None:
×
91
        patterns_indexes = tensor_as_list(patterns_indexes)
×
92
    else:
93
        patterns_indexes = list(range(len(targets)))
×
94

95
    targets = tensor_as_list(targets)
×
96

97
    # Consider that result_per_class is an OrderedDict
98
    # This means that, if sort_classes is True, the next for statement
99
    # will initialize "result_per_class" in sorted order which in turn means
100
    # that patterns will be ordered by ascending class ID.
101
    classes = torch.unique(
×
102
        torch.as_tensor(targets), sorted=sort_classes
103
    ).tolist()
104

105
    for class_id in classes:
×
106
        result_per_class[class_id] = []
×
107

108
    # Stores each pattern index in the appropriate class list
109
    for idx in patterns_indexes:
×
110
        result_per_class[targets[idx]].append(idx)
×
111

112
    # Concatenate all the pattern indexes
113
    for class_id in classes:
×
114
        if sort_indexes:
×
115
            result_per_class[class_id].sort()
×
116
        result.extend(result_per_class[class_id])
×
117

118
    if result == patterns_indexes and indexes_was_none:
×
119
        # Result is [0, 1, 2, ..., N] and patterns_indexes was originally None
120
        # This means that the user tried to obtain a full Dataset
121
        # (indexes_was_none) only ordered according to the sort_indexes and
122
        # sort_classes parameters. However, sort_indexes+sort_classes returned
123
        # the plain pattern sequence as it already is. So the original Dataset
124
        # already satisfies the sort_indexes+sort_classes constraints.
125
        # By returning None, we communicate that the Dataset can be taken as-is.
126
        return None
×
127

128
    return result
×
129

130

131
def grouped_and_ordered_indexes(
4✔
132
    targets: Sequence[int],
133
    patterns_indexes: Union[None, Sequence[int]],
134
    bucket_classes: bool = True,
135
    sort_classes: bool = False,
136
    sort_indexes: bool = False,
137
) -> Union[List[int], None]:
138
    """
139
    Given the targets list of a dataset and the patterns to include, returns the
140
    pattern indexes sorted according to the ``bucket_classes``,
141
    ``sort_classes`` and ``sort_indexes`` parameters.
142

143
    :param targets: The list of pattern targets, as a list.
144
    :param patterns_indexes: A list of pattern indexes to include in the set.
145
        If None, all patterns will be included.
146
    :param bucket_classes: If True, pattern indexes will be returned so that
147
        patterns will be grouped by class. Defaults to True.
148
    :param sort_classes: If both ``bucket_classes`` and ``sort_classes`` are
149
        True, class groups will be sorted by class index. Ignored if
150
        ``bucket_classes`` is False. Defaults to False.
151
    :param sort_indexes: If True, patterns indexes will be sorted. When
152
        bucketing by class, patterns will be sorted inside their buckets.
153
        Defaults to False.
154

155
    :returns: The list of pattern indexes sorted according to the
156
        ``bucket_classes``, ``sort_classes`` and ``sort_indexes`` parameters or
157
        None if the patterns_indexes is None and the whole dataset can be taken
158
        using the existing patterns order.
159
    """
160
    if bucket_classes:
×
161
        return _indexes_grouped_by_classes(
×
162
            targets,
163
            patterns_indexes,
164
            sort_indexes=sort_indexes,
165
            sort_classes=sort_classes,
166
        )
167

168
    if patterns_indexes is None:
×
169
        # No grouping and sub-set creation required... just return None
170
        return None
×
171
    if not sort_indexes:
×
172
        # No sorting required, just return patterns_indexes
173
        return tensor_as_list(patterns_indexes)
×
174

175
    # We are here only because patterns_indexes != None and sort_indexes is True
176
    patterns_indexes = tensor_as_list(patterns_indexes)
×
177
    result = list(patterns_indexes)  # Make sure we're working on a copy
×
178
    result.sort()
×
179
    return result
×
180

181

182
def as_avalanche_dataset(
4✔
183
    dataset: ISupportedClassificationDataset[T_co],
184
) -> AvalancheDataset:
185
    if isinstance(dataset, AvalancheDataset):
×
186
        return dataset
×
187
    return AvalancheDataset([dataset])
×
188

189

190
def as_classification_dataset(
4✔
191
    dataset: ISupportedClassificationDataset[T_co],
192
) -> 'ClassificationDataset':
193
    from avalanche.benchmarks.utils.classification_dataset import (
×
194
        ClassificationDataset
195
    )
196

197
    if isinstance(dataset, ClassificationDataset):
×
198
        return dataset
×
199
    return ClassificationDataset([dataset])
×
200

201

202
def _count_unique(*sequences: Sequence[SupportsInt]):
4✔
203
    uniques = set()
4✔
204

205
    for seq in sequences:
4✔
206
        for x in seq:
4✔
207
            uniques.add(int(x))
4✔
208

209
    return len(uniques)
4✔
210

211

212
def concat_datasets(datasets):
4✔
213
    """Concatenates a list of datasets."""
214
    if len(datasets) == 0:
4✔
215
        return AvalancheDataset([])
4✔
216
    res = datasets[0]
4✔
217
    if not isinstance(res, AvalancheDataset):
4✔
218
        res = AvalancheDataset([res])
4✔
219

220
    for d in datasets[1:]:
4✔
221
        if not isinstance(d, AvalancheDataset):
4✔
222
            d = AvalancheDataset([d])
4✔
223
        res = res.concat(d)
4✔
224
    return res
4✔
225

226

227
def find_common_transforms_group(
4✔
228
        datasets: Iterable[Any], 
229
        default_group: str = "train") -> str:
230
    """
231
    Utility used to find the common transformations group across multiple
232
    datasets.
233

234
    To compute the common group, the current one is used. Objects which are not
235
    instances of :class:`AvalancheDataset` are ignored.
236
    If no common group is found, then the default one is returned.
237

238
    :param datasets: The list of datasets.
239
    :param default_group: The name of the default group.
240
    :returns: The name of the common group.
241
    """
242
    # Find common "current_group" or use "train"
243
    uniform_group: Optional[str] = None
4✔
244
    for d_set in datasets:
4✔
245
        if isinstance(d_set, AvalancheDataset):
4✔
246
            if uniform_group is None:
4✔
247
                uniform_group = d_set._flat_data._transform_groups.current_group
4✔
248
            else:
249
                if (
4✔
250
                    uniform_group
251
                    != d_set._flat_data._transform_groups.current_group
252
                ):
253
                    uniform_group = None
×
254
                    break
×
255

256
    if uniform_group is None:
4✔
257
        initial_transform_group = default_group
×
258
    else:
259
        initial_transform_group = uniform_group
4✔
260

261
    return initial_transform_group
4✔
262

263

264
Y = TypeVar('Y')
4✔
265
T = TypeVar('T')
4✔
266

267

268
def _traverse_supported_dataset(
4✔
269
    dataset: Y,
270
    values_selector: Callable[[Y, Optional[List[int]]], Optional[Sequence[T]]],
271
    indices: Optional[List[int]] = None
272
) -> Sequence[T]:
273
    """
274
    Traverse the given dataset by gathering required info.
275

276
    The given dataset is traversed by covering all sub-datasets
277
    contained PyTorch :class:`Subset` and :class`ConcatDataset`.
278
    Beware that instances of :class:`AvalancheDataset` will not
279
    be traversed as those objects already have the proper data 
280
    attribute fields populated with data from leaf datasets.
281

282
    For each dataset, the `values_selector` will be called to gather
283
    the required information. The values returned by the given selector
284
    are then concatenated to create a final list of values.
285

286
    :param dataset: The dataset to traverse.
287
    :param values_selector: A function that, given the dataset
288
        and the indices to consider (which may be None if the entire 
289
        dataset must be considered), returns a list of selected values.
290
    :returns: The list of selected values.
291
    """
292
    initial_error = None
4✔
293
    try:
4✔
294
        result = values_selector(dataset, indices)
4✔
295
        if result is not None:
4✔
296
            return result
4✔
297
    except BaseException as e:
4✔
298
        initial_error = e
4✔
299

300
    if isinstance(dataset, Subset):
4✔
301
        if indices is None:
4✔
302
            indices = [dataset.indices[x] for x in range(len(dataset))]
4✔
303
        else:
304
            indices = [dataset.indices[x] for x in indices]
4✔
305
        
306
        return list(
4✔
307
            _traverse_supported_dataset(
308
                dataset.dataset, values_selector, indices
309
            )
310
        )
311

312
    if isinstance(dataset, ConcatDataset):
4✔
313
        result = []
4✔
314
        if indices is None:
4✔
315
            for c_dataset in dataset.datasets:
4✔
316
                result += list(
4✔
317
                    _traverse_supported_dataset(
318
                        c_dataset, values_selector, indices
319
                    )
320
                )
321
            return result
4✔
322

323
        datasets_to_indexes = defaultdict(list)
4✔
324
        indexes_to_dataset = []
4✔
325
        datasets_len = []
4✔
326
        recursion_result = []
4✔
327

328
        all_size = 0
4✔
329
        for c_dataset in dataset.datasets:
4✔
330
            len_dataset = len(c_dataset)
4✔
331
            datasets_len.append(len_dataset)
4✔
332
            all_size += len_dataset
4✔
333

334
        for subset_idx in indices:
4✔
335
            dataset_idx, pattern_idx = find_list_from_index(
4✔
336
                subset_idx, datasets_len, all_size
337
            )
338
            datasets_to_indexes[dataset_idx].append(pattern_idx)
4✔
339
            indexes_to_dataset.append(dataset_idx)
4✔
340

341
        for dataset_idx, c_dataset in enumerate(dataset.datasets):
4✔
342
            recursion_result.append(
4✔
343
                deque(
344
                    _traverse_supported_dataset(
345
                        c_dataset,
346
                        values_selector,
347
                        datasets_to_indexes[dataset_idx],
348
                    )
349
                )
350
            )
351

352
        result = []
4✔
353
        for idx in range(len(indices)):
4✔
354
            dataset_idx = indexes_to_dataset[idx]
4✔
355
            result.append(recursion_result[dataset_idx].popleft())
4✔
356

357
        return result
4✔
358

359
    if initial_error is not None:
4✔
360
        raise initial_error
4✔
361

362
    raise ValueError("Error: can't find the needed data in the given dataset")
×
363

364

365
def _init_task_labels(dataset, task_labels, check_shape=True) -> \
4✔
366
        Optional[DataAttribute[int]]:
367
    """
368
    Initializes the task label list (one for each pattern in the dataset).
369

370
    Precedence is given to the values contained in `task_labels` if passed.
371
    Otherwisem the elements will be retrieved from the dataset itself by
372
    traversing it and looking at the `targets_task_labels` field.
373

374
    :param dataset: The dataset for which the task labels list must be 
375
        initialized. Ignored if `task_labels` is passed, but it may still be
376
        used if `check_shape` is true.
377
    :param task_labels: The task labels to use. May be None, in which case
378
        the labels will be retrieved from the dataset.
379
    :param check_shape: If True, will check if the length of the task labels
380
        list matches the dataset size. Ignored if the labels are retrieved 
381
        from the dataset.
382
    :returns: A data attribute containing the task labels. May be None to
383
        signal that the dataset's `targets_task_labels` field should be used
384
        (because the dataset is a :class:`AvalancheDataset`).
385
    """
386
    if task_labels is not None:
4✔
387
        # task_labels has priority over the dataset fields
388
        if isinstance(task_labels, int):
4✔
389
            task_labels = ConstantSequence(task_labels, len(dataset))
4✔
390
        elif len(task_labels) != len(dataset) and check_shape:
4✔
391
            raise ValueError(
×
392
                "Invalid amount of task labels. It must be equal to the "
393
                "number of patterns in the dataset. Got {}, expected "
394
                "{}!".format(len(task_labels), len(dataset))
395
            )
396

397
        if isinstance(task_labels, ConstantSequence):
4✔
398
            tls = task_labels
4✔
399
        elif isinstance(task_labels, DataAttribute):
4✔
UNCOV
400
            tls = task_labels.data
×
401
        else:
402
            tls = SubSequence(task_labels, converter=int)
4✔
403
    else:
404
        task_labels = _traverse_supported_dataset(
4✔
405
            dataset, _select_task_labels
406
        )
407

408
        if task_labels is None:
4✔
UNCOV
409
            tls = None
×
410
        elif isinstance(task_labels, ConstantSequence):
4✔
411
            tls = task_labels
4✔
412
        elif isinstance(task_labels, DataAttribute):
4✔
413
            return DataAttribute(
4✔
414
                task_labels.data, "targets_task_labels",
415
                use_in_getitem=True)
416
        else:
417
            tls = SubSequence(task_labels, converter=int)
4✔
418

419
    if tls is None:
4✔
UNCOV
420
        return None
×
421
    return DataAttribute(tls, "targets_task_labels", use_in_getitem=True)
4✔
422

423

424
def _select_task_labels(dataset: Any, indices: Optional[List[int]]) -> \
4✔
425
        Optional[Sequence[SupportsInt]]:
426
    """
427
    Selector function to be passed to :func:`_traverse_supported_dataset`
428
    to obtain the `targets_task_labels` for the given dataset.
429

430
    :param dataset: the traversed dataset.
431
    :param indices: the indices describing the subset to consider.
432
    :returns: The list of task labels or None if not found.
433
    """
434
    found_task_labels: Optional[Sequence[SupportsInt]] = None
4✔
435
    if hasattr(dataset, "targets_task_labels"):
4✔
436
        found_task_labels = dataset.targets_task_labels
4✔
437

438
    if found_task_labels is None:
4✔
439
        if isinstance(dataset, (Subset, ConcatDataset)):
4✔
440
            return None  # Continue traversing
4✔
441

442
    if found_task_labels is None:
4✔
443
        if indices is None:
4✔
444
            return ConstantSequence(0, len(dataset))
4✔
445
        return ConstantSequence(0, len(indices))
4✔
446

447
    if indices is not None:
4✔
UNCOV
448
        found_task_labels = SubSequence(found_task_labels, indices=indices)
×
449

450
    return found_task_labels
4✔
451

452

453
def _init_transform_groups(
4✔
454
    transform_groups: Optional[Mapping[str, TransformGroupDef]],
455
    transform: Optional[XTransform],
456
    target_transform: Optional[YTransform],
457
    initial_transform_group: Optional[str],
458
    dataset,
459
) -> Optional[TransformGroups]:
460
    """
461
    Initializes the transform groups for the given dataset.
462

463
    This internal utility is commonly used to manage the transformation
464
    defintions coming from the user-facing API. The user may want to
465
    define transformations in a more classic (and simple) way by
466
    passing a single `transform`, or in a more elaborate way by
467
    passing a dictionary of groups (`transform_groups`).
468

469
    :param transform_groups: The transform groups to use as a dictionary
470
        (group_name -> group). Can be None. Mutually exclusive with 
471
        `targets` and `target_transform`
472
    :param transform: The transformation for the X value. Can be None.
473
    :param target_transform: The transformation for the Y value. Can be None.
474
    :param initial_transform_group: The name of the initial group.
475
        If None, 'train' will be used.
476
    :param dataset: The avalanche dataset, used only to obtain the name of
477
        the initial transformations groups if `initial_transform_group` is 
478
        None.
479
    :returns: a :class:`TransformGroups` instance if any transformation
480
        was passed, else None.
481
    """
482
    if transform_groups is not None and (
4✔
483
        transform is not None or target_transform is not None
484
    ):
UNCOV
485
        raise ValueError(
×
486
            "transform_groups can't be used with transform"
487
            "and target_transform values"
488
        )
489

490
    if transform_groups is not None:
4✔
491
        _check_groups_dict_format(transform_groups)
4✔
492

493
    if initial_transform_group is None:
4✔
494
        # Detect from the input dataset. If not an AvalancheDataset then
495
        # use 'train' as the initial transform group
496
        if (
4✔
497
            isinstance(dataset, AvalancheDataset)
498
            and dataset._flat_data._transform_groups is not None
499
        ):
500
            tgs = dataset._flat_data._transform_groups
4✔
501
            initial_transform_group = tgs.current_group
4✔
502
        else:
503
            initial_transform_group = "train"
4✔
504

505
    if transform_groups is None:
4✔
506
        if target_transform is None and transform is None:
4✔
507
            tgs = None
4✔
508
        else:
509
            tgs = TransformGroups(
4✔
510
                {
511
                    "train": (transform, target_transform),
512
                    "eval": (transform, target_transform),
513
                },
514
                current_group=initial_transform_group,
515
            )
516
    else:
517
        tgs = TransformGroups(
4✔
518
            transform_groups, current_group=initial_transform_group
519
        )
520
    return tgs
4✔
521

522

523
def _check_groups_dict_format(groups_dict):
4✔
524
    # The original groups_dict must be convertible to native Python dict
525
    groups_dict = dict(groups_dict)
4✔
526

527
    # Check if the format of the groups is correct
528
    for map_key in groups_dict:
4✔
529
        if not isinstance(map_key, str):
4✔
UNCOV
530
            raise ValueError(
×
531
                "Every group must be identified by a string."
532
                'Wrong key was: "' + str(map_key) + '"'
533
            )
534

535
    if "test" in groups_dict:
4✔
UNCOV
536
        warnings.warn(
×
537
            'A transformation group named "test" has been found. Beware '
538
            "that by default AvalancheDataset supports test transformations"
539
            ' through the "eval" group. Consider using that one!'
540
        )
541

542

543
def _split_user_def_task_label(
4✔
544
    datasets,
545
    task_labels: Optional[Union[int, 
546
                                Sequence[int],
547
                                Sequence[Sequence[int]]]]) -> \
548
        List[Optional[Union[int, Sequence[int]]]]:
549
    """
550
    Given a datasets list and the user-defined list of task labels,
551
    returns the task labels list of each dataset.
552

553
    This internal utility is mainly used to manage the different ways
554
    in which the user can define the task labels:
555
    - As a single task label for all exemplars of all datasets
556
    - A single list of length equal to the sum of the lengths of all datasets
557
    - A list containing, for each dataset, one element between: 
558
        - a list, defining the task labels of each exemplar of a that dataset
559
        - an int, defining the task label of all exemplars of a that dataset
560
    
561
    :param datasets: The list of datasets.
562
    :param task_labels: The user-defined task labels. Can be None, in which
563
        case a list of None will be returned.
564
    :returns: A list containing as many elements as the input `datasets`. 
565
        Each element is either a list of task labels or None. If None 
566
        (because `task_labels` is None), this means that the task labels
567
        should be retrieved by traversing each dataset.
568
    """
569
    t_labels = []
4✔
570
    idx_start = 0
4✔
571
    for dd_idx, dd in enumerate(datasets):
4✔
572
        end_idx = idx_start + len(dd)
4✔
573
        dataset_t_label: Optional[Union[int, Sequence[int]]]
574
        if task_labels is None:
4✔
575
            # No task label set
576
            dataset_t_label = None
4✔
577
        elif isinstance(task_labels, int):
4✔
578
            # Single integer (same label for all instances)
579
            dataset_t_label = task_labels
4✔
UNCOV
580
        elif isinstance(task_labels[0], int):
×
581
            # Single task labels sequence
582
            # (to be split across concatenated datasets)
UNCOV
583
            dataset_t_label = task_labels[idx_start:end_idx]  # type: ignore
×
UNCOV
584
        elif len(task_labels[dd_idx]) == len(dd):  # type: ignore
×
585
            # One sequence per dataset
UNCOV
586
            dataset_t_label = task_labels[dd_idx]
×
587
        else:
UNCOV
588
            raise ValueError(
×
589
                'The task_labels parameter has an invalid format.'
590
            )
591
        t_labels.append(dataset_t_label)
4✔
592

593
        idx_start = end_idx
4✔
594
    return t_labels
4✔
595

596

597
def _split_user_def_targets(
4✔
598
        datasets,
599
        targets: Optional[Union[Sequence[T], Sequence[Sequence[T]]]],
600
        single_element_checker: Callable[[Any], bool]) -> \
601
            List[Optional[Sequence[T]]]:
602
    """
603
    Given a datasets list and the user-defined list of targets,
604
    returns the targets list of each dataset.
605

606
    This internal utility is mainly used to manage the different ways
607
    in which the user can define the targets:
608
    - A single list of length equal to the sum of the lengths of all datasets
609
    - A list containing, for each dataset, a list, defining the targets 
610
        of each exemplar of a that dataset
611
    
612
    :param datasets: The list of datasets.
613
    :param targets: The user-defined targets. Can be None, in which
614
        case a list of None will be returned.
615
    :returns: A list containing as many elements as the input `datasets`. 
616
        Each element is either a list of targets or None. If None 
617
        (because `targets` is None), this means that the targets
618
        should be retrieved by traversing each dataset.
619
    """
620
    t_labels = []
4✔
621
    idx_start = 0
4✔
622
    for dd_idx, dd in enumerate(datasets):
4✔
623
        end_idx = idx_start + len(dd)
4✔
624
        dataset_t_label: Optional[Sequence[T]]
625
        if targets is None:
4✔
626
            # No targets set
627
            dataset_t_label = None
4✔
UNCOV
628
        elif single_element_checker(targets[0]):
×
629
            # Single targets sequence
630
            # (to be split across concatenated datasets)
UNCOV
631
            dataset_t_label = targets[idx_start:end_idx]  # type: ignore
×
UNCOV
632
        elif len(targets[dd_idx]) == len(dd):  # type: ignore
×
633
            # One sequence per dataset
UNCOV
634
            dataset_t_label = targets[dd_idx]  # type: ignore
×
635
        else:
UNCOV
636
            raise ValueError(
×
637
                'The targets parameter has an invalid format.'
638
            )
639
        t_labels.append(dataset_t_label)
4✔
640

641
        idx_start = end_idx
4✔
642
    return t_labels
4✔
643

644

645
class TaskSet(Mapping[int, TAvalancheDataset], Generic[TAvalancheDataset]):
4✔
646
    """A lazy mapping for <task-label -> task dataset>.
4✔
647

648
    Given an `AvalancheClassificationDataset`, this class provides an
649
    iterator that splits the data into task subsets, returning tuples
650
    `<task_id, task_dataset>`.
651

652
    Usage:
653

654
    .. code-block:: python
655

656
        tset = TaskSet(data)
657
        for tid, tdata in tset:
658
            print(f"task {tid} has {len(tdata)} examples.")
659

660
    """
661

662
    def __init__(self, data: TAvalancheDataset):
4✔
663
        """Constructor.
664

665
        :param data: original data
666
        """
667
        super().__init__()
4✔
668
        self.data: TAvalancheDataset = data
4✔
669

670
    def __iter__(self) -> Iterator[int]:
4✔
671
        t_labels = self._get_task_labels_field()
4✔
672
        return iter(t_labels.uniques)
4✔
673

674
    def __getitem__(self, task_label: int):
4✔
675
        t_labels = self._get_task_labels_field()
4✔
676
        tl_idx = t_labels.val_to_idx[task_label]
4✔
677
        return self.data.subset(
4✔
678
            tl_idx
679
        )
680

681
    def __len__(self) -> int:
4✔
682
        t_labels = self._get_task_labels_field()
4✔
683
        return len(t_labels.uniques)
4✔
684

685
    def _get_task_labels_field(self) -> DataAttribute[int]:
4✔
686
        return self.data.targets_task_labels  # type: ignore
4✔
687

688

689
def _numpy_is_sequence_int(numpy_tensor: np.ndarray) -> bool:
4✔
NEW
690
    return issubclass(numpy_tensor.dtype.type, np.integer)
×
691

692

693
def _numpy_is_single_int(numpy_tensor: np.ndarray) -> bool:
4✔
NEW
694
    try:
×
NEW
695
        single_value = numpy_tensor.item()
×
NEW
696
        return isinstance(single_value, int)
×
NEW
697
    except ValueError:
×
NEW
698
        return False
×
699

700

701
def _torch_is_sequence_int(torch_tensor: Tensor) -> bool:
4✔
702
    return not torch.is_floating_point(torch_tensor) and \
4✔
703
        not torch.is_complex(torch_tensor)
704

705

706
def _torch_is_single_int(torch_tensor: Tensor) -> bool:
4✔
NEW
707
    try:
×
NEW
708
        single_value = torch_tensor.item()
×
NEW
709
        return isinstance(single_value, int)
×
NEW
710
    except ValueError:
×
NEW
711
        return False
×
712
    
713

714
def _element_is_single_int(element: Any):
4✔
715
    if isinstance(element, (int, np.integer)):
4✔
716
        return True
4✔
717
    if isinstance(element, Tensor):
4✔
NEW
718
        return _torch_is_single_int(element)
×
719
    else:
720
        return False
4✔
721

722

723
def _is_int_iterable(iterable: Iterable[Any]):
4✔
724
    if isinstance(iterable, torch.Tensor):
4✔
NEW
725
        return _torch_is_sequence_int(iterable)
×
726
    elif isinstance(iterable, np.ndarray):
4✔
NEW
727
        return _numpy_is_sequence_int(iterable)
×
728
    else:
729
        for t in iterable:
4✔
730
            if not _element_is_single_int(t):
4✔
731
                return False
4✔
732
        return True
4✔
733
    
734

735
AnyT = TypeVar('AnyT', bound=Iterable)
4✔
736

737

738
def _to_int_list(iterable: AnyT, force: bool = True) -> Union[AnyT, List[int]]:
4✔
739
    if isinstance(iterable, torch.Tensor):
4✔
740
        if _torch_is_sequence_int(iterable):
4✔
741
            return iterable.tolist()
4✔
742
        elif force:
4✔
NEW
743
            raise ValueError('Cannot convert PyTorch Tenspr to int list')
×
744
        else:
745
            return iterable
4✔
746
    elif isinstance(iterable, np.ndarray):
4✔
NEW
747
        if _numpy_is_sequence_int(iterable):
×
NEW
748
            return iterable.tolist()
×
NEW
749
        elif force:
×
NEW
750
            raise ValueError('Cannot convert NumPy array to int list')
×
751
        else:
NEW
752
            return iterable  # type: ignore
×
753
    else:
754
        int_list = []
4✔
755
        for t in iterable:
4✔
756
            if _element_is_single_int(t):
4✔
757
                int_list.append(t)
4✔
NEW
758
            elif force:
×
NEW
759
                raise ValueError('Cannot convert sequence to int list')
×
760
            else:
NEW
761
                return iterable
×
762
        return int_list
4✔
763

764

765
def _smart_init_targets(
4✔
766
    dataset,
767
    targets,
768
    check_shape=True
769
):
770
    """
771
    Initializes the targets for a given dataset.
772

773
    To support backwards compatibility for when when 
774
    :func:`create_multi_dataset_generic_benchmark` was
775
    used to manage classification benchmarks only, this function will try to
776
    mimic the steps taken in :func:`make_classification_dataset`, that is:
777
    
778
    - will try to check if the input dataset has classification 
779
        targets (integer tensors / ndarray) and will cast them to
780
        a list of native ints, as expected by other parts
781
        of Avalanche.
782
    - accepts passing an int for the targets field. The given int
783
        will be applied to all exemplars in the dataset. 
784
    - supports PyTorch TensorDataset, by taking the second tensor as targets.
785

786
    If targets are not of type int, then they will be returned as-is,
787
    so that other types of datasets (regression, detection, ...) are
788
    supported without issues.
789

790
    :param dataset: The input dataset. If the `targets` parameter is
791
        not None, then targets will be retrieved from the dataset.
792
    :param targets: The targets to use. Can be None, in which case
793
        targets will be retrieved from the dataset.
794
    :param check_shape: If True, will check if the number of exemplars
795
        in the dataset match the length of the obtained targets sequence.
796
    :return: The targets, as a DataAttribute of elements whose type depends
797
        on the input dataset.
798
    """
799
    if targets is not None:
4✔
800
        # User defined targets always take precedence
NEW
801
        if isinstance(targets, int):
×
802
            # Classification targets
NEW
803
            targets = ConstantSequence(targets, len(dataset))
×
NEW
804
        elif len(targets) != len(dataset) and check_shape:
×
NEW
805
            raise ValueError(
×
806
                "Invalid number of target labels. It must be equal to the "
807
                "number of patterns in the dataset. Got {}, expected "
808
                "{}!".format(len(targets), len(dataset))
809
            )
NEW
810
        return DataAttribute(targets, "targets")
×
811

812
    targets = _traverse_supported_dataset(
4✔
813
        dataset, _smart_select_targets_opt)
814
    
815
    if targets is not None:
4✔
816
        # Classification targets
817
        targets = _to_int_list(targets, force=False)
4✔
818

819
    if targets is None:
4✔
NEW
820
        return None
×
821
    
822
    return DataAttribute(targets, "targets")
4✔
823

824

825
def _smart_select_targets_opt(
4✔
826
        dataset: Any,
827
        indices: Optional[List[int]]) -> Optional[Sequence[Any]]:
828
    if hasattr(dataset, "targets"):
4✔
829
        # Standard supported dataset
830
        found_targets = dataset.targets
4✔
831
    elif hasattr(dataset, "tensors") and len(dataset.tensors) >= 2:
4✔
832
        # Support for PyTorch TensorDataset
833
        found_targets = dataset.tensors[1]
4✔
834
    else:
NEW
835
        return None
×
836

837
    if indices is not None:
4✔
NEW
838
        found_targets = SubSequence(found_targets, indices=indices)
×
839

840
    return found_targets
4✔
841

842

843
def make_generic_dataset(
4✔
844
    dataset: Any,
845
    *,
846
    transform: Optional[XTransform] = None,
847
    target_transform: Optional[YTransform] = None,
848
    transform_groups: Optional[Mapping[str, TransformGroupDef]] = None,
849
    initial_transform_group: Optional[str] = None,
850
    task_labels: Optional[Union[int, Sequence[int]]] = None,
851
    targets: Optional[Any] = None,
852
    collate_fn: Optional[Callable[[List], Any]] = None
853
) -> AvalancheDataset:
854
    """
855
    Helper function will create an :class:`AvalancheDataset` with
856
    supervision fields `targets` and `targets_task_labels` (if given or found
857
    in the input dataset).
858

859
    :param dataset: The dataset to wrap in the AvalancheDataset. If it contains
860
        `targets` and/or `targets_task_labels` fields, then those fields will
861
        be inherited by the resulting dataset (if not given by the `targets`
862
        or `task_labels` parameters). This will also check if the input dataset
863
        is a :class:`TensorDataset` and, in that case, will try to use the
864
        second tensor as the `targets` field.
865
    :param transform: The transformation to apply to X values.
866
        Mutually exclusive with `transform_groups`.
867
    :param target_transform: The transformation to apply to Y values.
868
        Mutually exclusive with `transform_groups`.
869
    :param transform_groups: The transformations groups to add to the dataset.
870
        Mutually xclusive with `transform` and `target_transform`.
871
    :param task_labels: A list containing a task label for each example. Can
872
        also be a plain `int`, in which case it will be applied to all
873
        examples. If not None, shadows the `targets_task_labels` field from
874
        the input dataset.
875
    :param targets: A list containing a target for each example. If not None,
876
        shadows the `targets` field from the input dataset.
877
    :param collate_fn: The collate function to use when loading this dataset.
878

879
    :returns: An :class:`AvalancheDataset`.
880
    """
881
    if isinstance(dataset, AvalancheDataset):
4✔
882
        return dataset
4✔
883

884
    transform_gs = _init_transform_groups(
4✔
885
        transform_groups=transform_groups,
886
        transform=transform,
887
        target_transform=target_transform,
888
        initial_transform_group=initial_transform_group,
889
        dataset=dataset,
890
    )
891

892
    targets_data: Optional[DataAttribute[Any]] = \
4✔
893
        _smart_init_targets(dataset, targets)
894
    task_labels_data: Optional[DataAttribute[int]] = \
4✔
895
        _init_task_labels(dataset, task_labels)
896

897
    das: List[DataAttribute] = []
4✔
898
    if targets_data is not None:
4✔
899
        das.append(targets_data)
4✔
900
    if task_labels_data is not None:
4✔
901
        das.append(task_labels_data)
4✔
902

903
    data = AvalancheDataset(
4✔
904
        [dataset],
905
        data_attributes=das if len(das) > 0 else None,
906
        transform_groups=transform_gs,
907
        collate_fn=collate_fn,
908
    )
909
    
910
    if initial_transform_group is not None:
4✔
NEW
911
        return data.with_transforms(initial_transform_group)
×
912
    else:
913
        return data
4✔
914

915

916
def make_generic_tensor_dataset(
4✔
917
    dataset_tensors: Sequence,
918
    *,
919
    transform: Optional[XTransform] = None,
920
    target_transform: Optional[YTransform] = None,
921
    transform_groups: Optional[Mapping[str, TransformGroupDef]] = None,
922
    initial_transform_group: Optional[str] = None,
923
    task_labels: Optional[Union[int, Sequence[int]]] = None,
924
    targets: Optional[Any] = None,
925
    collate_fn: Optional[Callable[[List], Any]] = None
926
) -> AvalancheDataset:
927
    if len(dataset_tensors) < 1:
4✔
NEW
928
        raise ValueError("At least one sequence must be passed")
×
929

930
    if isinstance(targets, int):
4✔
NEW
931
        targets = dataset_tensors[targets]
×
932
    tts = []
4✔
933
    for tt in dataset_tensors:  # TorchTensor requires a pytorch tensor
4✔
934
        if not hasattr(tt, 'size'):
4✔
NEW
935
            tt = torch.tensor(tt)
×
936
        tts.append(tt)
4✔
937
    dataset = TensorDataset(*tts)
4✔
938

939
    transform_gs = _init_transform_groups(
4✔
940
        transform_groups,
941
        transform,
942
        target_transform,
943
        initial_transform_group,
944
        dataset,
945
    )
946
    targets_data = _smart_init_targets(dataset, targets)
4✔
947
    task_labels_data = _init_task_labels(dataset, task_labels)
4✔
948

949
    das: List[DataAttribute] = []
4✔
950
    if targets_data is not None:
4✔
951
        das.append(targets_data)
4✔
952
    if task_labels_data is not None:
4✔
953
        das.append(task_labels_data)
4✔
954

955
    data = AvalancheDataset(
4✔
956
        [dataset],
957
        data_attributes=das if len(das) > 0 else None,
958
        transform_groups=transform_gs,
959
        collate_fn=collate_fn,
960
    )
961

962
    if initial_transform_group is not None:
4✔
NEW
963
        return data.with_transforms(initial_transform_group)
×
964
    else:
965
        return data
4✔
966

967

968
__all__ = [
4✔
969
    "tensor_as_list",
970
    "grouped_and_ordered_indexes",
971
    "as_avalanche_dataset",
972
    "as_classification_dataset",
973
    "concat_datasets",
974
    "find_common_transforms_group",
975
    "TaskSet",
976
    "make_generic_dataset",
977
    "make_generic_tensor_dataset"
978
]
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc