• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ContinualAI / avalanche / 5257858549

pending completion
5257858549

push

github

web-flow
Merge pull request #1414 from AntonioCarta/replay-flattening

Replay flattening

61 of 68 new or added lines in 5 files covered. (89.71%)

16241 of 22388 relevant lines covered (72.54%)

2.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

84.51
/avalanche/benchmarks/utils/flat_data.py
1
################################################################################
2
# Copyright (c) 2022 ContinualAI.                                              #
3
# Copyrights licensed under the MIT License.                                   #
4
# See the accompanying LICENSE file for terms.                                 #
5
#                                                                              #
6
# Date: 19-07-2022                                                             #
7
# Author(s): Antonio Carta                                                     #
8
# E-mail: contact@continualai.org                                              #
9
# Website: avalanche.continualai.org                                           #
10
################################################################################
11
"""
4✔
12
    Datasets with optimized concat/subset operations.
13
"""
14
import bisect
4✔
15

16
import numpy as np
4✔
17

18
from avalanche.benchmarks.utils.dataset_utils import (
4✔
19
    slice_alike_object_to_indices,
20
)
21
try:
4✔
22
    from collections import Hashable
4✔
23
except ImportError:
1✔
24
    from collections.abc import Hashable
1✔
25

26
from typing import (
4✔
27
    Iterable,
28
    List,
29
    Optional,
30
    Sequence,
31
    Tuple,
32
    TypeVar,
33
    Union,
34
    overload,
35
)
36

37
from torch.utils.data import ConcatDataset
4✔
38

39
from avalanche.benchmarks.utils.dataset_definitions import IDataset
4✔
40

41

42
TFlatData = TypeVar('TFlatData', bound='FlatData')
4✔
43
DataT = TypeVar("DataT")
4✔
44
T_co = TypeVar("T_co", covariant=True)
4✔
45

46

47
class FlatData(IDataset[T_co], Sequence[T_co]):
4✔
48
    """FlatData is a dataset optimized for efficient repeated concatenation
4✔
49
    and subset operations.
50

51
    The class combines concatentation and subsampling operations in a single
52
    class.
53

54
    Class for internal use only. Users shuold use `AvalancheDataset` for data
55
    or `DataAttribute` for attributes such as class and task labels.
56

57
    *Notes for subclassing*
58

59
    Cat/Sub operations are "flattened" if possible, which means that they will
60
    take the datasets and indices from their arguments and create a new dataset
61
    with them, avoiding the creation of large trees of dataset (what would
62
    happen with PyTorch datasets). Flattening is not always possible, for
63
    example if the data contains additional info (e.g. custom transformation
64
    groups), so subclasses MUST set `can_flatten` properly in order to avoid
65
    nasty bugs.
66
    """
67

68
    def __init__(
4✔
69
        self,
70
        datasets: Sequence[IDataset[T_co]],
71
        indices: Optional[List[int]] = None,
72
        can_flatten: bool = True,
73
    ):
74
        """Constructor
75

76
        :param datasets: list of datasets to concatenate.
77
        :param indices:  list of indices.
78
        :param can_flatten: if True, enables flattening.
79
        """
80
        self._datasets: List[IDataset[T_co]] = list(datasets)
4✔
81
        self._indices: Optional[List[int]] = indices
4✔
82
        self._can_flatten: bool = can_flatten
4✔
83

84
        if can_flatten:
4✔
85
            self._datasets = _flatten_dataset_list(self._datasets)
4✔
86
            self._datasets, self._indices = _flatten_datasets_and_reindex(
4✔
87
                self._datasets, self._indices)
88
        self._cumulative_sizes = ConcatDataset.cumsum(self._datasets)
4✔
89

90
        # NOTE: check disabled to avoid slowing down OCL scenarios
91
        # # check indices
92
        # if self._indices is not None and len(self) > 0:
93
        #     assert min(self._indices) >= 0
94
        #     assert max(self._indices) < self._cumulative_sizes[-1], \
95
        #         f"Max index {max(self._indices)} is greater than datasets " \
96
        #         f"list size {self._cumulative_sizes[-1]}"
97

98
    def _get_indices(self):
4✔
99
        """This method creates indices on-the-fly if self._indices=None.
100
        Only for internal use. Call may be expensive if self._indices=None.
101
        """
102
        if self._indices is not None:
4✔
103
            return self._indices
4✔
104
        else:
105
            return list(range(len(self)))
4✔
106

107
    def subset(self: TFlatData, indices: Optional[Iterable[int]]) -> TFlatData:
4✔
108
        """Subsampling operation.
109

110
        :param indices: indices of the new samples
111
        :return:
112
        """
113
        if indices is not None and not isinstance(indices, list):
4✔
114
            indices = list(indices)
4✔
115

116
        if self._can_flatten and indices is not None:
4✔
117
            if self._indices is None:
4✔
118
                new_indices = indices
4✔
119
            else:
120
                self_indices = self._get_indices()
4✔
121
                new_indices = [self_indices[x] for x in indices]
4✔
122
            return self.__class__(datasets=self._datasets, indices=new_indices)
4✔
123
        return self.__class__(datasets=[self], indices=indices)
4✔
124

125
    def concat(self: TFlatData, other: TFlatData) -> TFlatData:
4✔
126
        """Concatenation operation.
127

128
        :param other: other dataset.
129
        :return:
130
        """
131
        if (not self._can_flatten) and (not other._can_flatten):
4✔
132
            return self.__class__(datasets=[self, other])
4✔
133

134
        # Case 1: one is a subset of the other
135
        if len(self._datasets) == 1 and len(other._datasets) == 1:
4✔
136
            if self._can_flatten and self._datasets[0] is other:
4✔
137
                return other.subset(self._get_indices() + 
4✔
138
                                    list(range(len(other))))
139
            elif other._can_flatten and other._datasets[0] is self:
4✔
140
                return self.subset(list(range(len(self))) + 
4✔
141
                                   other._get_indices())
142
            elif (
4✔
143
                self._can_flatten
144
                and other._can_flatten
145
                and self._datasets[0] is other._datasets[0]
146
            ):
147
                idxs = self._get_indices() + other._get_indices()
4✔
148
                return self.__class__(datasets=self._datasets, indices=idxs)
4✔
149

150
        # Case 2: at least one of them can be flattened
151
        if self._can_flatten and other._can_flatten:
4✔
152
            if self._indices is None and other._indices is None:
4✔
153
                new_indices = None
4✔
154
            else:
155
                if len(self._cumulative_sizes) == 0:
4✔
156
                    base_other = 0
×
157
                else:
158
                    base_other = self._cumulative_sizes[-1]
4✔
159
                new_indices = self._get_indices() + [
4✔
160
                    base_other + idx for idx in other._get_indices()
161
                ]
162
            return self.__class__(
4✔
163
                datasets=self._datasets + other._datasets, indices=new_indices
164
            )
165
        elif self._can_flatten:
4✔
166
            if self._indices is None and other._indices is None:
4✔
167
                new_indices = None
4✔
168
            else:
169
                if len(self._cumulative_sizes) == 0:
×
170
                    base_other = 0
×
171
                else:
172
                    base_other = self._cumulative_sizes[-1]
×
173
                new_indices = self._get_indices() + [
×
174
                    base_other + idx for idx in range(len(other))
175
                ]
176
            return self.__class__(
4✔
177
                datasets=self._datasets + [other], indices=new_indices
178
            )
179
        elif other._can_flatten:
×
180
            if self._indices is None and other._indices is None:
×
181
                new_indices = None
×
182
            else:
183
                base_other = len(self)
×
184
                self_idxs = list(range(len(self)))
×
185
                new_indices = self_idxs + [
×
186
                    base_other + idx for idx in other._get_indices()
187
                ]
188
            return self.__class__(
×
189
                datasets=[self] + other._datasets, indices=new_indices
190
            )
191
        else:
192
            assert False, "should never get here"
×
193

194
    def _get_idx(self, idx) -> Tuple[int, int]:
4✔
195
        """Return the index as a tuple <dataset-idx, sample-idx>.
196

197
        The first index indicates the dataset to use from `self._datasets`,
198
        while the second is the index of the sample in
199
        `self._datasets[dataset-idx]`.
200

201
        Private method.
202
        """
203
        if self._indices is not None:  # subset indexing
4✔
204
            idx = self._indices[idx]
4✔
205
        if len(self._datasets) == 1:
4✔
206
            dataset_idx = 0
4✔
207
        else:  # concat indexing
208
            dataset_idx = bisect.bisect_right(self._cumulative_sizes, idx)
4✔
209
            if dataset_idx == 0:
4✔
210
                idx = idx
4✔
211
            else:
212
                idx = idx - self._cumulative_sizes[dataset_idx - 1]
4✔
213
        return dataset_idx, int(idx)
4✔
214
    
215
    @overload
4✔
216
    def __getitem__(self, item: int) -> T_co:
4✔
217
        ...
×
218

219
    @overload
4✔
220
    def __getitem__(self: TFlatData, item: slice) -> TFlatData:
4✔
221
        ...
×
222

223
    def __getitem__(self: TFlatData, item: Union[int, slice]) -> \
4✔
224
            Union[T_co, TFlatData]:
225
        if isinstance(item, (int, np.integer)):
4✔
226
            dataset_idx, idx = self._get_idx(int(item))
4✔
227
            return self._datasets[dataset_idx][idx]
4✔
228
        else:
229
            slice_indices = slice_alike_object_to_indices(
4✔
230
                slice_alike_object=item,
231
                max_length=len(self)
232
            )
233

234
            return self.subset(
4✔
235
                indices=slice_indices
236
            )
237

238
    def __len__(self) -> int:
4✔
239
        if len(self._cumulative_sizes) == 0:
4✔
240
            return 0
4✔
241
        elif self._indices is not None:
4✔
242
            return len(self._indices)
4✔
243
        return self._cumulative_sizes[-1]
4✔
244

245
    def __add__(self: TFlatData, other: TFlatData) -> TFlatData:
4✔
246
        return self.concat(other)
×
247

248
    def __radd__(self: TFlatData, other: TFlatData) -> TFlatData:
4✔
249
        return other.concat(self)
×
250

251
    def __hash__(self):
4✔
252
        return id(self)
4✔
253

254
    def __repr__(self):
4✔
255
        return _flatdata_repr(self)
4✔
256

257
    def _tree_depth(self):
4✔
258
        """Return the depth of the tree of datasets.
259
        Use only to debug performance issues.
260
        """
NEW
261
        return _flatdata_depth(self)
×
262

263

264
class ConstantSequence(IDataset[DataT], Sequence[DataT]):
4✔
265
    """A memory-efficient constant sequence."""
4✔
266

267
    def __init__(self, constant_value: DataT, size: int):
4✔
268
        """Constructor
269

270
        :param constant_value: the fixed value
271
        :param size: length of the sequence
272
        """
273
        self._constant_value = constant_value
4✔
274
        self._size = size
4✔
275

276
    def __len__(self):
4✔
277
        return self._size
4✔
278
    
279
    @overload
4✔
280
    def __getitem__(self, index: int) -> DataT:
4✔
281
        ...
×
282
    
283
    @overload
4✔
284
    def __getitem__(self, index: slice) -> 'ConstantSequence[DataT]':
4✔
285
        ...
×
286

287
    def __getitem__(self, index: Union[int, slice]) -> \
4✔
288
            'Union[DataT, ConstantSequence[DataT]]':
289
        if isinstance(index, (int, np.integer)):
4✔
290
            index = int(index)
4✔
291
        
292
            if index >= len(self):
4✔
293
                raise IndexError()
4✔
294
            return self._constant_value
4✔
295
        else:
296
            slice_indices = slice_alike_object_to_indices(
×
297
                slice_alike_object=index,
298
                max_length=len(self)
299
            )
300
            return ConstantSequence(
×
301
                constant_value=self._constant_value,
302
                size=sum(1 for _ in slice_indices)
303
            )
304

305
    def subset(self, indices: List[int]) -> "ConstantSequence[DataT]":
4✔
306
        """Subset
307

308
        :param indices: indices of the new data.
309
        :return:
310
        """
311
        return ConstantSequence(self._constant_value, len(indices))
×
312

313
    def concat(self, other: FlatData[DataT]) -> IDataset[DataT]:
4✔
314
        """Concatenation
315

316
        :param other: other dataset
317
        :return:
318
        """
319
        if (
×
320
            isinstance(other, ConstantSequence)
321
            and self._constant_value == other._constant_value
322
        ):
323
            return ConstantSequence(
×
324
                self._constant_value, len(self) + len(other)
325
            )
326
        else:
327
            return FlatData([self, other])
×
328

329
    def __str__(self):
4✔
330
        return (
×
331
            f"ConstantSequence(value={self._constant_value}, len={self._size}"
332
        )
333

334
    def __hash__(self):
4✔
335
        return id(self)
4✔
336

337

338
def _flatten_dataset_list(
4✔
339
        datasets: List[Union[FlatData[T_co], IDataset[T_co]]]) -> \
340
            List[IDataset[T_co]]:
341
    """Flatten the dataset tree if possible."""
342
    # Concat -> Concat branch
343
    # Flattens by borrowing the list of concatenated datasets
344
    # from the original datasets.
345
    flattened_list: List[IDataset[T_co]] = []
4✔
346
    for dataset in datasets:
4✔
347
        if len(dataset) == 0:
4✔
348
            continue
4✔
349
        elif (
4✔
350
            isinstance(dataset, FlatData)
351
            and dataset._indices is None
352
            and dataset._can_flatten
353
        ):
354
            flattened_list.extend(dataset._datasets)
×
355
        else:
356
            flattened_list.append(dataset)
4✔
357

358
    # merge consecutive Subsets if compatible
359
    new_data_list: List[IDataset[T_co]] = []
4✔
360
    for dataset in flattened_list:
4✔
361
        last_dataset = new_data_list[-1] if len(new_data_list) > 0 else None
4✔
362
        if (
4✔
363
            isinstance(dataset, FlatData)
364
            and len(new_data_list) > 0
365
            and isinstance(last_dataset, FlatData)
366
        ):
367
            new_data_list.pop()
4✔
368
            merged_ds = _maybe_merge_subsets(last_dataset, dataset)
4✔
369
            new_data_list.extend(merged_ds)
4✔
370
        elif (
4✔
371
            (dataset is not None)
372
            and len(new_data_list) > 0
373
            and (last_dataset is not None)
374
            and last_dataset is dataset
375
        ):
376
            new_data_list.pop()
4✔
377
            # the same dataset is repeated, using indices to avoid repeating it
378
            idxs = list(list(range(len(last_dataset))) * 2)
4✔
379
            merged_ds = [FlatData([last_dataset], indices=idxs)]
4✔
380
            new_data_list.extend(merged_ds)
4✔
381
        else:
382
            new_data_list.append(dataset)
4✔
383
    return new_data_list
4✔
384

385

386
def _flatten_datasets_and_reindex(
4✔
387
        datasets: List[IDataset],
388
        indices: Optional[List[int]]) -> \
389
            Tuple[List[IDataset], Optional[List[int]]]:
390
    """The same dataset may occurr multiple times in the list of datasets.
391

392
    Here, we flatten the list of datasets and fix the indices to account for
393
    the new dataset position in the list.
394
    """
395
    # find unique datasets
396
    if not all(isinstance(d, Hashable) for d in datasets):
4✔
397
        return datasets, indices
4✔
398

399
    dset_uniques = set(datasets)
4✔
400
    if len(dset_uniques) == len(datasets):  # no duplicates. Nothing to do.
4✔
401
        return datasets, indices
4✔
402

403
    # split the indices into <dataset-id, sample-id> pairs
404
    cumulative_sizes = [0] + ConcatDataset.cumsum(datasets)
4✔
405
    data_sample_pairs: List[Tuple[int, int]] = []
4✔
406
    if indices is None:
4✔
407
        for ii, dset in enumerate(datasets):
4✔
408
            data_sample_pairs.extend([(ii, jj) for jj in range(len(dset))])
4✔
409
    else:
410
        for idx in indices:
×
411
            d_idx = bisect.bisect_right(cumulative_sizes, idx) - 1
×
412
            s_idx = idx - cumulative_sizes[d_idx]
×
413
            data_sample_pairs.append((d_idx, s_idx))
×
414

415
    # assign a new position in the list to each unique dataset
416
    new_datasets = list(dset_uniques)
4✔
417
    new_dpos = {d: i for i, d in enumerate(new_datasets)}
4✔
418
    new_cumsizes = [0] + ConcatDataset.cumsum(new_datasets)
4✔
419
    # reindex the indices to account for the new dataset position
420
    new_indices: List[int] = []
4✔
421
    for d_idx, s_idx in data_sample_pairs:
4✔
422
        new_d_idx = new_dpos[datasets[d_idx]]
4✔
423
        new_indices.append(new_cumsizes[new_d_idx] + s_idx)
4✔
424

425
    # NOTE: check disabled to avoid slowing down OCL scenarios
426
    # if len(new_indices) > 0 and new_cumsizes[-1] > 0:
427
    #     assert min(new_indices) >= 0
428
    #     assert max(new_indices) < new_cumsizes[-1], \
429
    #         f"Max index {max(new_indices)} is greater than datasets " \
430
    #         f"list size {new_cumsizes[-1]}"
431
    return new_datasets, new_indices
4✔
432

433

434
def _maybe_merge_subsets(d1: FlatData, d2: FlatData):
4✔
435
    """Check the conditions for merging subsets."""
436
    if (not d1._can_flatten) or (not d2._can_flatten):
4✔
437
        return [d1, d2]
4✔
438

439
    if (
4✔
440
        len(d1._datasets) == 1
441
        and len(d2._datasets) == 1
442
        and d1._datasets[0] is d2._datasets[0]
443
    ):
444
        # return [d1.__class__(d1._datasets, d1._indices + d2._indices)]
445
        return [d1.concat(d2)]
4✔
446
    return [d1, d2]
4✔
447

448

449
def _flatdata_depth(dataset):
4✔
450
    """Internal debugging method.
451
    Returns the depth of the dataset tree."""
452
    if isinstance(dataset, FlatData):
4✔
453
        dchilds = [_flatdata_depth(dd) for dd in dataset._datasets]
4✔
454
        if len(dchilds) == 0:
4✔
455
            return 1
×
456
        return 1 + max(dchilds)
4✔
457
    else:
458
        return 1
4✔
459

460

461
def _flatdata_print(dataset, indent=0):
4✔
462
    """Internal debugging method.
463
    Print the dataset."""
464
    print(_flatdata_repr(dataset, indent))
4✔
465

466

467
def _flatdata_repr(dataset, indent=0):
4✔
468
    """Return the string representation of the dataset.
469
    Shows the underlying dataset tree.
470
    """
471
    if isinstance(dataset, FlatData):
4✔
472
        ss = dataset._indices is not None
4✔
473
        cc = len(dataset._datasets) != 1
4✔
474
        cf = dataset._can_flatten
4✔
475
        s = (
4✔
476
            "\t" * indent
477
            + f"{dataset.__class__.__name__} (len={len(dataset)},subset={ss},"
478
            f"cat={cc},cf={cf})\n"
479
        )
480
        for dd in dataset._datasets:
4✔
481
            s += _flatdata_repr(dd, indent + 1)
4✔
482
        return s
4✔
483
    else:
484
        return "\t" * indent + f"{dataset.__class__.__name__} " \
4✔
485
                               f"(len={len(dataset)})\n"
486

487

488
__all__ = ["FlatData", "ConstantSequence"]
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc