• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ContinualAI / avalanche / 4894717306

pending completion
4894717306

push

github

Unknown Committer
Unknown Commit Message

4 of 5 new or added lines in 3 files covered. (80.0%)

6 existing lines in 3 files now uncovered.

15218 of 20556 relevant lines covered (74.03%)

2.96 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.54
/avalanche/benchmarks/utils/flat_data.py
1
################################################################################
2
# Copyright (c) 2022 ContinualAI.                                              #
3
# Copyrights licensed under the MIT License.                                   #
4
# See the accompanying LICENSE file for terms.                                 #
5
#                                                                              #
6
# Date: 19-07-2022                                                             #
7
# Author(s): Antonio Carta                                                     #
8
# E-mail: contact@continualai.org                                              #
9
# Website: avalanche.continualai.org                                           #
10
################################################################################
11
"""
4✔
12
    Datasets with optimized concat/subset operations.
13
"""
14
import bisect
4✔
15

16
import numpy as np
4✔
17

18
from avalanche.benchmarks.utils.dataset_utils import (
4✔
19
    slice_alike_object_to_indices,
20
)
21
try:
4✔
22
    from collections import Hashable
4✔
23
except ImportError:
1✔
24
    from collections.abc import Hashable
1✔
25

26
from typing import (
4✔
27
    Iterable,
28
    List,
29
    Optional,
30
    Sequence,
31
    Tuple,
32
    TypeVar,
33
    Union,
34
    overload,
35
)
36

37
from torch.utils.data import ConcatDataset
4✔
38

39
from avalanche.benchmarks.utils.dataset_definitions import IDataset
4✔
40

41

42
TFlatData = TypeVar('TFlatData', bound='FlatData')
4✔
43
DataT = TypeVar("DataT")
4✔
44
T_co = TypeVar("T_co", covariant=True)
4✔
45

46

47
class FlatData(IDataset[T_co], Sequence[T_co]):
4✔
48
    """FlatData is a dataset optimized for efficient repeated concatenation
4✔
49
    and subset operations.
50

51
    The class combines concatentation and subsampling operations in a single
52
    class.
53

54
    Class for internal use only. Users shuold use `AvalancheDataset` for data
55
    or `DataAttribute` for attributes such as class and task labels.
56

57
    *Notes for subclassing*
58

59
    Cat/Sub operations are "flattened" if possible, which means that they will
60
    take the datasets and indices from their arguments and create a new dataset
61
    with them, avoiding the creation of large trees of dataset (what would
62
    happen with PyTorch datasets). Flattening is not always possible, for
63
    example if the data contains additional info (e.g. custom transformation
64
    groups), so subclasses MUST set `can_flatten` properly in order to avoid
65
    nasty bugs.
66
    """
67

68
    def __init__(
4✔
69
        self,
70
        datasets: Sequence[IDataset[T_co]],
71
        indices: Optional[List[int]] = None,
72
        can_flatten: bool = True,
73
    ):
74
        """Constructor
75

76
        :param datasets: list of datasets to concatenate.
77
        :param indices:  list of indices.
78
        :param can_flatten: if True, enables flattening.
79
        """
80
        self._datasets: List[IDataset[T_co]] = list(datasets)
4✔
81
        self._indices: Optional[List[int]] = indices
4✔
82
        self._can_flatten: bool = can_flatten
4✔
83

84
        if can_flatten:
4✔
85
            self._datasets = _flatten_dataset_list(self._datasets)
4✔
86
            self._datasets, self._indices = _flatten_datasets_and_reindex(
4✔
87
                self._datasets, self._indices)
88
        self._cumulative_sizes = ConcatDataset.cumsum(self._datasets)
4✔
89

90
        # NOTE: check disabled to avoid slowing down OCL scenarios
91
        # # check indices
92
        # if self._indices is not None and len(self) > 0:
93
        #     assert min(self._indices) >= 0
94
        #     assert max(self._indices) < self._cumulative_sizes[-1], \
95
        #         f"Max index {max(self._indices)} is greater than datasets " \
96
        #         f"list size {self._cumulative_sizes[-1]}"
97

98
    def _get_indices(self):
4✔
99
        """This method creates indices on-the-fly if self._indices=None.
100
        Only for internal use. Call may be expensive if self._indices=None.
101
        """
102
        if self._indices is not None:
4✔
103
            return self._indices
4✔
104
        else:
105
            return list(range(len(self)))
4✔
106

107
    def subset(self: TFlatData, indices: Optional[Iterable[int]]) -> TFlatData:
4✔
108
        """Subsampling operation.
109

110
        :param indices: indices of the new samples
111
        :return:
112
        """
113
        if indices is not None and not isinstance(indices, list):
4✔
114
            indices = list(indices)
4✔
115

116
        if self._can_flatten and indices is not None:
4✔
117
            if self._indices is None:
4✔
118
                new_indices = indices
4✔
119
            else:
120
                self_indices = self._get_indices()
4✔
121
                new_indices = [self_indices[x] for x in indices]
4✔
122
            return self.__class__(datasets=self._datasets, indices=new_indices)
4✔
123
        return self.__class__(datasets=[self], indices=indices)
4✔
124

125
    def concat(self: TFlatData, other: TFlatData) -> TFlatData:
4✔
126
        """Concatenation operation.
127

128
        :param other: other dataset.
129
        :return:
130
        """
131
        if (not self._can_flatten) and (not other._can_flatten):
4✔
132
            return self.__class__(datasets=[self, other])
4✔
133

134
        # Case 1: one is a subset of the other
135
        if len(self._datasets) == 1 and len(other._datasets) == 1:
4✔
136
            if self._can_flatten and self._datasets[0] is other:
4✔
137
                return other.subset(self._get_indices() + 
4✔
138
                                    list(range(len(other))))
139
            elif other._can_flatten and other._datasets[0] is self:
4✔
140
                return self.subset(list(range(len(self))) + 
4✔
141
                                   other._get_indices())
142
            elif (
4✔
143
                self._can_flatten
144
                and other._can_flatten
145
                and self._datasets[0] is other._datasets[0]
146
            ):
147
                idxs = self._get_indices() + other._get_indices()
4✔
148
                return self.__class__(datasets=self._datasets, indices=idxs)
4✔
149

150
        # Case 2: at least one of them can be flattened
151
        if self._can_flatten and other._can_flatten:
4✔
152
            if self._indices is None and other._indices is None:
4✔
153
                new_indices = None
4✔
154
            else:
155
                if len(self._cumulative_sizes) == 0:
4✔
156
                    base_other = 0
4✔
157
                else:
158
                    base_other = self._cumulative_sizes[-1]
4✔
159
                new_indices = self._get_indices() + [
4✔
160
                    base_other + idx for idx in other._get_indices()
161
                ]
162
            return self.__class__(
4✔
163
                datasets=self._datasets + other._datasets, indices=new_indices
164
            )
165
        elif self._can_flatten:
4✔
166
            if self._indices is None and other._indices is None:
4✔
167
                new_indices = None
4✔
168
            else:
169
                if len(self._cumulative_sizes) == 0:
4✔
170
                    base_other = 0
4✔
171
                else:
172
                    base_other = self._cumulative_sizes[-1]
4✔
173
                new_indices = self._get_indices() + [
4✔
174
                    base_other + idx for idx in range(len(other))
175
                ]
176
            return self.__class__(
4✔
177
                datasets=self._datasets + [other], indices=new_indices
178
            )
179
        elif other._can_flatten:
4✔
180
            if self._indices is None and other._indices is None:
4✔
181
                new_indices = None
4✔
182
            else:
183
                base_other = len(self)
4✔
184
                self_idxs = list(range(len(self)))
4✔
185
                new_indices = self_idxs + [
4✔
186
                    base_other + idx for idx in other._get_indices()
187
                ]
188
            return self.__class__(
4✔
189
                datasets=[self] + other._datasets, indices=new_indices
190
            )
191
        else:
192
            assert False, "should never get here"
×
193

194
    def _get_idx(self, idx) -> Tuple[int, int]:
4✔
195
        """Return the index as a tuple <dataset-idx, sample-idx>.
196

197
        The first index indicates the dataset to use from `self._datasets`,
198
        while the second is the index of the sample in
199
        `self._datasets[dataset-idx]`.
200

201
        Private method.
202
        """
203
        if self._indices is not None:  # subset indexing
4✔
204
            idx = self._indices[idx]
4✔
205
        if len(self._datasets) == 1:
4✔
206
            dataset_idx = 0
4✔
207
        else:  # concat indexing
208
            dataset_idx = bisect.bisect_right(self._cumulative_sizes, idx)
4✔
209
            if dataset_idx == 0:
4✔
210
                idx = idx
4✔
211
            else:
212
                idx = idx - self._cumulative_sizes[dataset_idx - 1]
4✔
213
        return dataset_idx, int(idx)
4✔
214
    
215
    @overload
4✔
216
    def __getitem__(self, item: int) -> T_co:
4✔
217
        ...
×
218

219
    @overload
4✔
220
    def __getitem__(self: TFlatData, item: slice) -> TFlatData:
4✔
221
        ...
×
222

223
    def __getitem__(self: TFlatData, item: Union[int, slice]) -> \
4✔
224
            Union[T_co, TFlatData]:
225
        if isinstance(item, (int, np.integer)):
4✔
226
            dataset_idx, idx = self._get_idx(int(item))
4✔
227
            return self._datasets[dataset_idx][idx]
4✔
228
        else:
229
            slice_indices = slice_alike_object_to_indices(
4✔
230
                slice_alike_object=item,
231
                max_length=len(self)
232
            )
233

234
            return self.subset(
4✔
235
                indices=slice_indices
236
            )
237

238
    def __len__(self) -> int:
4✔
239
        if len(self._cumulative_sizes) == 0:
4✔
240
            return 0
4✔
241
        elif self._indices is not None:
4✔
242
            return len(self._indices)
4✔
243
        return self._cumulative_sizes[-1]
4✔
244

245
    def __add__(self: TFlatData, other: TFlatData) -> TFlatData:
4✔
246
        return self.concat(other)
4✔
247

248
    def __radd__(self: TFlatData, other: TFlatData) -> TFlatData:
4✔
249
        return other.concat(self)
×
250

251
    def __hash__(self):
4✔
252
        return id(self)
4✔
253

254

255
class ConstantSequence(IDataset[DataT], Sequence[DataT]):
4✔
256
    """A memory-efficient constant sequence."""
4✔
257

258
    def __init__(self, constant_value: DataT, size: int):
4✔
259
        """Constructor
260

261
        :param constant_value: the fixed value
262
        :param size: length of the sequence
263
        """
264
        self._constant_value = constant_value
4✔
265
        self._size = size
4✔
266

267
    def __len__(self):
4✔
268
        return self._size
4✔
269
    
270
    @overload
4✔
271
    def __getitem__(self, index: int) -> DataT:
4✔
272
        ...
×
273
    
274
    @overload
4✔
275
    def __getitem__(self, index: slice) -> 'ConstantSequence[DataT]':
4✔
276
        ...
×
277

278
    def __getitem__(self, index: Union[int, slice]) -> \
4✔
279
            'Union[DataT, ConstantSequence[DataT]]':
280
        if isinstance(index, (int, np.integer)):
4✔
281
            index = int(index)
4✔
282
        
283
            if index >= len(self):
4✔
284
                raise IndexError()
4✔
285
            return self._constant_value
4✔
286
        else:
287
            slice_indices = slice_alike_object_to_indices(
×
288
                slice_alike_object=index,
289
                max_length=len(self)
290
            )
291
            return ConstantSequence(
×
292
                constant_value=self._constant_value,
293
                size=sum(1 for _ in slice_indices)
294
            )
295

296
    def subset(self, indices: List[int]) -> "ConstantSequence[DataT]":
4✔
297
        """Subset
298

299
        :param indices: indices of the new data.
300
        :return:
301
        """
302
        return ConstantSequence(self._constant_value, len(indices))
×
303

304
    def concat(self, other: FlatData[DataT]) -> IDataset[DataT]:
4✔
305
        """Concatenation
306

307
        :param other: other dataset
308
        :return:
309
        """
310
        if (
×
311
            isinstance(other, ConstantSequence)
312
            and self._constant_value == other._constant_value
313
        ):
314
            return ConstantSequence(
×
315
                self._constant_value, len(self) + len(other)
316
            )
317
        else:
318
            return FlatData([self, other])
×
319

320
    def __str__(self):
4✔
321
        return (
×
322
            f"ConstantSequence(value={self._constant_value}, len={self._size}"
323
        )
324

325
    def __hash__(self):
4✔
326
        return id(self)
4✔
327

328

329
def _flatten_dataset_list(
4✔
330
        datasets: List[Union[FlatData[T_co], IDataset[T_co]]]) -> \
331
            List[IDataset[T_co]]:
332
    """Flatten the dataset tree if possible."""
333
    # Concat -> Concat branch
334
    # Flattens by borrowing the list of concatenated datasets
335
    # from the original datasets.
336
    flattened_list: List[IDataset[T_co]] = []
4✔
337
    for dataset in datasets:
4✔
338
        if len(dataset) == 0:
4✔
339
            continue
4✔
340
        elif (
4✔
341
            isinstance(dataset, FlatData)
342
            and dataset._indices is None
343
            and dataset._can_flatten
344
        ):
345
            flattened_list.extend(dataset._datasets)
4✔
346
        else:
347
            flattened_list.append(dataset)
4✔
348

349
    # merge consecutive Subsets if compatible
350
    new_data_list: List[IDataset[T_co]] = []
4✔
351
    for dataset in flattened_list:
4✔
352
        last_dataset = new_data_list[-1] if len(new_data_list) > 0 else None
4✔
353
        if (
4✔
354
            isinstance(dataset, FlatData)
355
            and len(new_data_list) > 0
356
            and isinstance(last_dataset, FlatData)
357
        ):
358
            new_data_list.pop()
4✔
359
            merged_ds = _maybe_merge_subsets(last_dataset, dataset)
4✔
360
            new_data_list.extend(merged_ds)
4✔
361
        else:
362
            new_data_list.append(dataset)
4✔
363
    return new_data_list
4✔
364

365

366
def _flatten_datasets_and_reindex(
4✔
367
        datasets: List[IDataset],
368
        indices: Optional[List[int]]) -> \
369
            Tuple[List[IDataset], Optional[List[int]]]:
370
    """The same dataset may occurr multiple times in the list of datasets.
371

372
    Here, we flatten the list of datasets and fix the indices to account for
373
    the new dataset position in the list.
374
    """
375
    # find unique datasets
376
    if not all(isinstance(d, Hashable) for d in datasets):
4✔
377
        return datasets, indices
4✔
378

379
    dset_uniques = set(datasets)
4✔
380
    if len(dset_uniques) == len(datasets):  # no duplicates. Nothing to do.
4✔
381
        return datasets, indices
4✔
382

383
    # split the indices into <dataset-id, sample-id> pairs
384
    cumulative_sizes = [0] + ConcatDataset.cumsum(datasets)
4✔
385
    data_sample_pairs: List[Tuple[int, int]] = []
4✔
386
    if indices is None:
4✔
387
        for ii, dset in enumerate(datasets):
4✔
388
            data_sample_pairs.extend([(ii, jj) for jj in range(len(dset))])
4✔
389
    else:
390
        for idx in indices:
4✔
391
            d_idx = bisect.bisect_right(cumulative_sizes, idx) - 1
4✔
392
            s_idx = idx - cumulative_sizes[d_idx]
4✔
393
            data_sample_pairs.append((d_idx, s_idx))
4✔
394

395
    # assign a new position in the list to each unique dataset
396
    new_datasets = list(dset_uniques)
4✔
397
    new_dpos = {d: i for i, d in enumerate(new_datasets)}
4✔
398
    new_cumsizes = [0] + ConcatDataset.cumsum(new_datasets)
4✔
399
    # reindex the indices to account for the new dataset position
400
    new_indices: List[int] = []
4✔
401
    for d_idx, s_idx in data_sample_pairs:
4✔
402
        new_d_idx = new_dpos[datasets[d_idx]]
4✔
403
        new_indices.append(new_cumsizes[new_d_idx] + s_idx)
4✔
404

405
    # NOTE: check disabled to avoid slowing down OCL scenarios
406
    # if len(new_indices) > 0 and new_cumsizes[-1] > 0:
407
    #     assert min(new_indices) >= 0
408
    #     assert max(new_indices) < new_cumsizes[-1], \
409
    #         f"Max index {max(new_indices)} is greater than datasets " \
410
    #         f"list size {new_cumsizes[-1]}"
411
    return new_datasets, new_indices
4✔
412

413

414
def _maybe_merge_subsets(d1: FlatData, d2: FlatData):
4✔
415
    """Check the conditions for merging subsets."""
416
    if (not d1._can_flatten) or (not d2._can_flatten):
4✔
417
        return [d1, d2]
4✔
418

UNCOV
419
    if (
×
420
        len(d1._datasets) == 1
421
        and len(d2._datasets) == 1
422
        and d1._datasets[0] is d2._datasets[0]
423
    ):
424
        # return [d1.__class__(d1._datasets, d1._indices + d2._indices)]
UNCOV
425
        return [d1.concat(d2)]
×
UNCOV
426
    return [d1, d2]
×
427

428

429
def _flatdata_depth(dataset):
4✔
430
    """Internal debugging method.
431
    Returns the depth of the dataset tree."""
432
    if isinstance(dataset, FlatData):
4✔
433
        dchilds = [_flatdata_depth(dd) for dd in dataset._datasets]
4✔
434
        if len(dchilds) == 0:
4✔
435
            return 1
×
436
        return 1 + max(dchilds)
4✔
437
    else:
438
        return 1
4✔
439

440

441
def _flatdata_print(dataset, indent=0):
4✔
442
    """Internal debugging method.
443
    Print the dataset."""
444
    if isinstance(dataset, FlatData):
4✔
445
        ss = dataset._indices is not None
4✔
446
        cc = len(dataset._datasets) != 1
4✔
447
        cf = dataset._can_flatten
4✔
448
        print(
4✔
449
            "\t" * indent
450
            + f"{dataset.__class__.__name__} (len={len(dataset)},subset={ss},"
451
            f"cat={cc},cf={cf})"
452
        )
453
        for dd in dataset._datasets:
4✔
454
            _flatdata_print(dd, indent + 1)
4✔
455
    else:
456
        print(
4✔
457
            "\t" * indent + f"{dataset.__class__.__name__} (len={len(dataset)})"
458
        )
459

460

461
__all__ = ["FlatData", "ConstantSequence"]
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc