• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ContinualAI / avalanche / 5268393053

pending completion
5268393053

Pull #1397

github

web-flow
Merge 60d244754 into e91562200
Pull Request #1397: Specialize benchmark creation helpers

417 of 538 new or added lines in 30 files covered. (77.51%)

43 existing lines in 5 files now uncovered.

16586 of 22630 relevant lines covered (73.29%)

2.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

99.39
/tests/benchmarks/test_flat_data.py
1
import unittest
4✔
2
import random
4✔
3

4
import torch
4✔
5

6
from avalanche.benchmarks import fixed_size_experience_split
4✔
7
from avalanche.benchmarks.utils import AvalancheDataset, \
4✔
8
    concat_datasets
9
from avalanche.benchmarks.utils.classification_dataset import \
4✔
10
    ClassificationDataset
11
from avalanche.benchmarks.utils.flat_data import FlatData, \
4✔
12
    _flatten_datasets_and_reindex
13
from avalanche.benchmarks.utils.flat_data import (
4✔
14
    _flatdata_depth,
15
    _flatdata_print,
16
)
17
from avalanche.training import ReservoirSamplingBuffer
4✔
18
from tests.unit_tests_utils import get_fast_benchmark
4✔
19

20

21
class AvalancheDatasetTests(unittest.TestCase):
4✔
22
    def test_flatdata_subset_concat_stack_overflow(self):
4✔
23
        d_sz = 5
4✔
24
        x_raw = torch.randint(0, 7, (d_sz,))
4✔
25
        data = FlatData([x_raw])
4✔
26
        dataset_hierarchy_depth = 500
4✔
27

28
        # prepare random permutations for each step
29
        perms = []
4✔
30
        for _ in range(dataset_hierarchy_depth):
4✔
31
            idx_permuted = list(range(d_sz))
4✔
32
            random.shuffle(idx_permuted)
4✔
33
            perms.append(idx_permuted)
4✔
34

35
        # compute expected indices after all permutations
36
        current_indices = range(d_sz)
4✔
37
        true_indices = []
4✔
38
        true_indices.append(list(current_indices))
4✔
39
        for idx in range(dataset_hierarchy_depth):
4✔
40
            current_indices = [current_indices[x] for x in perms[idx]]
4✔
41
            true_indices.append(current_indices)
4✔
42
        true_indices = list(reversed(true_indices))
4✔
43

44
        # apply permutations and concatenations iteratively
45
        curr_dataset = data
4✔
46
        for idx in range(dataset_hierarchy_depth):
4✔
47
            # print(idx)
48
            # print(idx, "depth: ", _flatdata_depth(curr_dataset))
49

50
            subset = curr_dataset.subset(indices=perms[idx])
4✔
51
            # print("SUBSET:")
52
            # _flatdata_print(subset)
53

54
            curr_dataset = subset.concat(curr_dataset)
4✔
55
            # print("CONCAT:")
56
            # _flatdata_print(curr_dataset)
57

58
        self.assertEqual(
4✔
59
            d_sz * dataset_hierarchy_depth + d_sz, len(curr_dataset)
60
        )
61
        for idx in range(dataset_hierarchy_depth):
4✔
62
            leaf_range = range(idx * d_sz, (idx + 1) * d_sz)
4✔
63
            permuted = true_indices[idx]
4✔
64

65
            x_leaf = torch.stack(
4✔
66
                [curr_dataset[idx] for idx in leaf_range], dim=0
67
            )
68
            self.assertTrue(torch.equal(x_raw[permuted], x_leaf))
4✔
69

70
        slice_idxs = list(
4✔
71
            range(d_sz * dataset_hierarchy_depth, len(curr_dataset))
72
        )
73
        x_slice = torch.stack([curr_dataset[idx] for idx in slice_idxs], dim=0)
4✔
74
        self.assertTrue(torch.equal(x_raw, x_slice))
4✔
75

76
        # If you broke this test it means that dataset merging is not working
77
        # anymore. you are probably doing something that disable merging
78
        # (passing custom transforms?)
79
        # Good luck...
80
        assert _flatdata_depth(curr_dataset) == 2
4✔
81

82
    def test_merging(self):
4✔
83
        x = torch.randn(10)
4✔
84
        fdata = FlatData([x])
4✔
85

86
        dd = fdata
4✔
87
        for i in range(5):
4✔
88
            dd = dd.concat(fdata)
4✔
89
            assert _flatdata_depth(dd) == 2
4✔
90
            assert len(dd._datasets) == 1
4✔
91

92
            idxs = list(range(len(dd)))
4✔
93
            random.shuffle(idxs)
4✔
94
            dd_old = dd
4✔
95
            dd = dd.subset(idxs[:12])
4✔
96

97
            for i in range(12):
4✔
98
                assert dd[i] == dd_old[idxs[i]]
4✔
99

100
            assert _flatdata_depth(dd) == 2
4✔
101
            assert len(dd._indices) == 12
4✔
102
            assert len(dd._datasets) == 1
4✔
103

104

105
class FlatteningTests(unittest.TestCase):
4✔
106
    def test_flatten_and_reindex(self):
4✔
107
        bm = get_fast_benchmark()
4✔
108
        D1 = bm.train_stream[0].dataset
4✔
109
        ds, idxs = _flatten_datasets_and_reindex([D1, D1, D1], None)
4✔
110

111
        print(f"len-ds: {len(ds)}, max={max(idxs)}, min={min(idxs)}, "
4✔
112
              f"lens={[len(d) for d in ds]}")
113
        assert len(ds) == 1
4✔
114
        assert len(idxs) == 3 * len(D1)
4✔
115
        assert max(idxs) == len(D1) - 1
4✔
116
        assert min(idxs) == 0
4✔
117

118
    def test_concat_flattens_same_dataset(self):
4✔
119
        D = AvalancheDataset([[1, 2, 3]],)
4✔
120
        B = concat_datasets([])
4✔
121
        B = B.concat(D)
4✔
122
        print(f"DATA depth={_flatdata_depth(B)}, dsets={len(B._datasets)}")
4✔
123

124
        for _ in range(10):
4✔
125
            B = D.concat(B)
4✔
126
            print(f"DATA depth={_flatdata_depth(B)}, dsets={len(B._datasets)}")
4✔
127
            # assert _flatdata_depth(B) <= 2
128
            # assert len(B._datasets) <= 2
129

130
    def test_concat_flattens_same_dataset_corner_case(self):
4✔
131
        base_dataset = [1, 2, 3]
4✔
132
        A = FlatData([base_dataset], can_flatten=False, indices=[1, 2])
4✔
133
        B = FlatData([A])
4✔
134
        C = A.concat(B)
4✔
135
        C[3]
4✔
136
        self.assertListEqual([2, 3, 2, 3], list(C))
4✔
137

138
        A = FlatData([base_dataset], can_flatten=False)
4✔
139
        B = FlatData([A], indices=[1, 2])
4✔
140
        C = A.concat(B)
4✔
141
        self.assertListEqual([1, 2, 3, 2, 3], list(C))
4✔
142

143
        A = FlatData([base_dataset], can_flatten=False, indices=[1, 2])
4✔
144
        B = FlatData([A])
4✔
145
        C = B.concat(A)
4✔
146
        self.assertListEqual([2, 3, 2, 3], list(C))
4✔
147

148
        A = FlatData([base_dataset], can_flatten=False)
4✔
149
        B = FlatData([A], indices=[1, 2])
4✔
150
        C = B.concat(A)
4✔
151
        self.assertListEqual([2, 3, 1, 2, 3], list(C))
4✔
152

153
    def test_concat_flattens_same_avalanche_dataset(self):
4✔
154
        D = AvalancheDataset([[1, 2, 3]])
4✔
155
        B = concat_datasets([])
4✔
156
        B = B.concat(D)
4✔
157
        B = D.concat(B)
4✔
158
        print(f"DATA depth={_flatdata_depth(B)}, dsets={len(B._datasets)}")
4✔
159
        assert _flatdata_depth(B) <= 2
4✔
160
        assert len(B._datasets) <= 2
4✔
161
        B = D.concat(B)
4✔
162
        print(f"DATA depth={_flatdata_depth(B)}, dsets={len(B._datasets)}")
4✔
163
        assert _flatdata_depth(B) <= 2
4✔
164
        assert len(B._datasets) <= 2
4✔
165

166
        B = D.concat(B)
4✔
167
        print(f"DATA depth={_flatdata_depth(B)}, dsets={len(B._datasets)}")
4✔
168
        assert _flatdata_depth(B) <= 2
4✔
169
        assert len(B._datasets) <= 2
4✔
170

171
    def test_concat_flattens_nc_scenario_dataset(self):
4✔
172
        benchmark = get_fast_benchmark()
4✔
173
        s = benchmark.train_stream
4✔
174
        B = concat_datasets([s[1].dataset])
4✔
175
        D1 = s[0].dataset
4✔
176

177
        B1 = D1.concat(B)
4✔
178
        print(f"DATA depth={_flatdata_depth(B1)}, dsets={len(B1._datasets)}")
4✔
179
        assert len(B1._datasets) <= 2
4✔
180
        B2 = D1.concat(B1)
4✔
181
        print(f"DATA depth={_flatdata_depth(B2)}, dsets={len(B2._datasets)}")
4✔
182
        assert len(B2._datasets) <= 2
4✔
183
        B3 = D1.concat(B2)
4✔
184
        print(f"DATA depth={_flatdata_depth(B3)}, dsets={len(B3._datasets)}")
4✔
185
        assert len(B3._datasets) <= 2
4✔
186

187
    def test_concat_flattens_nc_scenario_dataset2(self):
4✔
188
        bm = get_fast_benchmark()
4✔
189
        s = bm.train_stream
4✔
190

191
        B = concat_datasets([])  # empty dataset
4✔
192
        D1 = s[0].dataset
4✔
193
        print(repr(D1))
4✔
194

195
        D2a = s[1].dataset
4✔
196
        D2b = s[1].dataset
4✔
197

198
        B1 = D1.concat(B)
4✔
199
        print(f"DATA depth={_flatdata_depth(B1)}, dsets={len(B1._datasets)}")
4✔
200
        print(repr(B1))
4✔
201
        assert len(B1._datasets) <= 2
4✔
202
        B2 = D2a.concat(B1)
4✔
203
        print(f"DATA depth={_flatdata_depth(B2)}, dsets={len(B2._datasets)}")
4✔
204
        print(repr(B2))
4✔
205
        assert len(B2._datasets) <= 2
4✔
206
        B3 = D2b.concat(B2)
4✔
207
        print(f"DATA depth={_flatdata_depth(B3)}, dsets={len(B3._datasets)}")
4✔
208
        print(repr(B3))
4✔
209
        assert len(B3._datasets) <= 2
4✔
210

211
    def test_flattening_replay_ocl(self):
4✔
212
        benchmark = get_fast_benchmark()
4✔
213
        buffer = ReservoirSamplingBuffer(100)
4✔
214

215
        for t, exp in enumerate(fixed_size_experience_split(
4✔
216
                benchmark.train_stream[0], 1, None)):
217
            buffer.update_from_dataset(exp.dataset)
4✔
218
            b = buffer.buffer
4✔
219
            # depths = _flatdata_depth(b)
220
            # lenidxs = len(b._indices)
221
            # lendsets = len(b._datasets)
222
            # print(f"DATA depth={depths}, idxs={lenidxs}, dsets={lendsets}")
223
            #
224
            # atts = [b.targets.data, b.targets_task_labels.data]
225
            # depths = [_flatdata_depth(b) for b in atts]
226
            # lenidxs = [len(b._indices) for b in atts]
227
            # lendsets = [len(b._datasets) for b in atts]
228
            # print(f"(t={t}) ATTS depth={depths}, idxs={lenidxs},
229
            # dsets={lendsets}")
230
            if t > 5:
4✔
231
                break
4✔
232
        print(f"DATA depth={_flatdata_depth(b)}, dsets={len(b._datasets)}")
4✔
233
        assert len(b._datasets) <= 2
4✔
234

235
        for t, exp in enumerate(fixed_size_experience_split(
4✔
236
                benchmark.train_stream[1], 1, None)):
237
            buffer.update_from_dataset(exp.dataset)
4✔
238
            b = buffer.buffer
4✔
239
            # depths = _flatdata_depth(b)
240
            # lenidxs = len(b._indices)
241
            # lendsets = len(b._datasets)
242
            # print(f"DATA depth={depths}, idxs={lenidxs}, dsets={lendsets}")
243
            #
244
            # atts = [b.targets.data, b.targets_task_labels.data]
245
            # depths = [_flatdata_depth(b) for b in atts]
246
            # lenidxs = [len(b._indices) for b in atts]
247
            # lendsets = [len(b._datasets) for b in atts]
248
            # print(f"(t={t}) ATTS depth={depths}, idxs={lenidxs},
249
            # dsets={lendsets}")
250
            if t > 5:
4✔
251
                break
4✔
252
        print(f"DATA depth={_flatdata_depth(b)}, dsets={len(b._datasets)}")
4✔
253
        assert len(b._datasets) <= 2
4✔
254

255

256
if __name__ == "__main__":
4✔
UNCOV
257
    unittest.main()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc