• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ContinualAI / avalanche / 5268393053

pending completion
5268393053

Pull #1397

github

web-flow
Merge 60d244754 into e91562200
Pull Request #1397: Specialize benchmark creation helpers

417 of 538 new or added lines in 30 files covered. (77.51%)

43 existing lines in 5 files now uncovered.

16586 of 22630 relevant lines covered (73.29%)

2.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.61
/avalanche/benchmarks/utils/flat_data.py
1
################################################################################
2
# Copyright (c) 2022 ContinualAI.                                              #
3
# Copyrights licensed under the MIT License.                                   #
4
# See the accompanying LICENSE file for terms.                                 #
5
#                                                                              #
6
# Date: 19-07-2022                                                             #
7
# Author(s): Antonio Carta                                                     #
8
# E-mail: contact@continualai.org                                              #
9
# Website: avalanche.continualai.org                                           #
10
################################################################################
11
"""
4✔
12
    Datasets with optimized concat/subset operations.
13
"""
14
import bisect
4✔
15

16
import numpy as np
4✔
17

18
from avalanche.benchmarks.utils.dataset_utils import (
4✔
19
    slice_alike_object_to_indices,
20
)
21
try:
4✔
22
    from collections import Hashable
4✔
23
except ImportError:
1✔
24
    from collections.abc import Hashable
1✔
25

26
from typing import (
4✔
27
    Iterable,
28
    List,
29
    Optional,
30
    Sequence,
31
    Tuple,
32
    TypeVar,
33
    Union,
34
    overload,
35
)
36

37
from torch.utils.data import ConcatDataset
4✔
38
import itertools
4✔
39
from avalanche.benchmarks.utils.dataset_definitions import IDataset
4✔
40

41

42
TFlatData = TypeVar('TFlatData', bound='FlatData')
4✔
43
DataT = TypeVar("DataT")
4✔
44
T_co = TypeVar("T_co", covariant=True)
4✔
45

46

47
class LazyIndices:
4✔
48
    """More efficient ops for indices.
4✔
49

50
    Avoid wasteful allocations, accept generators. Convert to list only
51
    when needed.
52

53
    Do not use for anything outside this file.
54
    """
55

56
    def __init__(self, *lists, known_length=None, offset=0):
4✔
57
        new_lists = []
4✔
58
        for ll in lists:
4✔
59
            if isinstance(ll, LazyIndices) and ll._eager_list is not None:
4✔
60
                # already eagerized, don't waste work
61
                new_lists.append(ll._eager_list)
4✔
62
            else:
63
                new_lists.append(ll)
4✔
64
        self._lists = new_lists
4✔
65

66
        if len(self._lists) == 1 and offset == 0:
4✔
UNCOV
67
            self._eager_list = self._lists[0]
×
68
        else:
69
            self._lazy_sequence = itertools.chain(*self._lists)
4✔
70
            """chain of generators
1✔
71
            this will be consumed over time whenever we need elems.
72
            """
73
            self._eager_list = None
4✔
74
            """This is the list where we save indices
1✔
75
            whenever we generate them from the lazy sequence.
76
            """
77
            self._offset = offset
4✔
78

79
        self._known_length = known_length
4✔
80

81
    def _to_eager(self):
4✔
82
        if self._eager_list is not None:
4✔
UNCOV
83
            return
×
84
        self._eager_list = [el + self._offset for el in self._lazy_sequence]
4✔
85

86
    def __getitem__(self, item):
4✔
87
        if self._eager_list is None:
4✔
88
            self._to_eager()
4✔
89
        return self._eager_list[item]
4✔
90

91
    def __add__(self, other):
4✔
92
        return LazyIndices(self, other)
4✔
93

94
    def __radd__(self, other):
4✔
95
        return LazyIndices(other, self)
4✔
96

97
    def __len__(self):
4✔
98
        if self._eager_list is not None:
4✔
99
            return len(self._eager_list)
4✔
100
        elif self._known_length is not None:
4✔
UNCOV
101
            return self._known_length
×
102
        else:
103
            # raise ValueError("Unknown lazy list length")
104
            return sum(len(ll) for ll in self._lists)
4✔
105

106

107
class LazyRange(LazyIndices):
4✔
108
    """Avoid 'eagerification' step for ranges."""
4✔
109

110
    def __init__(self, start, end, offset=0):
4✔
111
        self._start = start
4✔
112
        self._end = end
4✔
113
        self._offset = offset
4✔
114
        self._known_length = end - start
4✔
115
        self._eager_list = self
4✔
116

117
    def _to_eager(self):
4✔
118
        # LazyRange is eager already
UNCOV
119
        pass
×
120

121
    def __iter__(self):
4✔
122
        for i in range(self._start, self._end):
4✔
123
            yield self._offset + i
4✔
124

125
    def __getitem__(self, item):
4✔
UNCOV
126
        assert item >= self._start and item < self._end, \
×
127
            "LazyRange: index out of range"
UNCOV
128
        return self._start + self._offset + item
×
129

130
    def __add__(self, other):
4✔
131
        # TODO: could be optimized to merge contiguous ranges
132
        return LazyIndices(self, other)
4✔
133

134
    def __len__(self):
4✔
135
        return self._end - self._start
4✔
136

137

138
class FlatData(IDataset[T_co], Sequence[T_co]):
4✔
139
    """FlatData is a dataset optimized for efficient repeated concatenation
4✔
140
    and subset operations.
141

142
    The class combines concatentation and subsampling operations in a single
143
    class.
144

145
    Class for internal use only. Users shuold use `AvalancheDataset` for data
146
    or `DataAttribute` for attributes such as class and task labels.
147

148
    *Notes for subclassing*
149

150
    Cat/Sub operations are "flattened" if possible, which means that they will
151
    take the datasets and indices from their arguments and create a new dataset
152
    with them, avoiding the creation of large trees of dataset (what would
153
    happen with PyTorch datasets). Flattening is not always possible, for
154
    example if the data contains additional info (e.g. custom transformation
155
    groups), so subclasses MUST set `can_flatten` properly in order to avoid
156
    nasty bugs.
157
    """
158

159
    def __init__(
4✔
160
        self,
161
        datasets: Sequence[IDataset[T_co]],
162
        indices: Optional[List[int]] = None,
163
        can_flatten: bool = True,
164
    ):
165
        """Constructor
166

167
        :param datasets: list of datasets to concatenate.
168
        :param indices:  list of indices.
169
        :param can_flatten: if True, enables flattening.
170
        """
171
        self._datasets: List[IDataset[T_co]] = list(datasets)
4✔
172
        self._indices: Optional[List[int]] = indices
4✔
173
        self._can_flatten: bool = can_flatten
4✔
174

175
        if can_flatten:
4✔
176
            self._datasets = _flatten_dataset_list(self._datasets)
4✔
177
            self._datasets, self._indices = _flatten_datasets_and_reindex(
4✔
178
                self._datasets, self._indices)
179
        self._cumulative_sizes = ConcatDataset.cumsum(self._datasets)
4✔
180

181
        # NOTE: check disabled to avoid slowing down OCL scenarios
182
        # # check indices
183
        # if self._indices is not None and len(self) > 0:
184
        #     assert min(self._indices) >= 0
185
        #     assert max(self._indices) < self._cumulative_sizes[-1], \
186
        #         f"Max index {max(self._indices)} is greater than datasets " \
187
        #         f"list size {self._cumulative_sizes[-1]}"
188

189
    def _get_lazy_indices(self):
4✔
190
        """This method creates indices on-the-fly if self._indices=None.
191
        Only for internal use. Call may be expensive if self._indices=None.
192
        """
193
        if self._indices is not None:
4✔
194
            return self._indices
4✔
195
        else:
196
            return LazyRange(0, len(self))
4✔
197

198
    def subset(self: TFlatData, indices: Optional[Iterable[int]]) -> TFlatData:
4✔
199
        """Subsampling operation.
200

201
        :param indices: indices of the new samples
202
        :return:
203
        """
204
        if indices is not None and not isinstance(indices, list):
4✔
205
            indices = list(indices)
4✔
206

207
        if self._can_flatten and indices is not None:
4✔
208
            if self._indices is None:
4✔
209
                new_indices = indices
4✔
210
            else:
211
                self_indices = self._get_lazy_indices()
4✔
212
                new_indices = [self_indices[x] for x in indices]
4✔
213
            return self.__class__(datasets=self._datasets, indices=new_indices)
4✔
214
        return self.__class__(datasets=[self], indices=indices)
4✔
215

216
    def concat(self: TFlatData, other: TFlatData) -> TFlatData:
4✔
217
        """Concatenation operation.
218

219
        :param other: other dataset.
220
        :return:
221
        """
222
        if (not self._can_flatten) and (not other._can_flatten):
4✔
223
            return self.__class__(datasets=[self, other])
4✔
224

225
        # Case 1: one is a subset of the other
226
        if len(self._datasets) == 1 and len(other._datasets) == 1:
4✔
227
            if self._can_flatten and self._datasets[0] is other \
4✔
228
                    and other._indices is None:
229
                idxs = self._get_lazy_indices() + other._get_lazy_indices()
4✔
230
                return other.subset(idxs)
4✔
231
            elif other._can_flatten and other._datasets[0] is self \
4✔
232
                    and self._indices is None:
233
                idxs = self._get_lazy_indices() + other._get_lazy_indices()
4✔
234
                return self.subset(idxs)
4✔
235
            elif (
4✔
236
                self._can_flatten
237
                and other._can_flatten
238
                and self._datasets[0] is other._datasets[0]
239
            ):
240
                idxs = LazyIndices(self._get_lazy_indices(),
4✔
241
                                   other._get_lazy_indices())
242
                return self.__class__(datasets=self._datasets, indices=idxs)
4✔
243

244
        # Case 2: at least one of them can be flattened
245
        if self._can_flatten and other._can_flatten:
4✔
246
            if self._indices is None and other._indices is None:
4✔
247
                new_indices = None
4✔
248
            else:
249
                if len(self._cumulative_sizes) == 0:
4✔
UNCOV
250
                    base_other = 0
×
251
                else:
252
                    base_other = self._cumulative_sizes[-1]
4✔
253
                other_idxs = LazyIndices(other._get_lazy_indices(),
4✔
254
                                         offset=base_other)
255
                new_indices = self._get_lazy_indices() + other_idxs
4✔
256
            return self.__class__(
4✔
257
                datasets=self._datasets + other._datasets, indices=new_indices
258
            )
259
        elif self._can_flatten:
4✔
260
            if self._indices is None and other._indices is None:
4✔
261
                new_indices = None
4✔
262
            else:
263
                if len(self._cumulative_sizes) == 0:
4✔
UNCOV
264
                    base_other = 0
×
265
                else:
266
                    base_other = self._cumulative_sizes[-1]
4✔
267
                other_idxs = LazyRange(0, len(other), offset=base_other)
4✔
268
                new_indices = self._get_lazy_indices() + other_idxs
4✔
269
            return self.__class__(
4✔
270
                datasets=self._datasets + [other], indices=new_indices
271
            )
272
        elif other._can_flatten:
4✔
273
            if self._indices is None and other._indices is None:
4✔
UNCOV
274
                new_indices = None
×
275
            else:
276
                base_other = len(self)
4✔
277
                self_idxs = LazyRange(0, len(self))
4✔
278
                other_idxs = LazyIndices(other._get_lazy_indices(),
4✔
279
                                         offset=base_other)
280
                new_indices = self_idxs + other_idxs
4✔
281
            return self.__class__(
4✔
282
                datasets=[self] + other._datasets, indices=new_indices
283
            )
284
        else:
285
            assert False, "should never get here"
×
286

287
    def _get_idx(self, idx) -> Tuple[int, int]:
4✔
288
        """Return the index as a tuple <dataset-idx, sample-idx>.
289

290
        The first index indicates the dataset to use from `self._datasets`,
291
        while the second is the index of the sample in
292
        `self._datasets[dataset-idx]`.
293

294
        Private method.
295
        """
296
        if self._indices is not None:  # subset indexing
4✔
297
            idx = self._indices[idx]
4✔
298
        if len(self._datasets) == 1:
4✔
299
            dataset_idx = 0
4✔
300
        else:  # concat indexing
301
            dataset_idx = bisect.bisect_right(self._cumulative_sizes, idx)
4✔
302
            if dataset_idx == 0:
4✔
303
                idx = idx
4✔
304
            else:
305
                idx = idx - self._cumulative_sizes[dataset_idx - 1]
4✔
306
        return dataset_idx, int(idx)
4✔
307
    
308
    @overload
4✔
309
    def __getitem__(self, item: int) -> T_co:
4✔
UNCOV
310
        ...
×
311

312
    @overload
4✔
313
    def __getitem__(self: TFlatData, item: slice) -> TFlatData:
4✔
UNCOV
314
        ...
×
315

316
    def __getitem__(self: TFlatData, item: Union[int, slice]) -> \
4✔
317
            Union[T_co, TFlatData]:
318
        if isinstance(item, (int, np.integer)):
4✔
319
            dataset_idx, idx = self._get_idx(int(item))
4✔
320
            return self._datasets[dataset_idx][idx]
4✔
321
        else:
322
            slice_indices = slice_alike_object_to_indices(
4✔
323
                slice_alike_object=item,
324
                max_length=len(self)
325
            )
326

327
            return self.subset(
4✔
328
                indices=slice_indices
329
            )
330

331
    def __len__(self) -> int:
4✔
332
        if len(self._cumulative_sizes) == 0:
4✔
333
            return 0
4✔
334
        elif self._indices is not None:
4✔
335
            return len(self._indices)
4✔
336
        return self._cumulative_sizes[-1]
4✔
337

338
    def __add__(self: TFlatData, other: TFlatData) -> TFlatData:
4✔
UNCOV
339
        return self.concat(other)
×
340

341
    def __radd__(self: TFlatData, other: TFlatData) -> TFlatData:
4✔
UNCOV
342
        return other.concat(self)
×
343

344
    def __hash__(self):
4✔
345
        return id(self)
4✔
346

347
    def __repr__(self):
4✔
348
        return _flatdata_repr(self)
4✔
349

350
    def _tree_depth(self):
4✔
351
        """Return the depth of the tree of datasets.
352
        Use only to debug performance issues.
353
        """
354
        return _flatdata_depth(self)
×
355

356

357
class ConstantSequence(IDataset[DataT], Sequence[DataT]):
4✔
358
    """A memory-efficient constant sequence."""
4✔
359

360
    def __init__(self, constant_value: DataT, size: int):
4✔
361
        """Constructor
362

363
        :param constant_value: the fixed value
364
        :param size: length of the sequence
365
        """
366
        self._constant_value = constant_value
4✔
367
        self._size = size
4✔
368
        self._can_flatten = False
4✔
369
        self._indices = None
4✔
370

371
    def __len__(self):
4✔
372
        return self._size
4✔
373
    
374
    @overload
4✔
375
    def __getitem__(self, index: int) -> DataT:
4✔
UNCOV
376
        ...
×
377
    
378
    @overload
4✔
379
    def __getitem__(self, index: slice) -> 'ConstantSequence[DataT]':
4✔
UNCOV
380
        ...
×
381

382
    def __getitem__(self, index: Union[int, slice]) -> \
4✔
383
            'Union[DataT, ConstantSequence[DataT]]':
384
        if isinstance(index, (int, np.integer)):
4✔
385
            index = int(index)
4✔
386
        
387
            if index >= len(self):
4✔
388
                raise IndexError()
4✔
389
            return self._constant_value
4✔
390
        else:
UNCOV
391
            slice_indices = slice_alike_object_to_indices(
×
392
                slice_alike_object=index,
393
                max_length=len(self)
394
            )
UNCOV
395
            return ConstantSequence(
×
396
                constant_value=self._constant_value,
397
                size=sum(1 for _ in slice_indices)
398
            )
399

400
    def subset(self, indices: List[int]) -> "ConstantSequence[DataT]":
4✔
401
        """Subset
402

403
        :param indices: indices of the new data.
404
        :return:
405
        """
406
        return ConstantSequence(self._constant_value, len(indices))
4✔
407

408
    def concat(self, other: FlatData[DataT]) -> IDataset[DataT]:
4✔
409
        """Concatenation
410

411
        :param other: other dataset
412
        :return:
413
        """
414
        if (
4✔
415
            isinstance(other, ConstantSequence)
416
            and self._constant_value == other._constant_value
417
        ):
418
            return ConstantSequence(
4✔
419
                self._constant_value, len(self) + len(other)
420
            )
421
        else:
422
            return FlatData([self, other])
4✔
423

424
    def __str__(self):
4✔
UNCOV
425
        return (
×
426
            f"ConstantSequence(value={self._constant_value}, len={self._size})"
427
        )
428

429
    def __hash__(self):
4✔
430
        return id(self)
4✔
431

432

433
def _flatten_dataset_list(
4✔
434
        datasets: List[Union[FlatData[T_co], IDataset[T_co]]]) -> \
435
            List[IDataset[T_co]]:
436
    """Flatten the dataset tree if possible."""
437
    # Concat -> Concat branch
438
    # Flattens by borrowing the list of concatenated datasets
439
    # from the original datasets.
440
    flattened_list: List[IDataset[T_co]] = []
4✔
441
    for dataset in datasets:
4✔
442
        if len(dataset) == 0:
4✔
443
            continue
4✔
444
        elif (
4✔
445
            isinstance(dataset, FlatData)
446
            and dataset._indices is None
447
            and dataset._can_flatten
448
        ):
UNCOV
449
            flattened_list.extend(dataset._datasets)
×
450
        else:
451
            flattened_list.append(dataset)
4✔
452

453
    # merge consecutive Subsets if compatible
454
    new_data_list: List[IDataset[T_co]] = []
4✔
455
    for dataset in flattened_list:
4✔
456
        last_dataset = new_data_list[-1] if len(new_data_list) > 0 else None
4✔
457
        if (
4✔
458
            isinstance(dataset, FlatData)
459
            and len(new_data_list) > 0
460
            and isinstance(last_dataset, FlatData)
461
        ):
462
            new_data_list.pop()
4✔
463
            merged_ds = _maybe_merge_subsets(last_dataset, dataset)
4✔
464
            new_data_list.extend(merged_ds)
4✔
465
        elif (
4✔
466
            (dataset is not None)
467
            and len(new_data_list) > 0
468
            and (last_dataset is not None)
469
            and last_dataset is dataset
470
        ):
471
            new_data_list.pop()
4✔
472
            # the same dataset is repeated, using indices to avoid repeating it
473
            idxs = LazyIndices(LazyRange(0, len(last_dataset)),
4✔
474
                               LazyRange(0, len(last_dataset)))
475
            merged_ds = [FlatData([last_dataset], indices=idxs)]
4✔
476
            new_data_list.extend(merged_ds)
4✔
477
        else:
478
            new_data_list.append(dataset)
4✔
479
    return new_data_list
4✔
480

481

482
def _flatten_datasets_and_reindex(
4✔
483
        datasets: List[IDataset],
484
        indices: Optional[List[int]]) -> \
485
            Tuple[List[IDataset], Optional[List[int]]]:
486
    """The same dataset may occurr multiple times in the list of datasets.
487

488
    Here, we flatten the list of datasets and fix the indices to account for
489
    the new dataset position in the list.
490
    """
491
    # find unique datasets
492
    if not all(isinstance(d, Hashable) for d in datasets):
4✔
493
        return datasets, indices
4✔
494

495
    dset_uniques = set(datasets)
4✔
496
    if len(dset_uniques) == len(datasets):  # no duplicates. Nothing to do.
4✔
497
        return datasets, indices
4✔
498

499
    # split the indices into <dataset-id, sample-id> pairs
500
    cumulative_sizes = [0] + ConcatDataset.cumsum(datasets)
4✔
501
    data_sample_pairs: List[Tuple[int, int]] = []
4✔
502
    if indices is None:
4✔
503
        for ii, dset in enumerate(datasets):
4✔
504
            data_sample_pairs.extend([(ii, jj) for jj in range(len(dset))])
4✔
505
    else:
506
        for idx in indices:
4✔
507
            d_idx = bisect.bisect_right(cumulative_sizes, idx) - 1
4✔
508
            s_idx = idx - cumulative_sizes[d_idx]
4✔
509
            data_sample_pairs.append((d_idx, s_idx))
4✔
510

511
    # assign a new position in the list to each unique dataset
512
    new_datasets = list(dset_uniques)
4✔
513
    new_dpos = {d: i for i, d in enumerate(new_datasets)}
4✔
514
    new_cumsizes = [0] + ConcatDataset.cumsum(new_datasets)
4✔
515
    # reindex the indices to account for the new dataset position
516
    new_indices: List[int] = []
4✔
517
    for d_idx, s_idx in data_sample_pairs:
4✔
518
        new_d_idx = new_dpos[datasets[d_idx]]
4✔
519
        new_indices.append(new_cumsizes[new_d_idx] + s_idx)
4✔
520

521
    # NOTE: check disabled to avoid slowing down OCL scenarios
522
    # if len(new_indices) > 0 and new_cumsizes[-1] > 0:
523
    #     assert min(new_indices) >= 0
524
    #     assert max(new_indices) < new_cumsizes[-1], \
525
    #         f"Max index {max(new_indices)} is greater than datasets " \
526
    #         f"list size {new_cumsizes[-1]}"
527
    return new_datasets, new_indices
4✔
528

529

530
def _maybe_merge_subsets(d1: FlatData, d2: FlatData):
4✔
531
    """Check the conditions for merging subsets."""
532
    if (not d1._can_flatten) or (not d2._can_flatten):
4✔
533
        return [d1, d2]
4✔
534

535
    if (
4✔
536
        len(d1._datasets) == 1
537
        and len(d2._datasets) == 1
538
        and d1._datasets[0] is d2._datasets[0]
539
    ):
540
        # return [d1.__class__(d1._datasets, d1._indices + d2._indices)]
UNCOV
541
        return [d1.concat(d2)]
×
542
    return [d1, d2]
4✔
543

544

545
def _flatdata_depth(dataset):
4✔
546
    """Internal debugging method.
547
    Returns the depth of the dataset tree."""
548
    if isinstance(dataset, FlatData):
4✔
549
        dchilds = [_flatdata_depth(dd) for dd in dataset._datasets]
4✔
550
        if len(dchilds) == 0:
4✔
UNCOV
551
            return 1
×
552
        return 1 + max(dchilds)
4✔
553
    else:
554
        return 1
4✔
555

556

557
def _flatdata_print(dataset, indent=0):
4✔
558
    """Internal debugging method.
559
    Print the dataset."""
560
    print(_flatdata_repr(dataset, indent))
4✔
561

562

563
def _flatdata_repr(dataset, indent=0):
4✔
564
    """Return the string representation of the dataset.
565
    Shows the underlying dataset tree.
566
    """
567
    if isinstance(dataset, FlatData):
4✔
568
        ss = dataset._indices is not None
4✔
569
        cc = len(dataset._datasets) != 1
4✔
570
        cf = dataset._can_flatten
4✔
571
        s = (
4✔
572
            "\t" * indent
573
            + f"{dataset.__class__.__name__} (len={len(dataset)},subset={ss},"
574
            f"cat={cc},cf={cf})\n"
575
        )
576
        for dd in dataset._datasets:
4✔
577
            s += _flatdata_repr(dd, indent + 1)
4✔
578
        return s
4✔
579
    else:
580
        return "\t" * indent + f"{dataset.__class__.__name__} " \
4✔
581
                               f"(len={len(dataset)})\n"
582

583

584
__all__ = ["FlatData", "ConstantSequence"]
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc