• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ContinualAI / avalanche / 5268393053

pending completion
5268393053

Pull #1397

github

web-flow
Merge 60d244754 into e91562200
Pull Request #1397: Specialize benchmark creation helpers

417 of 538 new or added lines in 30 files covered. (77.51%)

43 existing lines in 5 files now uncovered.

16586 of 22630 relevant lines covered (73.29%)

2.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

12.83
/tests/distributed/test_distributed_helper.py
1
import os
4✔
2
import random
4✔
3
import shutil
4✔
4
import tempfile
4✔
5
import time
4✔
6
import unittest
4✔
7
import numpy as np
4✔
8

9
import torch
4✔
10
import torch.distributed as dst
4✔
11
from torch.nn import Module
4✔
12
from torch.nn.parallel import DistributedDataParallel
4✔
13
from avalanche.benchmarks.generators.benchmark_generators import (
4✔
14
    dataset_classification_benchmark,
15
)
16
from avalanche.benchmarks.utils.classification_dataset import \
4✔
17
    make_tensor_classification_dataset
18

19
from avalanche.distributed import DistributedHelper
4✔
20
from avalanche.distributed.distributed_helper import \
4✔
21
    RollingSeedContext, BroadcastSeedContext
22
from avalanche.models import SimpleMLP, as_multitask
4✔
23
from avalanche.models.utils import avalanche_model_adaptation
4✔
24

25
from avalanche.training.determinism.rng_manager import RNGManager
4✔
26
from tests.distributed.distributed_test_utils import \
4✔
27
    check_skip_distributed_slow_test, check_skip_distributed_test, \
28
    suppress_dst_tests_output, common_dst_tests_setup
29

30

31
class DistributedHelperTests(unittest.TestCase):
4✔
32

33
    def setUp(self) -> None:
4✔
34
        self.use_gpu_in_tests = common_dst_tests_setup()
×
35

36
    @unittest.skipIf(check_skip_distributed_test(),
4✔
37
                     'Distributed tests ignored')
38
    def test_device_id(self):
3✔
39
        if self.use_gpu_in_tests:
×
40
            self.assertEqual(dst.get_rank(), DistributedHelper.get_device_id())
×
41
            self.assertEqual(torch.device(f'cuda:{dst.get_rank()}'),
×
42
                             DistributedHelper.make_device())
43
        else:
44
            self.assertEqual(-1, DistributedHelper.get_device_id())
×
45
            self.assertEqual(torch.device('cpu'),
×
46
                             DistributedHelper.make_device())
47

48
    @unittest.skipIf(check_skip_distributed_test(),
4✔
49
                     'Distributed tests ignored')
50
    def test_wrap_model(self):
3✔
51
        mb_size = 1*2*2*3*5
×
52
        num_classes = 11
×
53
        torch.manual_seed(1234 + DistributedHelper.rank)
×
54
        mb_x = torch.randn((mb_size, 32))
×
55
        mb_y = torch.randint(0, num_classes, (mb_size,))
×
56
        mb_t = torch.full((mb_size,), 1)
×
57
        model = SimpleMLP(num_classes=num_classes, input_size=32)
×
58
        model = as_multitask(model, 'classifier')
×
59
        self.assertIsInstance(model, Module)
×
60

61
        device = DistributedHelper.make_device()
×
62

63
        if device.type == 'cuda':
×
64
            # Additional test: must raise an error if the model 
65
            # is not already in the correct device
66
            with self.assertRaises(Exception):
×
67
                model_wrapped = DistributedHelper.wrap_model(model)
×
68

69
        model = model.to(device)
×
70

71
        model_wrapped = DistributedHelper.wrap_model(model)
×
72
        self.assertIsInstance(model_wrapped, DistributedDataParallel)
×
73
        self.assertNotIsInstance(model, DistributedDataParallel)
×
74

75
        device = DistributedHelper.make_device()
×
76
        mb_x = mb_x.to(device)
×
77
        mb_y = mb_y.to(device)
×
78
        mb_t = mb_t.to(device)
×
79
        model = model.to(device)
×
80

81
        model.eval()
×
82
        model_wrapped.eval()
×
83

NEW
84
        benchmark = dataset_classification_benchmark(
×
85
            [make_tensor_classification_dataset(
86
                mb_x, mb_y, mb_t, task_labels=mb_t.tolist()
87
            )],
88
            [make_tensor_classification_dataset(
89
                mb_x, mb_y, mb_t, task_labels=mb_t.tolist()
90
            )]
91
        )
92

93
        avalanche_model_adaptation(model, benchmark.train_stream[0])
×
94

95
        with torch.no_grad():
×
96
            mb_out1 = model(mb_x, mb_t).detach()
×
97
            self.assertEqual(mb_out1.device, device)
×
98
            self.assertSequenceEqual([mb_size, num_classes], mb_out1.shape)
×
99

100
            mb_out2 = model_wrapped(mb_x, mb_t).detach()
×
101
            self.assertEqual(mb_out2.device, device)
×
102
            self.assertSequenceEqual([mb_size, num_classes], mb_out2.shape)
×
103

104
            self.assertTrue(torch.equal(mb_out1, mb_out2))
×
105

106
            mb_out_all = DistributedHelper.cat_all(mb_out2)
×
107

108
            start_idx = mb_size * DistributedHelper.rank
×
109
            end_idx = start_idx + mb_size
×
110

111
            self.assertTrue(torch.equal(mb_out1, 
×
112
                                        mb_out_all[start_idx: end_idx]))
113
        
114
        self.assertTrue(model is DistributedHelper.unwrap_model(model_wrapped))
×
115

116
    @unittest.skipIf(check_skip_distributed_test(),
4✔
117
                     'Distributed tests ignored')
118
    def test_broadcast_tensor_or_objects(self):
3✔
119
        ts = torch.full((10,), DistributedHelper.rank, dtype=torch.long)
×
120
        DistributedHelper.broadcast(ts)
×
121
        self.assertTrue(torch.equal(ts, torch.zeros((10,), dtype=torch.long)))
×
122

123
        device = DistributedHelper.make_device()
×
124
        ts = ts.to(device)
×
125

126
        my_object = {'a': DistributedHelper.rank, 'b': ts}
×
127
        my_object_from_main = DistributedHelper.broadcast_object(my_object)
×
128

129
        expect = {
×
130
            'a': 0, 
131
            'b': torch.full((10,), 0, dtype=torch.long).tolist()}
132
        
133
        self.assertEqual(device, my_object_from_main['b'].device)
×
134
        my_object_from_main['b'] = my_object_from_main['b'].tolist()
×
135
        self.assertEqual(expect, my_object_from_main)
×
136

137
    @unittest.skipIf(check_skip_distributed_test(),
4✔
138
                     'Distributed tests ignored')
139
    def test_gather_all_objects(self):
3✔
140
        ts = torch.full((10,), DistributedHelper.rank, dtype=torch.long)
×
141

142
        device = DistributedHelper.make_device()
×
143
        ts = ts.to(device)
×
144

145
        my_object = {'a': DistributedHelper.rank, 'b': ts}
×
146
        all_objects = DistributedHelper.gather_all_objects(my_object)
×
147
        self.assertIsInstance(all_objects, list)
×
148
        self.assertEqual(DistributedHelper.world_size, len(all_objects))
×
149

150
        for rank in range(DistributedHelper.world_size):
×
151
            expect = {
×
152
                'a': rank,
153
                'b': torch.full((10,), rank, dtype=torch.long).tolist()}
154
        
155
            self.assertEqual(device, all_objects[rank]['b'].device)
×
156
            all_objects[rank]['b'] = all_objects[rank]['b'].tolist()
×
157
            self.assertEqual(expect, all_objects[rank])
×
158

159
    @unittest.skipIf(check_skip_distributed_test(),
4✔
160
                     'Distributed tests ignored')
161
    def test_cat_all(self):
3✔
162
        if DistributedHelper.rank == 0:
×
163
            ts = torch.full((10+1, 5), DistributedHelper.rank, dtype=torch.long)
×
164
        else:
165
            ts = torch.full((10, 5), DistributedHelper.rank, dtype=torch.long)
×
166
        device = DistributedHelper.make_device()
×
167

168
        if device.type == 'cuda':
×
169
            # Additional test: tensors do not need to be on the default device
170
            DistributedHelper.cat_all(ts)
×
171
            
172
        ts = ts.to(device)
×
173

174
        concatenated_tensor = DistributedHelper.cat_all(ts)
×
175

176
        self.assertEqual(device, concatenated_tensor.device)
×
177

178
        expect = torch.empty((DistributedHelper.world_size * 10 + 1, 5), 
×
179
                             dtype=torch.long).to(device)
180
        for rank in range(DistributedHelper.world_size):
×
181
            if rank == 0:
×
182
                expect[rank * 10: (rank + 1) * 10 + 1] = rank
×
183
            else:
184
                expect[1 + rank * 10: 1 + (rank + 1) * 10] = rank
×
185
        
186
        self.assertTrue(torch.equal(concatenated_tensor, expect))
×
187

188
    @unittest.skipIf(check_skip_distributed_test(),
4✔
189
                     'Distributed tests ignored')
190
    def test_gather_all_same_size(self):
3✔
191
        ts = torch.full((10, 5), DistributedHelper.rank, dtype=torch.long)
×
192
        device = DistributedHelper.make_device()
×
193

194
        if device.type == 'cuda':
×
195
            # Additional test: tensors do not need to be on the default device
196
            DistributedHelper.gather_all(ts)
×
197

198
            # On the other hand, PyTorch all_gather requires tensors to be on
199
            # the default device
200
            with self.assertRaises(Exception):
×
201
                
202
                out_t = [torch.empty_like(ts)
×
203
                         for _ in range(DistributedHelper.world_size)]
204
                torch.distributed.all_gather(out_t, ts)
×
205
            
206
            # ... while this should work
207
            out_t = [torch.empty_like(ts).to(device)
×
208
                     for _ in range(DistributedHelper.world_size)]
209
            torch.distributed.all_gather(out_t, ts.to(device))
×
210

211
        ts = ts.to(device)
×
212

213
        for same_shape in [False, True]:
×
214
            print(f'same_shape={same_shape}')
×
215
            # with self.subTest(same_shape=same_shape):
216
            tensor_list = DistributedHelper.gather_all(
×
217
                ts, same_shape=same_shape)
218

219
            self.assertEqual(DistributedHelper.world_size, len(tensor_list))
×
220

221
            for t in tensor_list:
×
222
                self.assertEqual(device, t.device)
×
223

224
            for rank in range(DistributedHelper.world_size):
×
225
                expect = torch.full((10, 5), rank, dtype=torch.long).to(device)
×
226
                self.assertTrue(torch.equal(tensor_list[rank], expect))
×
227

228
    @unittest.skipIf(check_skip_distributed_slow_test(),
4✔
229
                     'Distributed tests ignored')
230
    def test_gather_all_performance_known_same_shape(self):
3✔
231
        ts = torch.full((128, 224, 224, 3),
×
232
                        DistributedHelper.rank,
233
                        dtype=torch.float32)
234
        device = DistributedHelper.make_device()
×
235
        ts = ts.to(device)
×
236

237
        resulting_tensors = [torch.empty_like(ts).to(device)
×
238
                             for _ in range(DistributedHelper.world_size)]
239

240
        from tqdm import tqdm
×
241
        n_times = 30
×
242
        torch.distributed.all_gather(resulting_tensors, ts)
×
243
        start_time = time.time()
×
244
        for _ in tqdm(range(n_times)):
×
245
            torch.distributed.all_gather(resulting_tensors, ts)
×
246
        end_time = time.time()
×
247
        print('Time taken by PyTorch all_gather', end_time-start_time,
×
248
              'avg', (end_time-start_time) / n_times)
249

250
        start_time = time.time()
×
251
        out_list = [None for _ in range(DistributedHelper.world_size)]
×
252
        torch.distributed.all_gather_object(out_list, ts)
×
253

254
        for _ in tqdm(range(n_times)):
×
255
            torch.distributed.all_gather_object(out_list, ts)
×
256
        end_time = time.time()
×
257
        print('Time taken by PyTorch all_gather_object', end_time-start_time,
×
258
              'avg', (end_time-start_time) / n_times)
259
    
260
    @unittest.skipIf(check_skip_distributed_slow_test(),
4✔
261
                     'Distributed tests ignored')
262
    def test_gather_all_performance_sync_shape(self):
3✔
263
        max_shape_size = 10
×
264
        shape = [128, 6, DistributedHelper.rank+1] + \
×
265
            ([3] * DistributedHelper.rank)
266

267
        device = DistributedHelper.make_device()
×
268

269
        def shape_all_gather():
×
270
            ts = torch.zeros((max_shape_size,), dtype=torch.int64)
×
271
            for i in range(len(shape)):
×
272
                ts[i] = shape[i]
×
273
            
274
            ts = ts.to(device)
×
275
            all_tensors_shape = [torch.empty_like(ts)
×
276
                                 for _ in range(DistributedHelper.world_size)]
277
            torch.distributed.all_gather(all_tensors_shape, ts)
×
278
            all_tensors_shape = [t.cpu() for t in all_tensors_shape]
×
279

280
            for i, t in enumerate(all_tensors_shape):
×
281
                for x in range(len(t)):
×
282
                    if t[x] == 0:
×
283
                        if x == 0:
×
284
                            # Tensor with 0-length shape
285
                            all_tensors_shape[i] = t[:x+1]
×
286
                        else:
287
                            all_tensors_shape[i] = t[:x]
×
288
                        break
×
289

290
        def shape_all_gather_objects():
×
291
            out_list = [None for _ in range(DistributedHelper.world_size)]
×
292
            torch.distributed.all_gather_object(out_list, shape)
×
293

294
        from tqdm import tqdm
×
295
        n_times = 1000
×
296
        shape_all_gather()
×
297
        start_time = time.time()
×
298
        for _ in tqdm(range(n_times)):
×
299
            shape_all_gather()
×
300
        end_time = time.time()
×
301
        print('Time taken by PyTorch all_gather', end_time-start_time,
×
302
              'avg', (end_time-start_time) / n_times)
303

304
        start_time = time.time()
×
305
        shape_all_gather_objects()
×
306

307
        for _ in tqdm(range(n_times)):
×
308
            shape_all_gather_objects()
×
309
        end_time = time.time()
×
310
        print('Time taken by PyTorch all_gather_object', end_time-start_time,
×
311
              'avg', (end_time-start_time) / n_times)
312
    
313
    @unittest.skipIf(check_skip_distributed_test(),
4✔
314
                     'Distributed tests ignored')
315
    def test_gather_all_same_dim0(self):
3✔
316
        ts = torch.full((10, DistributedHelper.rank+1),
×
317
                        DistributedHelper.rank,
318
                        dtype=torch.long)
319
        device = DistributedHelper.make_device()
×
320

321
        ts = ts.to(device)
×
322

323
        tensor_list = DistributedHelper.gather_all(ts)
×
324
        self.assertEqual(DistributedHelper.world_size, len(tensor_list))
×
325

326
        for t in tensor_list:
×
327
            self.assertEqual(device, t.device)
×
328

329
        for rank in range(DistributedHelper.world_size):
×
330
            expect = torch.full((10, rank+1),
×
331
                                rank,
332
                                dtype=torch.long).to(device)
333
            self.assertTrue(torch.equal(tensor_list[rank], expect))
×
334

335
    @unittest.skipIf(check_skip_distributed_test(),
4✔
336
                     'Distributed tests ignored')
337
    def test_gather_all_same_dim1_n(self):
3✔
338
        ts = torch.full((10+DistributedHelper.rank, 5),
×
339
                        DistributedHelper.rank,
340
                        dtype=torch.long)
341
        device = DistributedHelper.make_device()
×
342

343
        ts = ts.to(device)
×
344

345
        tensor_list = DistributedHelper.gather_all(ts)
×
346
        self.assertEqual(DistributedHelper.world_size, len(tensor_list))
×
347

348
        for t in tensor_list:
×
349
            self.assertEqual(device, t.device)
×
350

351
        for rank in range(DistributedHelper.world_size):
×
352
            expect = torch.full((10+rank, 5), 
×
353
                                rank,
354
                                dtype=torch.long).to(device)
355
            self.assertTrue(torch.equal(tensor_list[rank], expect))
×
356

357
    @unittest.skipIf(check_skip_distributed_test(),
4✔
358
                     'Distributed tests ignored')
359
    def test_gather_all_zero_shaped(self):
3✔
360
        ts = torch.full(tuple(), DistributedHelper.rank, dtype=torch.long)
×
361
        device = DistributedHelper.make_device()
×
362

363
        ts = ts.to(device)
×
364

365
        for same_shape in [False, True]:
×
366
            print(f'same_shape={same_shape}')
×
367
            # with self.subTest(same_shape=same_shape):
368
            tensor_list = DistributedHelper.gather_all(
×
369
                ts, 
370
                same_shape=same_shape)
371
            self.assertEqual(DistributedHelper.world_size, len(tensor_list))
×
372

373
            for t in tensor_list:
×
374
                self.assertEqual(device, t.device)
×
375

376
            for rank in range(DistributedHelper.world_size):
×
377
                expect = torch.full(tuple(), rank, dtype=torch.long).to(device)
×
378
                self.assertTrue(torch.equal(tensor_list[rank], expect))
×
379

380
    @unittest.skipIf(check_skip_distributed_test(),
4✔
381
                     'Distributed tests ignored')
382
    def test_check_equal_tensors(self):
3✔
383
        if DistributedHelper.world_size == 1 and \
×
384
                DistributedHelper.get_device_id() >= 0:
385
            self.skipTest('When using CUDA, there must be at '
×
386
                          'least two processes to run this test')
387
        torch.manual_seed(1234)
×
388
        ts = torch.randn((100,))
×
389
        DistributedHelper.check_equal_tensors(ts)
×
390

391
        torch.manual_seed(1234 + DistributedHelper.rank)
×
392
        ts = torch.randn((100,))
×
393
        with self.assertRaises(Exception):
×
394
            DistributedHelper.check_equal_tensors(ts)
×
395

396
    @unittest.skipIf(check_skip_distributed_test(),
4✔
397
                     'Distributed tests ignored')
398
    def test_fields(self):
3✔
399
        self.assertEqual(dst.get_rank(), DistributedHelper.rank)
×
400
        self.assertEqual(dst.get_world_size(), DistributedHelper.world_size)
×
401
        self.assertEqual(True, DistributedHelper.is_distributed)
×
402
        self.assertEqual(dst.get_rank() == 0, DistributedHelper.is_main_process)
×
403

404
        if self.use_gpu_in_tests:
×
405
            self.assertEqual('nccl', DistributedHelper.backend)
×
406
            self.assertTrue(DistributedHelper.forced_cuda_comm)
×
407
        else:
408
            self.assertEqual('gloo', DistributedHelper.backend)
×
409
            self.assertFalse(DistributedHelper.forced_cuda_comm)
×
410

411
    @unittest.skipIf(check_skip_distributed_test(),
4✔
412
                     'Distributed tests ignored')
413
    def test_set_random_seeds_and_align(self):
3✔
414
        DistributedHelper.set_random_seeds(5678)
×
415

416
        self.assertEqual(297076, np.random.randint(0, 1000000))
×
417
        self.assertEqual(643380, torch.randint(0, 1000000, (1,)).item())
×
418
        self.assertEqual(683410, random.randint(0, 1000000))
×
419

420
        if DistributedHelper.is_main_process:
×
421
            np.random.randint(0, 1000000)
×
422
            torch.randint(0, 1000000, (1,))
×
423
            random.randint(0, 1000000)
×
424

425
        DistributedHelper.align_seeds()
×
426
        
427
        ref_values = (
×
428
            int(np.random.randint(0, 1000000)),
429
            int(torch.randint(0, 1000000, (1,))),
430
            int(random.randint(0, 1000000))
431
        )
432

433
        DistributedHelper.check_equal_objects(ref_values)
×
434
    
435
    @unittest.skipIf(check_skip_distributed_test(),
4✔
436
                     'Distributed tests ignored')
437
    def test_rolling_seed_aligner(self):
3✔
438
        RNGManager.set_random_seeds(4321)
×
439

440
        with RollingSeedContext():
×
441
            RNGManager.set_random_seeds(1234 + DistributedHelper.rank)
×
442
            random.randint(0, 2 ** 64 - 1)
×
443

444
        final_value = random.randint(0, 2 ** 64 - 1)
×
445
        self.assertEqual(14732185405572191734, final_value)
×
446

447
    @unittest.skipIf(check_skip_distributed_test(),
4✔
448
                     'Distributed tests ignored')
449
    def test_broadcast_seed_aligner(self):
3✔
450
        RNGManager.set_random_seeds(4321)
×
451

452
        with BroadcastSeedContext():
×
453
            RNGManager.set_random_seeds(1234 + DistributedHelper.rank)
×
454
            random.randint(0, 2 ** 64 - 1)
×
455

456
        final_value = random.randint(0, 2 ** 64 - 1)
×
457
        self.assertEqual(15306775005444441373, final_value)
×
458
    
459
    @unittest.skipIf(check_skip_distributed_test(),
4✔
460
                     'Distributed tests ignored')
461
    def test_main_process_first(self):
3✔
462
        tmpdirname = ''
×
463
        try:
×
464
            my_rank = DistributedHelper.rank
×
465
            if DistributedHelper.is_main_process:
×
466
                tmpdirname = tempfile.mkdtemp()
×
467
            
468
            tmpdirname = DistributedHelper.broadcast_object(tmpdirname)
×
469
        
470
            with DistributedHelper.main_process_first():
×
471
                
472
                for _ in range(2):
×
473
                    time.sleep(0.1 + my_rank * 0.05)
×
474
                    files = list(os.listdir(tmpdirname))
×
475
                    if DistributedHelper.is_main_process:
×
476
                        self.assertEqual(0, len(files))
×
477
                    else:
478
                        self.assertIn(f'rank0', files)
×
479
                        self.assertNotIn(f'rank{my_rank}', files)
×
480

481
                with open(os.path.join(tmpdirname, f'rank{my_rank}'), 'w') \
×
482
                        as f:
483
                    f.write('ok')
×
484
                
485
                for _ in range(2):
×
486
                    time.sleep(0.1 + my_rank * 0.05)
×
487
                    files = list(os.listdir(tmpdirname))
×
488
                    if DistributedHelper.is_main_process:
×
489
                        self.assertEqual(1, len(files))
×
490
                        self.assertIn(f'rank0', files)
×
491
                    else:
492
                        self.assertIn(f'rank0', files)
×
493
                        self.assertIn(f'rank{my_rank}', files)
×
494
            
495
            DistributedHelper.barrier()
×
496
            files = set(os.listdir(tmpdirname))
×
497
            expect = set([f'rank{rnk}'
×
498
                          for rnk in range(DistributedHelper.world_size)])
499
            self.assertSetEqual(expect, files)
×
500
            DistributedHelper.barrier()
×
501
        finally:
502
            if tmpdirname is not None and DistributedHelper.is_main_process:
×
503
                shutil.rmtree(tmpdirname)
×
504

505

506
if __name__ == "__main__":
4✔
507
    with suppress_dst_tests_output():
×
508
        verbosity = 1
×
509
        if DistributedHelper.rank > 0:
×
510
            verbosity = 0
×
511
        unittest.main(verbosity=verbosity)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc