• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ContinualAI / avalanche / 5268393053

pending completion
5268393053

Pull #1397

github

web-flow
Merge 60d244754 into e91562200
Pull Request #1397: Specialize benchmark creation helpers

417 of 538 new or added lines in 30 files covered. (77.51%)

43 existing lines in 5 files now uncovered.

16586 of 22630 relevant lines covered (73.29%)

2.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

41.82
/avalanche/benchmarks/utils/detection_dataset.py
1
################################################################################
2
# Copyright (c) 2021 ContinualAI.                                              #
3
# Copyrights licensed under the MIT License.                                   #
4
# See the accompanying LICENSE file for terms.                                 #
5
#                                                                              #
6
# Date: 12-05-2020                                                             #
7
# Author(s): Lorenzo Pellegrini, Antonio Carta                                 #
8
# E-mail: contact@continualai.org                                              #
9
# Website: avalanche.continualai.org                                           #
10
################################################################################
11

12
"""
4✔
13
This module contains the implementation of the ``DetectionDataset``,
14
which is the dataset used for supervised continual learning benchmarks.
15
DetectionDatasets are ``AvalancheDatasets`` that manage targets and task
16
labels automatically. Concatenation and subsampling operations are optimized
17
to be used frequently, as is common in replay strategies.
18
"""
19
from functools import partial
4✔
20
from typing import (
4✔
21
    List,
22
    Any,
23
    Sequence,
24
    Union,
25
    Optional,
26
    TypeVar,
27
    Callable,
28
    Dict,
29
    Tuple,
30
    Mapping,
31
    overload,
32
)
33

34
import torch
4✔
35
from torch import Tensor
4✔
36
from torch.utils.data.dataset import Subset, ConcatDataset
4✔
37

38
from avalanche.benchmarks.utils.utils import (
4✔
39
    TaskSet,
40
    _init_task_labels,
41
    _init_transform_groups,
42
    _split_user_def_targets,
43
    _split_user_def_task_label,
44
    _traverse_supported_dataset,
45
)
46

47
from .collate_functions import detection_collate_fn
4✔
48
from .data import AvalancheDataset
4✔
49
from .data_attribute import DataAttribute
4✔
50
from .dataset_definitions import (
4✔
51
    IDataset,
52
    IDatasetWithTargets, )
53
from .dataset_utils import (
4✔
54
    SubSequence,
55
)
56
from .flat_data import ConstantSequence
4✔
57
from .transform_groups import (
4✔
58
    TransformGroupDef,
59
    DefaultTransformGroups,
60
    TransformGroups,
61
    XTransform,
62
    YTransform,
63
)
64

65
T_co = TypeVar("T_co", covariant=True)
4✔
66
TAvalancheDataset = TypeVar("TAvalancheDataset", bound="AvalancheDataset")
4✔
67
TTargetType = Dict[str, Tensor]
4✔
68

69

70
# Image (tensor), target dict, task label
71
DetectionExampleT = Tuple[Tensor, TTargetType, int]
4✔
72
TDetectionDataset = TypeVar("TDetectionDataset", bound="DetectionDataset")
4✔
73

74

75
class DetectionDataset(AvalancheDataset[T_co]):
4✔
76

77
    def __init__(
4✔
78
            self,
79
            datasets: Sequence[IDataset[T_co]],
80
            *,
81
            indices: Optional[List[int]] = None,
82
            data_attributes: Optional[List[DataAttribute]] = None,
83
            transform_groups: Optional[TransformGroups] = None,
84
            frozen_transform_groups: Optional[TransformGroups] = None,
85
            collate_fn: Optional[Callable[[List], Any]] = None):
86
        super().__init__(
4✔
87
            datasets=datasets,
88
            indices=indices,
89
            data_attributes=data_attributes,
90
            transform_groups=transform_groups,
91
            frozen_transform_groups=frozen_transform_groups,
92
            collate_fn=collate_fn
93
        )
94
        
95
        assert hasattr(self, 'targets'), \
4✔
96
            'The supervised version of the ClassificationDataset requires ' + \
97
            'the targets field'
98
        assert hasattr(self, 'targets_task_labels'), \
4✔
99
            'The supervised version of the ClassificationDataset requires ' + \
100
            'the targets_task_labels field'
101
        
102
    @property
4✔
103
    def targets(self) -> DataAttribute[TTargetType]:
4✔
104
        return self._data_attributes['targets']
4✔
105

106
    @property
4✔
107
    def targets_task_labels(self) -> DataAttribute[int]:
4✔
108
        return self._data_attributes['targets_task_labels']
4✔
109
    
110
    @property
4✔
111
    def task_pattern_indices(self) -> Dict[int, Sequence[int]]:
4✔
112
        """A dictionary mapping task ids to their sample indices."""
113
        # Assumes that targets_task_labels exists
NEW
114
        t_labels: DataAttribute[int] = self.targets_task_labels
×
NEW
115
        return t_labels.val_to_idx
×
116

117
    @property
4✔
118
    def task_set(self: TDetectionDataset) -> TaskSet[TDetectionDataset]:
4✔
119
        """Returns the dataset's ``TaskSet``, which is a mapping <task-id,
120
        task-dataset>."""
NEW
121
        return TaskSet(self)
×
122

123
    def subset(self, indices):
4✔
NEW
124
        data = super().subset(indices)
×
NEW
125
        return data.with_transforms(self._transform_groups.current_group)
×
126

127
    def concat(self, other):
4✔
NEW
128
        data = super().concat(other)
×
NEW
129
        return data.with_transforms(self._transform_groups.current_group)
×
130
    
131
    def __hash__(self):
4✔
NEW
132
        return id(self)
×
133

134

135
SupportedDetectionDataset = Union[
4✔
136
    IDatasetWithTargets,
137
    Subset,
138
    ConcatDataset,
139
    DetectionDataset
140
]
141

142

143
def make_detection_dataset(
4✔
144
    dataset: SupportedDetectionDataset,
145
    *,
146
    transform: Optional[XTransform] = None,
147
    target_transform: Optional[YTransform] = None,
148
    transform_groups: Optional[Mapping[str, TransformGroupDef]] = None,
149
    initial_transform_group: Optional[str] = None,
150
    task_labels: Optional[Union[int, Sequence[int]]] = None,
151
    targets: Optional[Sequence[TTargetType]] = None,
152
    collate_fn: Optional[Callable[[List], Any]] = None
153
) -> DetectionDataset:
154
    """Avalanche Detection Dataset.
155

156
    Supervised continual learning benchmarks in Avalanche return instances of
157
    this dataset, but it can also be used in a completely standalone manner.
158

159
    This dataset applies input/target transformations, it supports
160
    slicing and advanced indexing and it also contains useful fields as
161
    `targets`, which contains the pattern dictionaries, and
162
    `targets_task_labels`, which contains the pattern task labels.
163
    The `task_set` field can be used to obtain a the subset of patterns
164
    labeled with a given task label.
165

166
    This dataset can also be used to apply several advanced operations involving
167
    transformations. For instance, it allows the user to add and replace
168
    transformations, freeze them so that they can't be changed, etc.
169

170
    This dataset also allows the user to keep distinct transformations groups.
171
    Simply put, a transformation group is a pair of transform+target_transform
172
    (exactly as in torchvision datasets). This dataset natively supports keeping
173
    two transformation groups: the first, 'train', contains transformations
174
    applied to training patterns. Those transformations usually involve some
175
    kind of data augmentation. The second one is 'eval', that will contain
176
    transformations applied to test patterns. Having both groups can be
177
    useful when, for instance, in need to test on the training data (as this
178
    process usually involves removing data augmentation operations). Switching
179
    between transformations can be easily achieved by using the
180
    :func:`train` and :func:`eval` methods.
181

182
    Moreover, arbitrary transformation groups can be added and used. For more
183
    info see the constructor and the :func:`with_transforms` method.
184

185
    This dataset will try to inherit the task labels from the input
186
    dataset. If none are available and none are given via the `task_labels`
187
    parameter, each pattern will be assigned a default task label 0.
188

189
    Creates a ``AvalancheDataset`` instance.
190

191
    :param dataset: The dataset to decorate. Beware that
192
        AvalancheDataset will not overwrite transformations already
193
        applied by this dataset.
194
    :param transform: A function/transform that takes the X value of a
195
        pattern from the original dataset and returns a transformed version.
196
    :param target_transform: A function/transform that takes in the target
197
        and transforms it.
198
    :param transform_groups: A dictionary containing the transform groups.
199
        Transform groups are used to quickly switch between training and
200
        eval (test) transformations. This becomes useful when in need to
201
        test on the training dataset as test transformations usually don't
202
        contain random augmentations. ``AvalancheDataset`` natively supports
203
        the 'train' and 'eval' groups by calling the ``train()`` and
204
        ``eval()`` methods. When using custom groups one can use the
205
        ``with_transforms(group_name)`` method instead. Defaults to None,
206
        which means that the current transforms will be used to
207
        handle both 'train' and 'eval' groups (just like in standard
208
        ``torchvision`` datasets).
209
    :param initial_transform_group: The name of the initial transform group
210
        to be used. Defaults to None, which means that the current group of
211
        the input dataset will be used (if an AvalancheDataset). If the
212
        input dataset is not an AvalancheDataset, then 'train' will be
213
        used.
214
    :param task_labels: The task label of each instance. Must be a sequence
215
        of ints, one for each instance in the dataset. Alternatively can be
216
        a single int value, in which case that value will be used as the
217
        task label for all the instances. Defaults to None, which means that
218
        the dataset will try to obtain the task labels from the original
219
        dataset. If no task labels could be found, a default task label
220
        0 will be applied to all instances.
221
    :param targets: The dictionary of detection boxes of each pattern.
222
        Defaults to None, which means that the targets will be retrieved from
223
        the dataset (if possible).
224
    :param collate_fn: The function to use when slicing to merge single
225
        patterns. This function is the function used in the data loading
226
        process, too. If None, the constructor will check if a
227
        `collate_fn` field exists in the dataset. If no such field exists,
228
        the default collate function for detection will be used.
229
    """
230

231
    transform_gs = _init_transform_groups(
4✔
232
        transform_groups,
233
        transform,
234
        target_transform,
235
        initial_transform_group,
236
        dataset,
237
    )
238
    targets_data: Optional[DataAttribute[TTargetType]] = \
4✔
239
        _init_targets(dataset, targets)
240
    task_labels_data: Optional[DataAttribute[int]] = \
4✔
241
        _init_task_labels(dataset, task_labels)
242

243
    das: List[DataAttribute] = []
4✔
244
    if targets_data is not None:
4✔
245
        das.append(targets_data)
4✔
246
    if task_labels_data is not None:
4✔
247
        das.append(task_labels_data)
4✔
248
    
249
    if collate_fn is None:
4✔
250
        collate_fn = getattr(dataset, 'collate_fn', detection_collate_fn)
4✔
251

252
    data: DetectionDataset = DetectionDataset(
4✔
253
        [dataset],
254
        data_attributes=das if len(das) > 0 else None,
255
        transform_groups=transform_gs,
256
        collate_fn=collate_fn,
257
    )
258
    
259
    if initial_transform_group is not None:
4✔
260
        return data.with_transforms(initial_transform_group)
4✔
261
    else:
262
        return data
4✔
263

264

265
def _init_targets(dataset, targets, check_shape=True) -> \
4✔
266
        Optional[DataAttribute[TTargetType]]:
267
    if targets is not None:
4✔
268
        # User defined targets always take precedence
269
        if len(targets) != len(dataset) and check_shape:
4✔
270
            raise ValueError(
×
271
                "Invalid amount of target labels. It must be equal to the "
272
                "number of patterns in the dataset. Got {}, expected "
273
                "{}!".format(len(targets), len(dataset))
274
            )
275
        return DataAttribute(targets, "targets")
4✔
276

277
    targets = _traverse_supported_dataset(
4✔
278
        dataset, _select_targets)
279

280
    if targets is None:
4✔
281
        return None
×
282
    
283
    return DataAttribute(targets, "targets")
4✔
284

285

286
def _detection_class_mapping_transform(class_mapping, example_target_dict):
4✔
287
    example_target_dict = dict(example_target_dict)
×
288

289
    # example_target_dict["labels"] is a tensor containing one label
290
    # for each bounding box in the image. We need to remap each of them
291
    example_target_labels = example_target_dict["labels"]
×
292
    example_mapped_labels = [class_mapping[int(el)] for el
×
293
                             in example_target_labels]
294

295
    if isinstance(example_target_labels, Tensor):
×
296
        example_mapped_labels = torch.as_tensor(example_mapped_labels)
×
297

298
    example_target_dict["labels"] = example_mapped_labels
×
299

300
    return example_target_dict
×
301

302

303
def detection_subset(
4✔
304
    dataset: SupportedDetectionDataset,
305
    indices: Optional[Sequence[int]] = None,
306
    *,
307
    class_mapping: Optional[Sequence[int]] = None,
308
    transform: Optional[XTransform] = None,
309
    target_transform: Optional[YTransform] = None,
310
    transform_groups: Optional[Mapping[str, 
311
                                       Tuple[XTransform, YTransform]]] = None,
312
    initial_transform_group: Optional[str] = None,
313
    task_labels: Optional[Union[int, Sequence[int]]] = None,
314
    targets: Optional[Sequence[TTargetType]] = None,
315
    collate_fn: Optional[Callable[[List], Any]] = None
316
) -> DetectionDataset:
317
    """Creates an ``AvalancheSubset`` instance.
318

319
    For simple subset operations you should use the method
320
    `dataset.subset(indices)`.
321
    Use this constructor only if you need to redefine transformation or
322
    class/task labels.
323

324
    A Dataset that behaves like a PyTorch :class:`torch.utils.data.Subset`.
325
    This Dataset also supports transformations, slicing, advanced indexing,
326
    the targets field, class mapping and all the other goodies listed in
327
    :class:`AvalancheDataset`.
328

329
    :param dataset: The whole dataset.
330
    :param indices: Indices in the whole set selected for subset. Can
331
        be None, which means that the whole dataset will be returned.
332
    :param class_mapping: A list that, for each possible class label value,
333
        contains its corresponding remapped value. Can be None.
334
    :param transform: A function/transform that takes the X value of a
335
        pattern from the original dataset and returns a transformed version.
336
    :param target_transform: A function/transform that takes in the target
337
        and transforms it.
338
    :param transform_groups: A dictionary containing the transform groups.
339
        Transform groups are used to quickly switch between training and
340
        eval (test) transformations. This becomes useful when in need to
341
        test on the training dataset as test transformations usually don't
342
        contain random augmentations. ``AvalancheDataset`` natively supports
343
        the 'train' and 'eval' groups by calling the ``train()`` and
344
        ``eval()`` methods. When using custom groups one can use the
345
        ``with_transforms(group_name)`` method instead. Defaults to None,
346
        which means that the current transforms will be used to
347
        handle both 'train' and 'eval' groups (just like in standard
348
        ``torchvision`` datasets).
349
    :param initial_transform_group: The name of the initial transform group
350
        to be used. Defaults to None, which means that the current group of
351
        the input dataset will be used (if an AvalancheDataset). If the
352
        input dataset is not an AvalancheDataset, then 'train' will be
353
        used.
354
    :param task_labels: The task label for each instance. Must be a sequence
355
        of ints, one for each instance in the dataset. This can either be a
356
        list of task labels for the original dataset or the list of task
357
        labels for the instances of the subset (an automatic detection will
358
        be made). In the unfortunate case in which the original dataset and
359
        the subset contain the same amount of instances, then this parameter
360
        is considered to contain the task labels of the subset.
361
        Alternatively can be a single int value, in which case
362
        that value will be used as the task label for all the instances.
363
        Defaults to None, which means that the dataset will try to
364
        obtain the task labels from the original dataset. If no task labels
365
        could be found, a default task label 0 will be applied to all
366
        instances.
367
    :param targets: The target dictionary of each pattern. Defaults to None,
368
        which means that the targets will be retrieved from the dataset (if
369
        possible). This can either be a list of target dictionaries for the
370
        original dataset or the list of target dictionaries for the instances
371
        of the subset (an automatic detection will be made). In the
372
        unfortunate case in which the original dataset and the subset contain
373
        the same amount of instances, then this parameter is considered to
374
        contain the target dictionaries of the subset.
375
    :param collate_fn: The function to use when slicing to merge single
376
        patterns. This function is the function used in the data loading
377
        process, too. If None, the constructor will check if a
378
        `collate_fn` field exists in the dataset. If no such field exists,
379
        the default collate function for detection will be used
380
    """
381

382
    if isinstance(dataset, DetectionDataset):
×
UNCOV
383
        if (
×
384
            class_mapping is None
385
            and transform is None
386
            and target_transform is None
387
            and transform_groups is None
388
            and initial_transform_group is None
389
            and task_labels is None
390
            and targets is None
391
            and collate_fn is None
392
        ):
393
            return dataset.subset(indices)
×
394

395
    targets_data: Optional[DataAttribute[TTargetType]] = \
×
396
        _init_targets(dataset, targets, check_shape=False)
397
    task_labels_data: Optional[DataAttribute[int]] = \
×
398
        _init_task_labels(dataset, task_labels, check_shape=False)
399
    
400
    del task_labels
×
401
    del targets
×
402

403
    transform_gs = _init_transform_groups(
×
404
        transform_groups,
405
        transform,
406
        target_transform,
407
        initial_transform_group,
408
        dataset,
409
    )
410

411
    if initial_transform_group is not None and isinstance(
×
412
        dataset, AvalancheDataset
413
    ):
414
        dataset = dataset.with_transforms(initial_transform_group)
×
415

416
    if class_mapping is not None:  # update targets
×
417

418
        if targets_data is None:
×
419
            # Should not happen
420
            # The following line usually fails
421
            targets_data = dataset.targets  # type: ignore
×
422
        
423
        assert targets_data is not None, \
×
424
            'To execute the class mapping, a list of targets is required.'
425
        
426
        tgs = [
×
427
            _detection_class_mapping_transform(
428
                class_mapping, example_target_dict)
429
            for example_target_dict in targets_data]
430

431
        targets_data = DataAttribute(tgs, "targets")
×
432

433
    if class_mapping is not None:
×
434
        mapping_fn = partial(_detection_class_mapping_transform, class_mapping)
×
435
        frozen_transform_groups = DefaultTransformGroups(
×
436
            (None, mapping_fn)
437
        )
438
    else:
439
        frozen_transform_groups = None
×
440

441
    das: List[DataAttribute] = []
×
442
    if targets_data is not None:
×
443
        das.append(targets_data)
×
444
    if task_labels_data is not None:
×
445
        das.append(task_labels_data)
×
446

UNCOV
447
    if collate_fn is None:
×
448
        collate_fn = detection_collate_fn
×
449

NEW
450
    return DetectionDataset(
×
451
        [dataset],
452
        indices=list(indices) if indices is not None else None,
453
        data_attributes=das if len(das) > 0 else None,
454
        transform_groups=transform_gs,
455
        frozen_transform_groups=frozen_transform_groups,
456
        collate_fn=collate_fn,
457
    )
458

459

460
def concat_detection_datasets(
4✔
461
    datasets: Sequence[SupportedDetectionDataset],
462
    *,
463
    transform: Optional[XTransform] = None,
464
    target_transform: Optional[YTransform] = None,
465
    transform_groups: Optional[Mapping[str, 
466
                                       Tuple[XTransform, YTransform]]] = None,
467
    initial_transform_group: Optional[str] = None,
468
    task_labels: Optional[Union[int,
469
                                Sequence[int],
470
                                Sequence[Sequence[int]]]] = None,
471
    targets: Optional[Union[
472
        Sequence[TTargetType], Sequence[Sequence[TTargetType]]
473
    ]] = None,
474
    collate_fn: Optional[Callable[[List], Any]] = None
475
) -> DetectionDataset:
476
    """Creates a ``AvalancheConcatDataset`` instance.
477

478
    For simple subset operations you should use the method
479
    `dataset.concat(other)` or
480
    `concat_datasets` from `avalanche.benchmarks.utils.utils`.
481
    Use this constructor only if you need to redefine transformation or
482
    class/task labels.
483

484
    A Dataset that behaves like a PyTorch
485
    :class:`torch.utils.data.ConcatDataset`. However, this Dataset also supports
486
    transformations, slicing, advanced indexing and the targets field and all
487
    the other goodies listed in :class:`AvalancheDataset`.
488

489
    This dataset guarantees that the operations involving the transformations
490
    and transformations groups are consistent across the concatenated dataset
491
    (if they are subclasses of :class:`AvalancheDataset`).
492

493
    :param datasets: A collection of datasets.
494
    :param transform: A function/transform that takes the X value of a
495
        pattern from the original dataset and returns a transformed version.
496
    :param target_transform: A function/transform that takes in the target
497
        and transforms it.
498
    :param transform_groups: A dictionary containing the transform groups.
499
        Transform groups are used to quickly switch between training and
500
        eval (test) transformations. This becomes useful when in need to
501
        test on the training dataset as test transformations usually don't
502
        contain random augmentations. ``AvalancheDataset`` natively supports
503
        the 'train' and 'eval' groups by calling the ``train()`` and
504
        ``eval()`` methods. When using custom groups one can use the
505
        ``with_transforms(group_name)`` method instead. Defaults to None,
506
        which means that the current transforms will be used to
507
        handle both 'train' and 'eval' groups (just like in standard
508
        ``torchvision`` datasets).
509
    :param initial_transform_group: The name of the initial transform group
510
        to be used. Defaults to None, which means that if all
511
        AvalancheDatasets in the input datasets list agree on a common
512
        group (the "current group" is the same for all datasets), then that
513
        group will be used as the initial one. If the list of input datasets
514
        does not contain an AvalancheDataset or if the AvalancheDatasets
515
        do not agree on a common group, then 'train' will be used.
516
    :param targets: The label of each pattern. Can either be a sequence of
517
        labels or, alternatively, a sequence containing sequences of labels
518
        (one for each dataset to be concatenated). Defaults to None, which
519
        means that the targets will be retrieved from the datasets (if
520
        possible).
521
    :param task_labels: The task labels for each pattern. Must be a sequence
522
        of ints, one for each pattern in the dataset. Alternatively, task
523
        labels can be expressed as a sequence containing sequences of ints
524
        (one for each dataset to be concatenated) or even a single int,
525
        in which case that value will be used as the task label for all
526
        instances. Defaults to None, which means that the dataset will try
527
        to obtain the task labels from the original datasets. If no task
528
        labels could be found for a dataset, a default task label 0 will
529
        be applied to all patterns of that dataset.
530
    :param collate_fn: The function to use when slicing to merge single
531
        patterns. This function is the function used in the data loading
532
        process, too. If None, the constructor will check if a `collate_fn`
533
        field exists in the first dataset. If no such field exists, the
534
        default collate function for detection  will be used.
535
        Beware that the chosen collate function will be applied to all
536
        the concatenated datasets even if a different collate is defined
537
        in different datasets.
538
    """
539
    dds = []
×
540
    per_dataset_task_labels = _split_user_def_task_label(
×
541
        datasets, task_labels
542
    )
543

544
    per_dataset_targets = _split_user_def_targets(
×
545
        datasets,
546
        targets,
547
        lambda x: isinstance(x, dict)
548
    )
549
    
550
    for dd, dataset_task_labels, dataset_targets in \
×
551
            zip(datasets, per_dataset_task_labels, per_dataset_targets):
552
        dd = make_detection_dataset(
×
553
            dd,
554
            transform=transform,
555
            target_transform=target_transform,
556
            transform_groups=transform_groups,
557
            initial_transform_group=initial_transform_group,
558
            task_labels=dataset_task_labels,
559
            targets=dataset_targets,
560
            collate_fn=collate_fn,
561
        )
562
        dds.append(dd)
×
563
    
564
    if (
×
565
        transform is None
566
        and target_transform is None
567
        and transform_groups is None
568
        and initial_transform_group is None
569
        and task_labels is None
570
        and targets is None
571
        and collate_fn is None
572
        and len(datasets) > 0
573
    ):
574
        d0 = datasets[0]
×
575
        if isinstance(d0, DetectionDataset):
×
576
            for d1 in datasets[1:]:
×
577
                d0 = d0.concat(d1)
×
578
            return d0
×
579

580
    das: List[DataAttribute] = []
×
581
    if len(dds) > 0:
×
582
        #######################################
583
        # TRANSFORMATION GROUPS
584
        #######################################
585
        transform_groups_obj = _init_transform_groups(
×
586
            transform_groups,
587
            transform,
588
            target_transform,
589
            initial_transform_group,
590
            dds[0],
591
        )
592

593
        # Find common "current_group" or use "train"
594
        if initial_transform_group is None:
×
595
            uniform_group = None
×
596
            for d_set in datasets:
×
597
                if isinstance(d_set, AvalancheDataset):
×
598
                    if uniform_group is None:
×
599
                        uniform_group = d_set._transform_groups.current_group
×
600
                    else:
601
                        if (
×
602
                            uniform_group
603
                            != d_set._transform_groups.current_group
604
                        ):
605
                            uniform_group = None
×
606
                            break
×
607

608
            if uniform_group is None:
×
609
                initial_transform_group = "train"
×
610
            else:
611
                initial_transform_group = uniform_group
×
612

613
        #######################################
614
        # DATA ATTRIBUTES
615
        #######################################
616

617
        totlen = sum([len(d) for d in datasets])
×
618
        if (
×
619
            task_labels is not None
620
        ):  # User defined targets always take precedence
621
            
622
            all_labels: IDataset[int]
623
            if isinstance(task_labels, int):
×
624
                all_labels = ConstantSequence(task_labels, totlen)
×
625
            else:
626
                all_labels_lst = []
×
627
                for dd, dataset_task_labels in \
×
628
                        zip(dds, per_dataset_task_labels):
629
                    assert dataset_task_labels is not None
×
630

631
                    # We already checked that len(t_labels) == len(dataset)
632
                    # (done in _split_user_def_task_label)
633
                    if isinstance(dataset_task_labels, int):
×
634
                        all_labels_lst.extend([dataset_task_labels] * len(dd))
×
635
                    else:
636
                        all_labels_lst.extend(dataset_task_labels)
×
637
                all_labels = all_labels_lst
×
638
            das.append(
×
639
                DataAttribute(
640
                    all_labels, "targets_task_labels", use_in_getitem=True
641
                )
642
            )
643

644
        if targets is not None:  # User defined targets always take precedence
×
645
            all_targets_lst: List[TTargetType] = []
×
646
            for dd, dataset_targets in zip(dds, per_dataset_targets):
×
647
                assert dataset_targets is not None
×
648

649
                # We already checked that len(targets) == len(dataset)
650
                # (done in _split_user_def_targets)
651
                all_targets_lst.extend(dataset_targets)
×
652
            das.append(DataAttribute(all_targets_lst, "targets"))
×
653
    else:
654
        transform_groups_obj = None
×
655
        initial_transform_group = 'train'
×
656

657
    data = DetectionDataset(
×
658
        dds,
659
        transform_groups=transform_groups_obj,
660
        data_attributes=das if len(das) > 0 else None
661
    )
662
    return data.with_transforms(initial_transform_group)
×
663

664

665
def _select_targets(
4✔
666
        dataset: SupportedDetectionDataset,
667
        indices: Optional[List[int]]) -> Sequence[TTargetType]:
668
    if hasattr(dataset, "targets"):
4✔
669
        # Standard supported dataset
670
        found_targets = dataset.targets
4✔
671
    else:
672
        raise ValueError(
×
673
            "Unsupported dataset: must have a valid targets field"
674
        )
675

676
    if indices is not None:
4✔
677
        found_targets = SubSequence(found_targets, indices=indices)
×
678

679
    return found_targets
4✔
680

681

682
__all__ = [
4✔
683
    "SupportedDetectionDataset",
684
    "DetectionDataset",
685
    "make_detection_dataset",
686
    "detection_subset",
687
    "concat_detection_datasets"
688
]
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc