• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ContinualAI / avalanche / 8098020118

29 Feb 2024 02:57PM UTC coverage: 51.806% (-12.4%) from 64.17%
8098020118

push

github

web-flow
Update test-coverage-coveralls.yml

14756 of 28483 relevant lines covered (51.81%)

0.52 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

17.18
/tests/benchmarks/utils/test_transformations.py
1
import copy
1✔
2
import unittest
1✔
3
from avalanche.benchmarks.datasets.dataset_utils import default_dataset_location
1✔
4
from avalanche.benchmarks.utils.data import AvalancheDataset
1✔
5
from avalanche.benchmarks.utils.dataset_traversal_utils import single_flat_dataset
1✔
6
from avalanche.benchmarks.utils.detection_dataset import DetectionDataset
1✔
7
from avalanche.benchmarks.classic.cmnist import SplitMNIST
1✔
8
from avalanche.benchmarks.utils.transform_groups import TransformGroups
1✔
9

10
from avalanche.benchmarks.utils.transforms import (
1✔
11
    MultiParamCompose,
12
    MultiParamTransformCallable,
13
    TupleTransform,
14
    flat_transforms_recursive,
15
)
16

17
import torch
1✔
18
from PIL import ImageChops
1✔
19
from torch import Tensor
1✔
20
from torch.utils.data import DataLoader, ConcatDataset
1✔
21
from torchvision.datasets import MNIST
1✔
22
from torchvision.transforms import (
1✔
23
    ToTensor,
24
    Compose,
25
    CenterCrop,
26
    Normalize,
27
    Lambda,
28
    RandomHorizontalFlip,
29
)
30
from torchvision.transforms.functional import to_tensor
1✔
31
from PIL.Image import Image
1✔
32

33
from tests.unit_tests_utils import get_fast_detection_datasets
1✔
34

35

36
def pil_images_equal(img_a, img_b):
1✔
37
    diff = ImageChops.difference(img_a, img_b)
×
38

39
    return not diff.getbbox()
×
40

41

42
def zero_if_label_2(img_tensor: Tensor, class_label):
1✔
43
    if int(class_label) == 2:
×
44
        torch.full(img_tensor.shape, 0.0, out=img_tensor)
×
45

46
    return img_tensor, class_label
×
47

48

49
def get_mbatch(data, batch_size=5):
1✔
50
    dl = DataLoader(
×
51
        data, shuffle=False, batch_size=batch_size, collate_fn=data.collate_fn
52
    )
53
    return next(iter(dl))
×
54

55

56
class TransformsTest(unittest.TestCase):
1✔
57
    def test_multi_param_transform_callable(self):
1✔
58
        dataset: DetectionDataset
59
        dataset, _ = get_fast_detection_datasets()
×
60

61
        boxes = []
×
62
        i = 0
×
63
        while len(boxes) == 0:
×
64
            x_orig, y_orig, t_orig = dataset[i]
×
65
            boxes = y_orig["boxes"]
×
66
            i += 1
×
67
        i -= 1
×
68

69
        x_expect = to_tensor(copy.deepcopy(x_orig))
×
70
        x_expect[0][0] += 1
×
71

72
        y_expect = copy.deepcopy(y_orig)
×
73
        y_expect["boxes"][0][0] += 1
×
74

75
        def do_something_xy(img, target):
×
76
            img = to_tensor(img)
×
77
            img[0][0] += 1
×
78
            target["boxes"][0][0] += 1
×
79
            return img, target
×
80

81
        uut = MultiParamTransformCallable(do_something_xy)
×
82

83
        # Test __eq__
84
        uut_eq = MultiParamTransformCallable(do_something_xy)
×
85
        self.assertTrue(uut == uut_eq)
×
86
        self.assertTrue(uut_eq == uut)
×
87

88
        x, y, t = uut(*dataset[i])
×
89

90
        self.assertIsInstance(x, torch.Tensor)
×
91
        self.assertIsInstance(y, dict)
×
92
        self.assertIsInstance(t, int)
×
93

94
        self.assertTrue(torch.equal(x_expect, x))
×
95
        keys = set(y_expect.keys())
×
96
        self.assertSetEqual(keys, set(y.keys()))
×
97

98
        for k in keys:
×
99
            self.assertTrue(torch.equal(y_expect[k], y[k]), msg=f"Wrong {k}")
×
100

101
    def test_multi_param_compose(self):
1✔
102
        dataset: DetectionDataset
103
        dataset, _ = get_fast_detection_datasets()
×
104

105
        assert_called = 0
×
106

107
        def do_something_xy(img: Tensor, target):
×
108
            nonlocal assert_called
109
            assert_called += 1
×
110
            img = img.clone()
×
111
            img[0][0] += 1
×
112
            target["boxes"][0][0] += 1
×
113
            return img, target
×
114

115
        t_x = lambda x, y: (to_tensor(x), y)
×
116
        t_xy = do_something_xy
×
117
        t_x_1_element = ToTensor()
×
118

119
        boxes = []
×
120
        i = 0
×
121
        while len(boxes) == 0:
×
122
            x_orig, y_orig, t_orig = dataset[i]
×
123
            boxes = y_orig["boxes"]
×
124
            i += 1
×
125
        i -= 1
×
126

127
        x_expect = to_tensor(copy.deepcopy(x_orig))
×
128
        x_expect[0][0] += 1
×
129

130
        y_expect = copy.deepcopy(y_orig)
×
131
        y_expect["boxes"][0][0] += 1
×
132

133
        uut_2 = MultiParamCompose([t_x, t_xy])
×
134

135
        # Test __eq__
136
        uut_2_eq = MultiParamCompose([t_x, t_xy])
×
137
        self.assertTrue(uut_2 == uut_2_eq)
×
138
        self.assertTrue(uut_2_eq == uut_2)
×
139

140
        with self.assertWarns(Warning):
×
141
            # Assert that the following warn is raised:
142
            # "Transformations define a different number of parameters. ..."
143
            uut_1 = MultiParamCompose([t_x_1_element, t_xy])
×
144

145
        for uut, uut_type in zip((uut_1, uut_2), ("uut_1", "uut_2")):
×
146
            with self.subTest(uut_type=uut_type):
×
147
                initial_assert_called = assert_called
×
148

149
                x, y, t = uut(*dataset[i])
×
150

151
                self.assertEqual(initial_assert_called + 1, assert_called)
×
152

153
                self.assertIsInstance(x, torch.Tensor)
×
154
                self.assertIsInstance(y, dict)
×
155
                self.assertIsInstance(t, int)
×
156

157
                self.assertTrue(torch.equal(x_expect, x))
×
158
                keys = set(y_expect.keys())
×
159
                self.assertSetEqual(keys, set(y.keys()))
×
160

161
                for k in keys:
×
162
                    self.assertTrue(torch.equal(y_expect[k], y[k]), msg=f"Wrong {k}")
×
163

164
    def test_tuple_transform(self):
1✔
165
        dataset = MNIST(root=default_dataset_location("mnist"), download=True)
×
166

167
        t_x = ToTensor()
×
168
        t_y = lambda element: element + 1
×
169
        t_bad = lambda element: element - 1
×
170

171
        uut = TupleTransform([t_x, t_y])
×
172

173
        uut_eq = TupleTransform(
×
174
            (t_x, t_y)  # Also test with a tuple instead of a list here
175
        )
176

177
        uut_not_x = TupleTransform([None, t_y])
×
178

179
        uut_bad = TupleTransform((t_x, t_y, t_bad))
×
180

181
        x_orig, y_orig = dataset[0]
×
182

183
        # Test with x transform
184
        x, y = uut(*dataset[0])
×
185

186
        self.assertIsInstance(x, torch.Tensor)
×
187
        self.assertIsInstance(y, int)
×
188

189
        self.assertTrue(torch.equal(to_tensor(x_orig), x))
×
190
        self.assertEqual(y_orig + 1, y)
×
191

192
        # Test without x transform
193
        x, y = uut_not_x(*dataset[0])
×
194

195
        self.assertIsInstance(x, Image)
×
196
        self.assertIsInstance(y, int)
×
197

198
        self.assertEqual(x_orig, x)
×
199
        self.assertEqual(y_orig + 1, y)
×
200

201
        # Check __eq__ works
202
        self.assertTrue(uut == uut_eq)
×
203
        self.assertTrue(uut_eq == uut)
×
204

205
        self.assertFalse(uut == uut_not_x)
×
206
        self.assertFalse(uut_not_x == uut)
×
207

208
        with self.assertRaises(Exception):
×
209
            # uut_bad has 3 transforms, which is incorrect
210
            uut_bad(*dataset[0])
×
211

212
    def test_flat_transforms_recursive_only_torchvision(self):
1✔
213
        x_transform = ToTensor()
×
214
        x_transform_list = [CenterCrop(24), Normalize(0.5, 0.1)]
×
215
        x_transform_composed = Compose(x_transform_list)
×
216

217
        expected_x = [x_transform] + x_transform_list
×
218

219
        # Single transforms checks
220
        self.assertSequenceEqual(
×
221
            [x_transform], flat_transforms_recursive([x_transform], 0)
222
        )
223

224
        self.assertSequenceEqual(
×
225
            [x_transform], flat_transforms_recursive(x_transform, 0)
226
        )
227

228
        self.assertSequenceEqual(
×
229
            x_transform_list, flat_transforms_recursive(x_transform_list, 0)
230
        )
231

232
        self.assertSequenceEqual(
×
233
            x_transform_list, flat_transforms_recursive(x_transform_composed, 0)
234
        )
235

236
        # Hybrid list checks
237
        self.assertSequenceEqual(
×
238
            expected_x,
239
            flat_transforms_recursive([x_transform, x_transform_composed], 0),
240
        )
241

242
    def test_flat_transforms_recursive_from_dataset(self):
1✔
243
        x_transform = ToTensor()
×
244
        x_transform_list = [CenterCrop(24), Normalize(0.5, 0.1)]
×
245
        x_transform_additional = RandomHorizontalFlip(p=0.2)
×
246
        x_transform_composed = Compose(x_transform_list)
×
247

248
        expected_x = [x_transform] + x_transform_list + [x_transform_additional]
×
249

250
        y_transform = Lambda(lambda x: max(0, x - 1))
×
251

252
        dataset = MNIST(
×
253
            root=default_dataset_location("mnist"), download=True, transform=x_transform
254
        )
255

256
        transform_group = TransformGroups(
×
257
            transform_groups={
258
                "train": TupleTransform([x_transform_composed, y_transform])
259
            }
260
        )
261

262
        transform_group_additional_1a = TransformGroups(
×
263
            transform_groups={"train": TupleTransform([x_transform_additional, None])}
264
        )
265
        transform_group_additional_1b = TransformGroups(
×
266
            transform_groups={"train": TupleTransform([x_transform_additional, None])}
267
        )
268

269
        avl_dataset = AvalancheDataset([dataset], transform_groups=transform_group)
×
270

271
        avl_subset_1 = avl_dataset.subset([1, 2, 3])
×
272
        avl_subset_2 = avl_dataset.subset([5, 6, 7])
×
273

274
        avl_subset_1 = AvalancheDataset(
×
275
            [avl_subset_1], transform_groups=transform_group_additional_1a
276
        )
277
        avl_subset_2 = AvalancheDataset(
×
278
            [avl_subset_2], transform_groups=transform_group_additional_1b
279
        )
280

281
        for concat_type, avl_concat in zip(
×
282
            ["avalanche", "pytorch"],
283
            [
284
                avl_subset_1.concat(avl_subset_2),
285
                ConcatDataset([avl_subset_1, avl_subset_2]),
286
            ],
287
        ):
288
            with self.subTest("Concatenation type", concat_type=concat_type):
×
289
                _, _, transforms = single_flat_dataset(avl_concat)
×
290
                x_flattened = flat_transforms_recursive(transforms, 0)
×
291
                y_flattened = flat_transforms_recursive(transforms, 1)
×
292

293
                self.assertSequenceEqual(expected_x, x_flattened)
×
294
                self.assertSequenceEqual([y_transform], y_flattened)
×
295

296

297
if __name__ == "__main__":
1✔
298
    unittest.main()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc