• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ContinualAI / avalanche / 4993189103

pending completion
4993189103

Pull #1370

github

Unknown Committer
Unknown Commit Message
Pull Request #1370: Add base elements to support distributed comms. Add supports_distributed plugin flag.

258 of 822 new or added lines in 27 files covered. (31.39%)

80 existing lines in 5 files now uncovered.

15585 of 21651 relevant lines covered (71.98%)

2.88 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

81.35
/avalanche/benchmarks/utils/data_loader.py
1
################################################################################
2
# Copyright (c) 2021 ContinualAI.                                              #
3
# Copyrights licensed under the MIT License.                                   #
4
# See the accompanying LICENSE file for terms.                                 #
5
#                                                                              #
6
# Date: 01-12-2020                                                             #
7
# Author(s): Antonio Carta                                                     #
8
# E-mail: contact@continualai.org                                              #
9
# Website: avalanche.continualai.org                                           #
10
################################################################################
11
"""
4✔
12
    Avalanche supports data loading using pytorch's dataloaders.
13
    This module provides custom dataloaders for continual learning such as
14
    support for balanced dataloading between different tasks or balancing
15
    between the current data and the replay memory.
16
"""
17
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union
4✔
18

19
import torch
4✔
20
from torch.utils.data import RandomSampler, DistributedSampler, Dataset
4✔
21
from torch.utils.data.dataloader import DataLoader
4✔
22

23
from avalanche.benchmarks.utils.collate_functions import (
4✔
24
    classification_collate_mbatches_fn,
25
)
26
from avalanche.benchmarks.utils.collate_functions import (
4✔
27
    detection_collate_fn as _detection_collate_fn,
28
)
29
from avalanche.benchmarks.utils.collate_functions import (
4✔
30
    detection_collate_mbatches_fn as _detection_collate_mbatches_fn,
31
)
32
from avalanche.benchmarks.utils.data import AvalancheDataset
4✔
33
from avalanche.benchmarks.utils.data_attribute import DataAttribute
4✔
34
from avalanche.distributed.distributed_helper import DistributedHelper
4✔
35

36
_default_collate_mbatches_fn = classification_collate_mbatches_fn
4✔
37

38
detection_collate_fn = _detection_collate_fn
4✔
39

40
detection_collate_mbatches_fn = _detection_collate_mbatches_fn
4✔
41

42

43
def return_identity(x):
4✔
44
    """
45
    The identity function. Can be wrapped in 'partial'
46
    to act as a getter function.
47
    Used to avoid lambda functions that cannot be pickled.
48
    """
49
    return x
4✔
50

51

52
def collate_from_data_or_kwargs(data, kwargs):
4✔
53
    if "collate_fn" in kwargs:
4✔
54
        return
4✔
55
    elif hasattr(data, "collate_fn"):
4✔
56
        kwargs["collate_fn"] = data.collate_fn
4✔
57

58

59
class TaskBalancedDataLoader:
4✔
60
    """Task-balanced data loader for Avalanche's datasets."""
4✔
61

62
    def __init__(
4✔
63
        self,
64
        data: AvalancheDataset,
65
        oversample_small_tasks: bool = False,
66
        **kwargs
67
    ):
68
        """Task-balanced data loader for Avalanche's datasets.
69

70
        The iterator returns a mini-batch balanced across each task, which
71
        makes it useful when training in multi-task scenarios whenever data is
72
        highly unbalanced.
73

74
        If `oversample_small_tasks == True` smaller tasks are
75
        oversampled to match the largest task. Otherwise, once the data for a
76
        specific task is terminated, that task will not be present in the
77
        subsequent mini-batches.
78

79
        :param data: an instance of `AvalancheDataset`.
80
        :param oversample_small_tasks: whether smaller tasks should be
81
            oversampled to match the largest one.
82
        :param kwargs: data loader arguments used to instantiate the loader for
83
            each task separately. See pytorch :class:`DataLoader`.
84
        """
85
        if "collate_mbatches" in kwargs:
4✔
86
            raise ValueError(
×
87
                "collate_mbatches is not needed anymore and it has been "
88
                "deprecated. Data loaders will use the collate function"
89
                "`data.collate_fn`."
90
            )
91

92
        self.data = data
4✔
93
        self.dataloaders: Dict[int, DataLoader] = dict()
4✔
94
        self.oversample_small_tasks = oversample_small_tasks
4✔
95

96
        # split data by task.
97
        task_datasets = []
4✔
98
        task_labels_field = getattr(self.data, 'targets_task_labels')
4✔
99
        assert isinstance(task_labels_field, DataAttribute)
4✔
100
        for task_label in task_labels_field.uniques:
4✔
101

102
            tidxs = task_labels_field.val_to_idx[task_label]
4✔
103
            tdata = self.data.subset(tidxs)
4✔
104
            task_datasets.append(tdata)
4✔
105

106
        # the iteration logic is implemented by GroupBalancedDataLoader.
107
        # we use kwargs to pass the arguments to avoid passing the same
108
        # arguments multiple times.
109
        if "data" in kwargs:
4✔
110
            del kwargs["data"]
×
111
        # needed if they are passed as positional arguments
112
        kwargs["oversample_small_groups"] = oversample_small_tasks
4✔
113
        self._dl = GroupBalancedDataLoader(datasets=task_datasets, **kwargs)
4✔
114

115
    def __iter__(self):
4✔
116
        for el in self._dl.__iter__():
4✔
117
            yield el
4✔
118

119
    def __len__(self):
4✔
120
        return self._dl.__len__()
4✔
121

122

123
class GroupBalancedDataLoader:
4✔
124
    """Data loader that balances data from multiple datasets."""
4✔
125

126
    def __init__(
4✔
127
        self,
128
        datasets: Sequence[AvalancheDataset],
129
        oversample_small_groups: bool = False,
130
        batch_size: int = 32,
131
        distributed_sampling: bool = True,
132
        **kwargs
133
    ):
134
        """Data loader that balances data from multiple datasets.
135

136
        Mini-batches emitted by this dataloader are created by collating
137
        together mini-batches from each group. It may be used to balance data
138
        among classes, experiences, tasks, and so on.
139

140
        If `oversample_small_groups == True` smaller groups are oversampled to
141
        match the largest group. Otherwise, once data from a group is
142
        completely iterated, the group will be skipped.
143

144
        :param datasets: an instance of `AvalancheDataset`.
145
        :param oversample_small_groups: whether smaller groups should be
146
            oversampled to match the largest one.
147
        :param batch_size: the size of the batch. It must be greater than or
148
            equal to the number of groups.
149
        :param kwargs: data loader arguments used to instantiate the loader for
150
            each group separately. See pytorch :class:`DataLoader`.
151
        """
152
        if "collate_mbatches" in kwargs:
4✔
153
            raise ValueError(
×
154
                "collate_mbatches is not needed anymore and it has been "
155
                "deprecated. Data loaders will use the collate function"
156
                "`data.collate_fn`."
157
            )
158

159
        self.datasets = datasets
4✔
160
        self.batch_sizes = []
4✔
161
        self.oversample_small_groups = oversample_small_groups
4✔
162
        self.distributed_sampling = distributed_sampling
4✔
163
        self.loader_kwargs = kwargs
4✔
164
        if "collate_fn" in kwargs:
4✔
165
            self.collate_fn = kwargs["collate_fn"]
×
166
        else:
167
            self.collate_fn = self.datasets[0].collate_fn
4✔
168

169
        # collate is done after we have all batches
170
        # so we set an empty collate for the internal dataloaders
171
        self.loader_kwargs["collate_fn"] = return_identity
4✔
172

173
        # check if batch_size is larger than or equal to the number of datasets
174
        assert batch_size >= len(datasets)
4✔
175

176
        # divide the batch between all datasets in the group
177
        ds_batch_size = batch_size // len(datasets)
4✔
178
        remaining = batch_size % len(datasets)
4✔
179

180
        for _ in self.datasets:
4✔
181
            bs = ds_batch_size
4✔
182
            if remaining > 0:
4✔
183
                bs += 1
4✔
184
                remaining -= 1
4✔
185
            self.batch_sizes.append(bs)
4✔
186

187
        loaders_for_len_estimation = [
4✔
188
            _make_data_loader(
189
                dataset,
190
                distributed_sampling,
191
                kwargs,
192
                mb_size,
193
                force_no_workers=True,
194
            )[0]
195
            for dataset, mb_size in zip(self.datasets, self.batch_sizes)
196
        ]
197

198
        self.max_len = max([len(d) for d in loaders_for_len_estimation])
4✔
199

200
    def __iter__(self):
4✔
201
        dataloaders = []
4✔
202
        samplers = []
4✔
203
        for dataset, mb_size in zip(self.datasets, self.batch_sizes):
4✔
204
            data_l, data_l_sampler = _make_data_loader(
4✔
205
                dataset,
206
                self.distributed_sampling,
207
                self.loader_kwargs,
208
                mb_size,
209
            )
210

211
            dataloaders.append(data_l)
4✔
212
            samplers.append(data_l_sampler)
4✔
213

214
        iter_dataloaders = []
4✔
215
        for dl in dataloaders:
4✔
216
            iter_dataloaders.append(iter(dl))
4✔
217

218
        max_num_mbatches = max([len(d) for d in dataloaders])
4✔
219
        for it in range(max_num_mbatches):
4✔
220
            mb_curr = []
4✔
221
            removed_dataloaders_idxs = []
4✔
222
            # copy() is necessary because we may remove keys from the
223
            # dictionary. This would break the generator.
224
            for tid, (t_loader, t_loader_sampler) in enumerate(
4✔
225
                zip(iter_dataloaders, samplers)
226
            ):
227
                try:
4✔
228
                    batch = next(t_loader)
4✔
229
                except StopIteration:
4✔
230
                    # StopIteration is thrown if dataset ends.
231
                    if self.oversample_small_groups:
4✔
232
                        # reinitialize data loader
233
                        if isinstance(t_loader_sampler, DistributedSampler):
×
234
                            # Manage shuffling in DistributedSampler
235
                            t_loader_sampler.set_epoch(
×
236
                                t_loader_sampler.epoch + 1
237
                            )
238

239
                        iter_dataloaders[tid] = iter(dataloaders[tid])
×
240
                        batch = next(iter_dataloaders[tid])
×
241
                    else:
242
                        # We iteratated over all the data from this group
243
                        # and we don't need the iterator anymore.
244
                        iter_dataloaders[tid] = None
4✔
245
                        samplers[tid] = None
4✔
246
                        removed_dataloaders_idxs.append(tid)
4✔
247
                        continue
4✔
248
                mb_curr.extend(batch)
4✔
249
            yield self.collate_fn(mb_curr)
4✔
250

251
            # clear empty data-loaders
252
            for tid in reversed(removed_dataloaders_idxs):
4✔
253
                del iter_dataloaders[tid]
4✔
254
                del samplers[tid]
4✔
255

256
    def __len__(self):
4✔
257
        return self.max_len
4✔
258

259

260
class GroupBalancedInfiniteDataLoader:
4✔
261
    """Data loader that balances data from multiple datasets emitting an
4✔
262
    infinite stream."""
263

264
    def __init__(
4✔
265
        self,
266
        datasets: Sequence[AvalancheDataset],
267
        collate_mbatches=_default_collate_mbatches_fn,
268
        distributed_sampling: bool = True,
269
        **kwargs
270
    ):
271
        """Data loader that balances data from multiple datasets emitting an
272
        infinite stream.
273
        Mini-batches emitted by this dataloader are created by collating
274
        together mini-batches from each group. It may be used to balance data
275
        among classes, experiences, tasks, and so on.
276
        :param datasets: an instance of `AvalancheDataset`.
277
        :param collate_mbatches: function that given a sequence of mini-batches
278
            (one for each task) combines them into a single mini-batch. Used to
279
            combine the mini-batches obtained separately from each task.
280
        :param kwargs: data loader arguments used to instantiate the loader for
281
            each group separately. See pytorch :class:`DataLoader`.
282
        """
283
        self.datasets = datasets
4✔
284
        self.dataloaders = []
4✔
285
        self.collate_mbatches = collate_mbatches
4✔
286

287
        for data in self.datasets:
4✔
288
            if DistributedHelper.is_distributed and distributed_sampling:
4✔
289
                seed = torch.randint(
×
290
                    0,
291
                    2 ** 32 - 1 - DistributedHelper.world_size,
292
                    (1,),
293
                    dtype=torch.int64,
294
                )
NEW
295
                seed += DistributedHelper.rank
×
296
                generator = torch.Generator()
×
297
                generator.manual_seed(int(seed))
×
298
            else:
299
                generator = None  # Default
4✔
300
            infinite_sampler = RandomSampler(
4✔
301
                data,
302
                replacement=True,
303
                num_samples=10 ** 10,
304
                generator=generator,
305
            )
306
            collate_from_data_or_kwargs(data, kwargs)
4✔
307
            dl = DataLoader(data, sampler=infinite_sampler, **kwargs)
4✔
308
            self.dataloaders.append(dl)
4✔
309
        self.max_len = 10 ** 10
4✔
310

311
    def __iter__(self):
4✔
312
        iter_dataloaders = []
4✔
313
        for dl in self.dataloaders:
4✔
314
            iter_dataloaders.append(iter(dl))
4✔
315

316
        while True:
2✔
317
            mb_curr = []
4✔
318
            for tid, t_loader in enumerate(iter_dataloaders):
4✔
319
                batch = next(t_loader)
4✔
320
                mb_curr.append(batch)
4✔
321
            yield self.collate_mbatches(mb_curr)
4✔
322

323
    def __len__(self):
4✔
324
        return self.max_len
×
325

326

327
class ReplayDataLoader:
4✔
328
    """Custom data loader for rehearsal/replay strategies."""
4✔
329

330
    def __init__(
4✔
331
        self,
332
        data: AvalancheDataset,
333
        memory: Optional[AvalancheDataset] = None,
334
        oversample_small_tasks: bool = False,
335
        batch_size: int = 32,
336
        batch_size_mem: int = 32,
337
        task_balanced_dataloader: bool = False,
338
        distributed_sampling: bool = True,
339
        **kwargs
340
    ):
341
        """Custom data loader for rehearsal strategies.
342

343
        This dataloader iterates in parallel two datasets, the current `data`
344
        and the rehearsal `memory`, which are used to create mini-batches by
345
        concatenating their data together. Mini-batches from both of them are
346
        balanced using the task label (i.e. each mini-batch contains a balanced
347
        number of examples from all the tasks in the `data` and `memory`).
348

349
        The length of the loader is determined only by the current 
350
        task data and is the same than what it would be when creating a 
351
        data loader for this dataset.
352

353
        If `oversample_small_tasks == True` smaller tasks are oversampled to
354
        match the largest task.
355

356
        :param data: AvalancheDataset.
357
        :param memory: AvalancheDataset.
358
        :param oversample_small_tasks: whether smaller tasks should be
359
            oversampled to match the largest one.
360
        :param batch_size: the size of the data batch. It must be greater
361
            than or equal to the number of tasks.
362
        :param batch_size_mem: the size of the memory batch. If
363
            `task_balanced_dataloader` is set to True, it must be greater than
364
            or equal to the number of tasks.
365
        :param task_balanced_dataloader: if true, buffer data loaders will be
366
            task-balanced, otherwise it creates a single data loader for the
367
            buffer samples.
368
        :param kwargs: data loader arguments used to instantiate the loader for
369
            each task separately. See pytorch :class:`DataLoader`.
370
        """
371
        if "collate_mbatches" in kwargs:
4✔
372
            raise ValueError(
×
373
                "collate_mbatches is not needed anymore and it has been "
374
                "deprecated. Data loaders will use the collate function"
375
                "`data.collate_fn`."
376
            )
377

378
        self.data = data
4✔
379
        self.memory = memory
4✔
380
        self.oversample_small_tasks = oversample_small_tasks
4✔
381
        self.task_balanced_dataloader = task_balanced_dataloader
4✔
382
        self.data_batch_sizes: Union[int, Dict[int, int]] = dict()
4✔
383
        self.memory_batch_sizes: Union[int, Dict[int, int]] = dict()
4✔
384
        self.distributed_sampling = distributed_sampling
4✔
385
        self.loader_kwargs = kwargs
4✔
386

387
        if "collate_fn" in kwargs:
4✔
388
            self.collate_fn = kwargs["collate_fn"]
×
389
        else:
390
            self.collate_fn = self.data.collate_fn
4✔
391

392
        # collate is done after we have all batches
393
        # so we set an empty collate for the internal dataloaders
394
        self.loader_kwargs["collate_fn"] = lambda x: x
4✔
395

396
        if task_balanced_dataloader:
4✔
397
            memory_task_labels = getattr(self.memory, 'targets_task_labels')
×
398
            assert isinstance(memory_task_labels, DataAttribute)
×
399
            num_keys = len(memory_task_labels.uniques)
×
400
            assert batch_size_mem >= num_keys, (
×
401
                "Batch size must be greator or equal "
402
                "to the number of tasks in the memory "
403
                "and current data."
404
            )
405

406
        self.data_batch_sizes, _ = self._get_batch_sizes(
4✔
407
            data, batch_size, 0, False
408
        )
409

410
        # Create dataloader for memory items
411
        if task_balanced_dataloader:
4✔
412
            memory_task_labels = getattr(self.memory, 'targets_task_labels')
×
413
            assert isinstance(memory_task_labels, DataAttribute)
×
414
            num_keys = len(memory_task_labels.uniques)
×
415
            single_group_batch_size = batch_size_mem // num_keys
×
416
            remaining_example = batch_size_mem % num_keys
×
417
        else:
418
            single_group_batch_size = batch_size_mem
4✔
419
            remaining_example = 0
4✔
420

421
        self.memory_batch_sizes, _ = self._get_batch_sizes(
4✔
422
            memory,
423
            single_group_batch_size,
424
            remaining_example,
425
            task_balanced_dataloader,
426
        )
427

428
        loaders_for_len_estimation = []
4✔
429

430
        if isinstance(self.data_batch_sizes, int):
4✔
431
            loaders_for_len_estimation.append(
4✔
432
                _make_data_loader(
433
                    data,
434
                    distributed_sampling,
435
                    kwargs,
436
                    self.data_batch_sizes,
437
                    force_no_workers=True,
438
                )[0]
439
            )
440
        else:
441
            # Task balanced
442
            data_task_set: Mapping[int, AvalancheDataset] = \
×
443
                getattr(data, 'task_set')
444
            for task_id in data_task_set:
×
445
                dataset = data_task_set[task_id]
×
446
                mb_sz = self.data_batch_sizes[task_id]
×
447

448
                loaders_for_len_estimation.append(
×
449
                    _make_data_loader(
450
                        dataset,
451
                        distributed_sampling,
452
                        kwargs,
453
                        mb_sz,
454
                        force_no_workers=True,
455
                    )[0]
456
                )
457

458
        self.max_len = max([len(d) for d in loaders_for_len_estimation])
4✔
459

460
    def __iter__(self):
4✔
461
        loader_data, sampler_data = self._create_loaders_and_samplers(
4✔
462
            self.data, self.data_batch_sizes
463
        )
464

465
        loader_memory, sampler_memory = self._create_loaders_and_samplers(
4✔
466
            self.memory, self.memory_batch_sizes
467
        )
468

469
        iter_data_dataloaders = {}
4✔
470
        iter_buffer_dataloaders = {}
4✔
471

472
        for t in loader_data.keys():
4✔
473
            iter_data_dataloaders[t] = iter(loader_data[t])
4✔
474
        for t in loader_memory.keys():
4✔
475
            iter_buffer_dataloaders[t] = iter(loader_memory[t])
4✔
476

477
        max_len = max([len(d) for d in loader_data.values()])
4✔
478

479
        try:
4✔
480
            for it in range(max_len):
4✔
481
                mb_curr: List[Any] = []
4✔
482
                ReplayDataLoader._get_mini_batch_from_data_dict(
4✔
483
                    iter_data_dataloaders,
484
                    sampler_data,
485
                    loader_data,
486
                    self.oversample_small_tasks,
487
                    mb_curr,
488
                )
489

490
                ReplayDataLoader._get_mini_batch_from_data_dict(
4✔
491
                    iter_buffer_dataloaders,
492
                    sampler_memory,
493
                    loader_memory,
494
                    self.oversample_small_tasks,
495
                    mb_curr,
496
                )
497

498
                yield self.collate_fn(mb_curr)
4✔
499
        except StopIteration:
4✔
500
            return
×
501

502
    def __len__(self):
4✔
503
        return self.max_len
4✔
504

505
    @staticmethod
4✔
506
    def _get_mini_batch_from_data_dict(
3✔
507
        iter_dataloaders,
508
        iter_samplers,
509
        loaders_dict,
510
        oversample_small_tasks,
511
        mb_curr,
512
    ):
513
        # list() is necessary because we may remove keys from the
514
        # dictionary. This would break the generator.
515
        for t in list(iter_dataloaders.keys()):
4✔
516
            t_loader = iter_dataloaders[t]
4✔
517
            t_sampler = iter_samplers[t]
4✔
518
            try:
4✔
519
                tbatch = next(t_loader)
4✔
520
            except StopIteration:
4✔
521
                # StopIteration is thrown if dataset ends.
522
                # reinitialize data loader
523
                if oversample_small_tasks:
4✔
524
                    # reinitialize data loader
525
                    if isinstance(t_sampler, DistributedSampler):
4✔
526
                        # Manage shuffling in DistributedSampler
527
                        t_sampler.set_epoch(t_sampler.epoch + 1)
×
528

529
                    iter_dataloaders[t] = iter(loaders_dict[t])
4✔
530
                    tbatch = next(iter_dataloaders[t])
4✔
531
                else:
532
                    del iter_dataloaders[t]
4✔
533
                    del iter_samplers[t]
4✔
534
                    continue
4✔
535
            mb_curr.extend(tbatch)
4✔
536

537
    def _create_loaders_and_samplers(self, data, batch_sizes):
4✔
538
        loaders = dict()
4✔
539
        samplers = dict()
4✔
540

541
        if isinstance(batch_sizes, int):
4✔
542
            loader, sampler = _make_data_loader(
4✔
543
                data,
544
                self.distributed_sampling,
545
                self.loader_kwargs,
546
                batch_sizes,
547
            )
548
            loaders[0] = loader
4✔
549
            samplers[0] = sampler
4✔
550
        else:
551
            for task_id in data.task_set:
×
552
                dataset = data.task_set[task_id]
×
553
                mb_sz = batch_sizes[task_id]
×
554

555
                loader, sampler = _make_data_loader(
×
556
                    dataset,
557
                    self.distributed_sampling,
558
                    self.loader_kwargs,
559
                    mb_sz,
560
                )
561

562
                loaders[task_id] = loader
×
563
                samplers[task_id] = sampler
×
564
        return loaders, samplers
4✔
565

566
    @staticmethod
4✔
567
    def _get_batch_sizes(
3✔
568
        data_dict,
569
        single_exp_batch_size,
570
        remaining_example,
571
        task_balanced_dataloader,
572
    ):
573
        batch_sizes = dict()
4✔
574
        if task_balanced_dataloader:
4✔
575
            for task_id in data_dict.task_set:
×
576
                current_batch_size = single_exp_batch_size
×
577
                if remaining_example > 0:
×
578
                    current_batch_size += 1
×
579
                    remaining_example -= 1
×
580
                batch_sizes[task_id] = current_batch_size
×
581
        else:
582
            # Current data is loaded without task balancing
583
            batch_sizes = single_exp_batch_size
4✔
584
        return batch_sizes, remaining_example
4✔
585

586

587
def _make_data_loader(
4✔
588
    dataset: Dataset,
589
    distributed_sampling: bool,
590
    data_loader_args: Dict[str, Any],
591
    batch_size: int,
592
    force_no_workers: bool = False,
593
):
594
    data_loader_args = data_loader_args.copy()
4✔
595

596
    collate_from_data_or_kwargs(dataset, data_loader_args)
4✔
597

598
    if force_no_workers:
4✔
599
        data_loader_args['num_workers'] = 0
4✔
600
        if 'persistent_workers' in data_loader_args:
4✔
601
            data_loader_args['persistent_workers'] = False
4✔
602
        if 'prefetch_factor' in data_loader_args:
4✔
603
            data_loader_args['prefetch_factor'] = 2
×
604

605
    if DistributedHelper.is_distributed and distributed_sampling:
4✔
606
        # Note: shuffle only goes in the sampler, while
607
        # drop_last must be passed to both the sampler
608
        # and the DataLoader
NEW
609
        drop_last = data_loader_args.pop("drop_last", False)
×
UNCOV
610
        sampler = DistributedSampler(
×
611
            dataset,
612
            shuffle=data_loader_args.pop("shuffle", True),
613
            drop_last=drop_last,
614
        )
615
        data_loader = DataLoader(
×
616
            dataset,
617
            sampler=sampler,
618
            batch_size=batch_size,
619
            drop_last=drop_last,
620
            **data_loader_args
621
        )
622
    else:
623
        sampler = None
4✔
624
        data_loader = DataLoader(
4✔
625
            dataset, batch_size=batch_size, **data_loader_args
626
        )
627

628
    return data_loader, sampler
4✔
629

630

631
__all__ = [
4✔
632
    "detection_collate_fn",
633
    "detection_collate_mbatches_fn",
634
    "collate_from_data_or_kwargs",
635
    "TaskBalancedDataLoader",
636
    "GroupBalancedDataLoader",
637
    "ReplayDataLoader",
638
    "GroupBalancedInfiniteDataLoader",
639
]
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc