• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ContinualAI / avalanche / 8098020118

29 Feb 2024 02:57PM UTC coverage: 51.806% (-12.4%) from 64.17%
8098020118

push

github

web-flow
Update test-coverage-coveralls.yml

14756 of 28483 relevant lines covered (51.81%)

0.52 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

13.37
/tests/benchmarks/utils/test_flat_data.py
1
import sys
1✔
2
import unittest
1✔
3
import random
1✔
4

5
import torch
1✔
6

7
from avalanche.benchmarks import FixedSizeExperienceSplitter
1✔
8
from avalanche.benchmarks.utils import AvalancheDataset, concat_datasets
1✔
9
from avalanche.benchmarks.utils.classification_dataset import (
1✔
10
    TaskAwareClassificationDataset,
11
)
12
from avalanche.benchmarks.utils.flat_data import (
1✔
13
    FlatData,
14
    LazyRange,
15
    _flatten_datasets_and_reindex,
16
    LazyIndices,
17
)
18
from avalanche.benchmarks.utils.flat_data import (
1✔
19
    _flatdata_depth,
20
    _flatdata_print,
21
)
22
from avalanche.training import ReservoirSamplingBuffer
1✔
23
from tests.unit_tests_utils import get_fast_benchmark
1✔
24

25

26
class AvalancheDatasetTests(unittest.TestCase):
1✔
27
    def test_flatdata_subset_concat_stack_overflow(self):
1✔
28
        d_sz = 5
×
29
        x_raw = torch.randint(0, 7, (d_sz,))
×
30
        data = FlatData([x_raw])
×
31
        dataset_hierarchy_depth = 500
×
32

33
        # prepare random permutations for each step
34
        perms = []
×
35
        for _ in range(dataset_hierarchy_depth):
×
36
            idx_permuted = list(range(d_sz))
×
37
            random.shuffle(idx_permuted)
×
38
            perms.append(idx_permuted)
×
39

40
        # compute expected indices after all permutations
41
        current_indices = range(d_sz)
×
42
        true_indices = []
×
43
        true_indices.append(list(current_indices))
×
44
        for idx in range(dataset_hierarchy_depth):
×
45
            current_indices = [current_indices[x] for x in perms[idx]]
×
46
            true_indices.append(current_indices)
×
47
        true_indices = list(reversed(true_indices))
×
48

49
        # apply permutations and concatenations iteratively
50
        curr_dataset = data
×
51
        for idx in range(dataset_hierarchy_depth):
×
52
            # print(idx)
53
            # print(idx, "depth: ", _flatdata_depth(curr_dataset))
54

55
            subset = curr_dataset.subset(indices=perms[idx])
×
56
            # print("SUBSET:")
57
            # _flatdata_print(subset)
58

59
            curr_dataset = subset.concat(curr_dataset)
×
60
            # print("CONCAT:")
61
            # _flatdata_print(curr_dataset)
62

63
        self.assertEqual(d_sz * dataset_hierarchy_depth + d_sz, len(curr_dataset))
×
64
        for idx in range(dataset_hierarchy_depth):
×
65
            leaf_range = range(idx * d_sz, (idx + 1) * d_sz)
×
66
            permuted = true_indices[idx]
×
67

68
            x_leaf = torch.stack([curr_dataset[idx] for idx in leaf_range], dim=0)
×
69
            self.assertTrue(torch.equal(x_raw[permuted], x_leaf))
×
70

71
        slice_idxs = list(range(d_sz * dataset_hierarchy_depth, len(curr_dataset)))
×
72
        x_slice = torch.stack([curr_dataset[idx] for idx in slice_idxs], dim=0)
×
73
        self.assertTrue(torch.equal(x_raw, x_slice))
×
74

75
        # If you broke this test it means that dataset merging is not working
76
        # anymore. you are probably doing something that disable merging
77
        # (passing custom transforms?)
78
        # Good luck...
79
        assert _flatdata_depth(curr_dataset) == 2
×
80

81
    def test_merging(self):
1✔
82
        x = torch.randn(10)
×
83
        fdata = FlatData([x])
×
84

85
        dd = fdata
×
86
        for i in range(5):
×
87
            dd = dd.concat(fdata)
×
88
            assert _flatdata_depth(dd) == 2
×
89
            assert len(dd._datasets) == 1
×
90

91
            idxs = list(range(len(dd)))
×
92
            random.shuffle(idxs)
×
93
            dd_old = dd
×
94
            dd = dd.subset(idxs[:12])
×
95

96
            for i in range(12):
×
97
                assert dd[i] == dd_old[idxs[i]]
×
98

99
            assert _flatdata_depth(dd) == 2
×
100
            assert len(dd._indices) == 12
×
101
            assert len(dd._datasets) == 1
×
102

103

104
class FlatteningTests(unittest.TestCase):
1✔
105
    def test_flatten_and_reindex(self):
1✔
106
        bm = get_fast_benchmark()
×
107
        D1 = bm.train_stream[0].dataset
×
108
        ds, idxs = _flatten_datasets_and_reindex([D1, D1, D1], None)
×
109

110
        print(
×
111
            f"len-ds: {len(ds)}, max={max(idxs)}, min={min(idxs)}, "
112
            f"lens={[len(d) for d in ds]}"
113
        )
114
        assert len(ds) == 1
×
115
        assert len(idxs) == 3 * len(D1)
×
116
        assert max(idxs) == len(D1) - 1
×
117
        assert min(idxs) == 0
×
118

119
    def test_concat_flattens_same_dataset(self):
1✔
120
        D = AvalancheDataset(
×
121
            [[1, 2, 3]],
122
        )
123
        B = concat_datasets([])
×
124
        B = B.concat(D)
×
125
        print(f"DATA depth={_flatdata_depth(B)}, dsets={len(B._datasets)}")
×
126

127
        for _ in range(10):
×
128
            B = D.concat(B)
×
129
            print(f"DATA depth={_flatdata_depth(B)}, dsets={len(B._datasets)}")
×
130
            # assert _flatdata_depth(B) <= 2
131
            # assert len(B._datasets) <= 2
132

133
    def test_concat_flattens_same_dataset_corner_case(self):
1✔
134
        base_dataset = [1, 2, 3]
×
135
        A = FlatData([base_dataset], can_flatten=False, indices=[1, 2])
×
136
        B = FlatData([A])
×
137
        C = A.concat(B)
×
138
        C[3]
×
139
        self.assertListEqual([2, 3, 2, 3], list(C))
×
140

141
        A = FlatData([base_dataset], can_flatten=False)
×
142
        B = FlatData([A], indices=[1, 2])
×
143
        C = A.concat(B)
×
144
        self.assertListEqual([1, 2, 3, 2, 3], list(C))
×
145

146
        A = FlatData([base_dataset], can_flatten=False, indices=[1, 2])
×
147
        B = FlatData([A])
×
148
        C = B.concat(A)
×
149
        self.assertListEqual([2, 3, 2, 3], list(C))
×
150

151
        A = FlatData([base_dataset], can_flatten=False)
×
152
        B = FlatData([A], indices=[1, 2])
×
153
        C = B.concat(A)
×
154
        self.assertListEqual([2, 3, 1, 2, 3], list(C))
×
155

156
    def test_concat_flattens_same_classification_dataset(self):
1✔
157
        D = TaskAwareClassificationDataset([[1, 2, 3]])
×
158
        B = concat_datasets([])
×
159
        B = B.concat(D)
×
160
        B = D.concat(B)
×
161
        print(f"DATA depth={_flatdata_depth(B)}, dsets={len(B._datasets)}")
×
162
        assert _flatdata_depth(B) <= 2
×
163
        assert len(B._datasets) <= 2
×
164
        B = D.concat(B)
×
165
        print(f"DATA depth={_flatdata_depth(B)}, dsets={len(B._datasets)}")
×
166
        assert _flatdata_depth(B) <= 2
×
167
        assert len(B._datasets) <= 2
×
168

169
        B = D.concat(B)
×
170
        print(f"DATA depth={_flatdata_depth(B)}, dsets={len(B._datasets)}")
×
171
        assert _flatdata_depth(B) <= 2
×
172
        assert len(B._datasets) <= 2
×
173

174
    def test_concat_flattens_nc_scenario_dataset(self):
1✔
175
        benchmark = get_fast_benchmark()
×
176
        s = benchmark.train_stream
×
177
        B = concat_datasets([s[1].dataset])
×
178
        D1 = s[0].dataset
×
179

180
        B1 = D1.concat(B)
×
181
        print(f"DATA depth={_flatdata_depth(B1)}, dsets={len(B1._datasets)}")
×
182
        assert len(B1._datasets) <= 2
×
183
        B2 = D1.concat(B1)
×
184
        print(f"DATA depth={_flatdata_depth(B2)}, dsets={len(B2._datasets)}")
×
185
        assert len(B2._datasets) <= 2
×
186
        B3 = D1.concat(B2)
×
187
        print(f"DATA depth={_flatdata_depth(B3)}, dsets={len(B3._datasets)}")
×
188
        assert len(B3._datasets) <= 2
×
189

190
    def test_concat_flattens_nc_scenario_dataset2(self):
1✔
191
        bm = get_fast_benchmark()
×
192
        s = bm.train_stream
×
193

194
        B = concat_datasets([])  # empty dataset
×
195
        D1 = s[0].dataset
×
196
        print(repr(D1))
×
197

198
        D2a = s[1].dataset
×
199
        D2b = s[1].dataset
×
200

201
        B1 = D1.concat(B)
×
202
        print(f"DATA depth={_flatdata_depth(B1)}, dsets={len(B1._datasets)}")
×
203
        print(repr(B1))
×
204
        assert len(B1._datasets) <= 2
×
205
        B2 = D2a.concat(B1)
×
206
        print(f"DATA depth={_flatdata_depth(B2)}, dsets={len(B2._datasets)}")
×
207
        print(repr(B2))
×
208
        assert len(B2._datasets) <= 2
×
209
        B3 = D2b.concat(B2)
×
210
        print(f"DATA depth={_flatdata_depth(B3)}, dsets={len(B3._datasets)}")
×
211
        print(repr(B3))
×
212
        assert len(B3._datasets) <= 2
×
213

214
    def test_flattening_replay_ocl(self):
1✔
215
        benchmark = get_fast_benchmark()
×
216
        buffer = ReservoirSamplingBuffer(100)
×
217

218
        for t, exp in enumerate(
×
219
            FixedSizeExperienceSplitter(benchmark.train_stream[0], 1, None)
220
        ):
221
            buffer.update_from_dataset(exp.dataset)
×
222
            b = buffer.buffer
×
223
            # depths = _flatdata_depth(b)
224
            # lenidxs = len(b._indices)
225
            # lendsets = len(b._datasets)
226
            # print(f"DATA depth={depths}, idxs={lenidxs}, dsets={lendsets}")
227
            #
228
            # atts = [b.targets.data, b.targets_task_labels.data]
229
            # depths = [_flatdata_depth(b) for b in atts]
230
            # lenidxs = [len(b._indices) for b in atts]
231
            # lendsets = [len(b._datasets) for b in atts]
232
            # print(f"(t={t}) ATTS depth={depths}, idxs={lenidxs},
233
            # dsets={lendsets}")
234
            if t > 5:
×
235
                break
×
236
        print(f"DATA depth={_flatdata_depth(b)}, dsets={len(b._datasets)}")
×
237
        assert len(b._datasets) <= 2
×
238

239
        for t, exp in enumerate(
×
240
            FixedSizeExperienceSplitter(benchmark.train_stream[1], 1, None)
241
        ):
242
            buffer.update_from_dataset(exp.dataset)
×
243
            b = buffer.buffer
×
244
            # depths = _flatdata_depth(b)
245
            # lenidxs = len(b._indices)
246
            # lendsets = len(b._datasets)
247
            # print(f"DATA depth={depths}, idxs={lenidxs}, dsets={lendsets}")
248
            #
249
            # atts = [b.targets.data, b.targets_task_labels.data]
250
            # depths = [_flatdata_depth(b) for b in atts]
251
            # lenidxs = [len(b._indices) for b in atts]
252
            # lendsets = [len(b._datasets) for b in atts]
253
            # print(f"(t={t}) ATTS depth={depths}, idxs={lenidxs},
254
            # dsets={lendsets}")
255
            if t > 5:
×
256
                break
×
257
        print(f"DATA depth={_flatdata_depth(b)}, dsets={len(b._datasets)}")
×
258
        assert len(b._datasets) <= 2
×
259

260

261
class LazyIndicesTests(unittest.TestCase):
1✔
262
    def test_basic(self):
1✔
263
        eager = list(range(10))
×
264
        li = LazyIndices(eager)
×
265
        self.assertListEqual(eager, list(li))
×
266
        self.assertEqual(len(eager), len(li))
×
267

268
        li = LazyIndices(eager, eager)
×
269
        self.assertListEqual(eager + eager, list(li))
×
270
        self.assertEqual(len(eager) * 2, len(li))
×
271

272
        li = LazyIndices(eager, offset=7)
×
273
        self.assertListEqual(list([el + 7 for el in eager]), list(li))
×
274
        self.assertEqual(len(eager), len(li))
×
275

276
    def test_range(self):
1✔
277
        eager = list(range(1, 11))
×
278
        li = LazyRange(start=1, end=11)
×
279
        self.assertListEqual(eager, list(li))
×
280
        self.assertEqual(len(eager), len(li))
×
281

282
        eager = list(range(1, 11))
×
283
        li = LazyRange(start=0, end=10, offset=1)
×
284
        self.assertListEqual(eager, list(li))
×
285
        self.assertEqual(len(eager), len(li))
×
286

287
        eager = list(range(8, 18)) + list(range(12, 15))
×
288
        a = LazyRange(start=0, end=10, offset=1)
×
289
        b = LazyRange(start=2, end=5, offset=3)
×
290
        li = LazyIndices(a, b, offset=7)
×
291
        self.assertListEqual(eager, list(li))
×
292
        self.assertEqual(len(eager), len(li))
×
293

294
    def test_recursion(self):
1✔
295
        eager = list(range(10))
×
296

297
        li = LazyIndices(eager, offset=0)
×
298
        # TODO: speed up this test. Can we avoid checking such a high limit?
299
        limit = sys.getrecursionlimit() * 2 + 10
×
300
        for i in range(limit):
×
301
            li = LazyIndices(li, eager, offset=0)
×
302

303
        self.assertEqual(len(eager) * (i + 2), len(li))
×
304
        for el in li:  # keep this to check recursion error
×
305
            pass
×
306

307

308
if __name__ == "__main__":
1✔
309
    unittest.main()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc