• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ContinualAI / avalanche / 5725326611

pending completion
5725326611

push

github

web-flow
Merge pull request #1439 from lrzpellegrini/ffcv_support_pt2

FFCV support

500 of 806 new or added lines in 14 files covered. (62.03%)

1 existing line in 1 file now uncovered.

17477 of 23989 relevant lines covered (72.85%)

2.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.09
/tests/test_transformations.py
1
import copy
4✔
2
import unittest
4✔
3
from avalanche.benchmarks.datasets.dataset_utils import default_dataset_location
4✔
4
from avalanche.benchmarks.utils.data import AvalancheDataset
4✔
5
from avalanche.benchmarks.utils.dataset_traversal_utils import single_flat_dataset
4✔
6
from avalanche.benchmarks.utils.detection_dataset import DetectionDataset
4✔
7
from avalanche.benchmarks.classic.cmnist import SplitMNIST
4✔
8
from avalanche.benchmarks.utils.transform_groups import TransformGroups
4✔
9

10
from avalanche.benchmarks.utils.transforms import (
4✔
11
    MultiParamCompose,
12
    MultiParamTransformCallable,
13
    TupleTransform,
14
    flat_transforms_recursive,
15
)
16

17
import torch
4✔
18
from PIL import ImageChops
4✔
19
from torch import Tensor
4✔
20
from torch.utils.data import DataLoader, ConcatDataset
4✔
21
from torchvision.datasets import MNIST
4✔
22
from torchvision.transforms import (
4✔
23
    ToTensor,
24
    Compose,
25
    CenterCrop,
26
    Normalize,
27
    Lambda,
28
    RandomHorizontalFlip,
29
)
30
from torchvision.transforms.functional import to_tensor
4✔
31
from PIL.Image import Image
4✔
32

33
from tests.unit_tests_utils import get_fast_detection_datasets
4✔
34

35

36
def pil_images_equal(img_a, img_b):
4✔
NEW
37
    diff = ImageChops.difference(img_a, img_b)
×
38

NEW
39
    return not diff.getbbox()
×
40

41

42
def zero_if_label_2(img_tensor: Tensor, class_label):
4✔
NEW
43
    if int(class_label) == 2:
×
NEW
44
        torch.full(img_tensor.shape, 0.0, out=img_tensor)
×
45

NEW
46
    return img_tensor, class_label
×
47

48

49
def get_mbatch(data, batch_size=5):
4✔
NEW
50
    dl = DataLoader(
×
51
        data, shuffle=False, batch_size=batch_size, collate_fn=data.collate_fn
52
    )
NEW
53
    return next(iter(dl))
×
54

55

56
class TransformsTest(unittest.TestCase):
4✔
57
    def test_multi_param_transform_callable(self):
4✔
58
        dataset: DetectionDataset
59
        dataset, _ = get_fast_detection_datasets()
4✔
60

61
        boxes = []
4✔
62
        i = 0
4✔
63
        while len(boxes) == 0:
4✔
64
            x_orig, y_orig, t_orig = dataset[i]
4✔
65
            boxes = y_orig["boxes"]
4✔
66
            i += 1
4✔
67
        i -= 1
4✔
68

69
        x_expect = to_tensor(copy.deepcopy(x_orig))
4✔
70
        x_expect[0][0] += 1
4✔
71

72
        y_expect = copy.deepcopy(y_orig)
4✔
73
        y_expect["boxes"][0][0] += 1
4✔
74

75
        def do_something_xy(img, target):
4✔
76
            img = to_tensor(img)
4✔
77
            img[0][0] += 1
4✔
78
            target["boxes"][0][0] += 1
4✔
79
            return img, target
4✔
80

81
        uut = MultiParamTransformCallable(do_something_xy)
4✔
82

83
        # Test __eq__
84
        uut_eq = MultiParamTransformCallable(do_something_xy)
4✔
85
        self.assertTrue(uut == uut_eq)
4✔
86
        self.assertTrue(uut_eq == uut)
4✔
87

88
        x, y, t = uut(*dataset[i])
4✔
89

90
        self.assertIsInstance(x, torch.Tensor)
4✔
91
        self.assertIsInstance(y, dict)
4✔
92
        self.assertIsInstance(t, int)
4✔
93

94
        self.assertTrue(torch.equal(x_expect, x))
4✔
95
        keys = set(y_expect.keys())
4✔
96
        self.assertSetEqual(keys, set(y.keys()))
4✔
97

98
        for k in keys:
4✔
99
            self.assertTrue(torch.equal(y_expect[k], y[k]), msg=f"Wrong {k}")
4✔
100

101
    def test_multi_param_compose(self):
4✔
102
        dataset: DetectionDataset
103
        dataset, _ = get_fast_detection_datasets()
4✔
104

105
        assert_called = 0
4✔
106

107
        def do_something_xy(img: Tensor, target):
4✔
108
            nonlocal assert_called
109
            assert_called += 1
4✔
110
            img = img.clone()
4✔
111
            img[0][0] += 1
4✔
112
            target["boxes"][0][0] += 1
4✔
113
            return img, target
4✔
114

115
        t_x = lambda x, y: (to_tensor(x), y)
4✔
116
        t_xy = do_something_xy
4✔
117
        t_x_1_element = ToTensor()
4✔
118

119
        boxes = []
4✔
120
        i = 0
4✔
121
        while len(boxes) == 0:
4✔
122
            x_orig, y_orig, t_orig = dataset[i]
4✔
123
            boxes = y_orig["boxes"]
4✔
124
            i += 1
4✔
125
        i -= 1
4✔
126

127
        x_expect = to_tensor(copy.deepcopy(x_orig))
4✔
128
        x_expect[0][0] += 1
4✔
129

130
        y_expect = copy.deepcopy(y_orig)
4✔
131
        y_expect["boxes"][0][0] += 1
4✔
132

133
        uut_2 = MultiParamCompose([t_x, t_xy])
4✔
134

135
        # Test __eq__
136
        uut_2_eq = MultiParamCompose([t_x, t_xy])
4✔
137
        self.assertTrue(uut_2 == uut_2_eq)
4✔
138
        self.assertTrue(uut_2_eq == uut_2)
4✔
139

140
        with self.assertWarns(Warning):
4✔
141
            # Assert that the following warn is raised:
142
            # "Transformations define a different number of parameters. ..."
143
            uut_1 = MultiParamCompose([t_x_1_element, t_xy])
4✔
144

145
        for uut, uut_type in zip((uut_1, uut_2), ("uut_1", "uut_2")):
4✔
146
            with self.subTest(uut_type=uut_type):
4✔
147
                initial_assert_called = assert_called
4✔
148

149
                x, y, t = uut(*dataset[i])
4✔
150

151
                self.assertEqual(initial_assert_called + 1, assert_called)
4✔
152

153
                self.assertIsInstance(x, torch.Tensor)
4✔
154
                self.assertIsInstance(y, dict)
4✔
155
                self.assertIsInstance(t, int)
4✔
156

157
                self.assertTrue(torch.equal(x_expect, x))
4✔
158
                keys = set(y_expect.keys())
4✔
159
                self.assertSetEqual(keys, set(y.keys()))
4✔
160

161
                for k in keys:
4✔
162
                    self.assertTrue(torch.equal(y_expect[k], y[k]), msg=f"Wrong {k}")
4✔
163

164
    def test_tuple_transform(self):
4✔
165
        dataset = MNIST(root=default_dataset_location("mnist"), download=True)
4✔
166

167
        t_x = ToTensor()
4✔
168
        t_y = lambda element: element + 1
4✔
169
        t_bad = lambda element: element - 1
4✔
170

171
        uut = TupleTransform([t_x, t_y])
4✔
172

173
        uut_eq = TupleTransform(
4✔
174
            (t_x, t_y)  # Also test with a tuple instead of a list here
175
        )
176

177
        uut_not_x = TupleTransform([None, t_y])
4✔
178

179
        uut_bad = TupleTransform((t_x, t_y, t_bad))
4✔
180

181
        x_orig, y_orig = dataset[0]
4✔
182

183
        # Test with x transform
184
        x, y = uut(*dataset[0])
4✔
185

186
        self.assertIsInstance(x, torch.Tensor)
4✔
187
        self.assertIsInstance(y, int)
4✔
188

189
        self.assertTrue(torch.equal(to_tensor(x_orig), x))
4✔
190
        self.assertEqual(y_orig + 1, y)
4✔
191

192
        # Test without x transform
193
        x, y = uut_not_x(*dataset[0])
4✔
194

195
        self.assertIsInstance(x, Image)
4✔
196
        self.assertIsInstance(y, int)
4✔
197

198
        self.assertEqual(x_orig, x)
4✔
199
        self.assertEqual(y_orig + 1, y)
4✔
200

201
        # Check __eq__ works
202
        self.assertTrue(uut == uut_eq)
4✔
203
        self.assertTrue(uut_eq == uut)
4✔
204

205
        self.assertFalse(uut == uut_not_x)
4✔
206
        self.assertFalse(uut_not_x == uut)
4✔
207

208
        with self.assertRaises(Exception):
4✔
209
            # uut_bad has 3 transforms, which is incorrect
210
            uut_bad(*dataset[0])
4✔
211

212
    def test_flat_transforms_recursive_only_torchvision(self):
4✔
213
        x_transform = ToTensor()
4✔
214
        x_transform_list = [CenterCrop(24), Normalize(0.5, 0.1)]
4✔
215
        x_transform_composed = Compose(x_transform_list)
4✔
216

217
        expected_x = [x_transform] + x_transform_list
4✔
218

219
        # Single transforms checks
220
        self.assertSequenceEqual(
4✔
221
            [x_transform], flat_transforms_recursive([x_transform], 0)
222
        )
223

224
        self.assertSequenceEqual(
4✔
225
            [x_transform], flat_transforms_recursive(x_transform, 0)
226
        )
227

228
        self.assertSequenceEqual(
4✔
229
            x_transform_list, flat_transforms_recursive(x_transform_list, 0)
230
        )
231

232
        self.assertSequenceEqual(
4✔
233
            x_transform_list, flat_transforms_recursive(x_transform_composed, 0)
234
        )
235

236
        # Hybrid list checks
237
        self.assertSequenceEqual(
4✔
238
            expected_x,
239
            flat_transforms_recursive([x_transform, x_transform_composed], 0),
240
        )
241

242
    def test_flat_transforms_recursive_from_dataset(self):
4✔
243
        x_transform = ToTensor()
4✔
244
        x_transform_list = [CenterCrop(24), Normalize(0.5, 0.1)]
4✔
245
        x_transform_additional = RandomHorizontalFlip(p=0.2)
4✔
246
        x_transform_composed = Compose(x_transform_list)
4✔
247

248
        expected_x = [x_transform] + x_transform_list + [x_transform_additional]
4✔
249

250
        y_transform = Lambda(lambda x: max(0, x - 1))
4✔
251

252
        dataset = MNIST(
4✔
253
            root=default_dataset_location("mnist"), download=True, transform=x_transform
254
        )
255

256
        transform_group = TransformGroups(
4✔
257
            transform_groups={
258
                "train": TupleTransform([x_transform_composed, y_transform])
259
            }
260
        )
261

262
        transform_group_additional_1a = TransformGroups(
4✔
263
            transform_groups={"train": TupleTransform([x_transform_additional, None])}
264
        )
265
        transform_group_additional_1b = TransformGroups(
4✔
266
            transform_groups={"train": TupleTransform([x_transform_additional, None])}
267
        )
268

269
        avl_dataset = AvalancheDataset([dataset], transform_groups=transform_group)
4✔
270

271
        avl_subset_1 = avl_dataset.subset([1, 2, 3])
4✔
272
        avl_subset_2 = avl_dataset.subset([5, 6, 7])
4✔
273

274
        avl_subset_1 = AvalancheDataset(
4✔
275
            [avl_subset_1], transform_groups=transform_group_additional_1a
276
        )
277
        avl_subset_2 = AvalancheDataset(
4✔
278
            [avl_subset_2], transform_groups=transform_group_additional_1b
279
        )
280

281
        for concat_type, avl_concat in zip(
4✔
282
            ["avalanche", "pytorch"],
283
            [
284
                avl_subset_1.concat(avl_subset_2),
285
                ConcatDataset([avl_subset_1, avl_subset_2]),
286
            ],
287
        ):
288
            with self.subTest("Concatenation type", concat_type=concat_type):
4✔
289
                _, _, transforms = single_flat_dataset(avl_concat)
4✔
290
                x_flattened = flat_transforms_recursive(transforms, 0)
4✔
291
                y_flattened = flat_transforms_recursive(transforms, 1)
4✔
292

293
                self.assertSequenceEqual(expected_x, x_flattened)
4✔
294
                self.assertSequenceEqual([y_transform], y_flattened)
4✔
295

296

297
if __name__ == "__main__":
4✔
NEW
298
    unittest.main()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc