• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ContinualAI / avalanche / 5268393053

pending completion
5268393053

Pull #1397

github

web-flow
Merge 60d244754 into e91562200
Pull Request #1397: Specialize benchmark creation helpers

417 of 538 new or added lines in 30 files covered. (77.51%)

43 existing lines in 5 files now uncovered.

16586 of 22630 relevant lines covered (73.29%)

2.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.75
/avalanche/benchmarks/scenarios/classification_benchmark_creation.py
1
from typing import (
4✔
2
    Any,
3
    Callable,
4
    Dict,
5
    Mapping,
6
    Optional,
7
    Sequence,
8
    Tuple,
9
    TypeVar,
10
)
11
from avalanche.benchmarks.scenarios.dataset_scenario import (
4✔
12
    DatasetScenario,
13
    TStreamsUserDict,
14
)
15
from avalanche.benchmarks.scenarios.generic_benchmark_creation import (
4✔
16
    _make_classification_scenario,
17
    FileAndLabel,
18
    DatasetFactory,
19
    LazyStreamDefinition,
20
    create_generic_benchmark_from_filelists,
21
    create_generic_benchmark_from_paths,
22
    create_generic_benchmark_from_tensor_lists,
23
    create_lazy_generic_benchmark,
24
    create_multi_dataset_generic_benchmark,
25
)
26

27
from avalanche.benchmarks.utils.classification_dataset import (
4✔
28
    SupportedDataset,
29
    make_classification_dataset,
30
)
31
from avalanche.benchmarks.utils.transform_groups import XTransform, YTransform
4✔
32

33

34
TDatasetScenario = TypeVar(
4✔
35
    'TDatasetScenario',
36
    bound='DatasetScenario')
37

38

39
def create_multi_dataset_classification_benchmark(
4✔
40
    train_datasets: Sequence[SupportedDataset],
41
    test_datasets: Sequence[SupportedDataset],
42
    *,
43
    other_streams_datasets: Optional[
44
        Mapping[str, Sequence[SupportedDataset]]] = None,
45
    complete_test_set_only: bool = False,
46
    train_transform: XTransform = None,
47
    train_target_transform: YTransform = None,
48
    eval_transform: XTransform = None,
49
    eval_target_transform: YTransform = None,
50
    other_streams_transforms: Optional[
51
        Mapping[str, Tuple[XTransform, YTransform]]] = None,
52
    dataset_factory: DatasetFactory = make_classification_dataset,
53
    benchmark_factory: Callable[
54
        [
55
            TStreamsUserDict,
56
            bool
57
        ], TDatasetScenario
58
    ] = _make_classification_scenario  # type: ignore
59
) -> TDatasetScenario:
60
    """
61
    Creates a classification benchmark instance given a list of datasets.
62
    Each dataset will be considered as a separate experience.
63

64
    Contents of the datasets must already be set, including task labels.
65
    Transformations will be applied if defined.
66

67
    For additional info, please refer to
68
    :func:`create_multi_dataset_generic_benchmark`.
69
    """
70
    return create_multi_dataset_generic_benchmark(
4✔
71
        train_datasets=train_datasets,
72
        test_datasets=test_datasets,
73
        other_streams_datasets=other_streams_datasets,
74
        complete_test_set_only=complete_test_set_only,
75
        train_transform=train_transform,
76
        train_target_transform=train_target_transform,
77
        eval_transform=eval_transform,
78
        eval_target_transform=eval_target_transform,
79
        other_streams_transforms=other_streams_transforms,
80
        dataset_factory=dataset_factory,
81
        benchmark_factory=benchmark_factory
82
    )
83

84

85
def create_lazy_classification_benchmark(
4✔
86
    train_generator: LazyStreamDefinition,
87
    test_generator: LazyStreamDefinition,
88
    *,
89
    other_streams_generators: Optional[Dict[str, LazyStreamDefinition]] = None,
90
    complete_test_set_only: bool = False,
91
    train_transform: XTransform = None,
92
    train_target_transform: YTransform = None,
93
    eval_transform: XTransform = None,
94
    eval_target_transform: YTransform = None,
95
    other_streams_transforms: Optional[
96
        Mapping[str, Tuple[XTransform, YTransform]]] = None,
97
    dataset_factory: DatasetFactory = make_classification_dataset,
98
    benchmark_factory: Callable[
99
        [
100
            TStreamsUserDict,
101
            bool
102
        ], TDatasetScenario
103
    ] = _make_classification_scenario  # type: ignore
104
) -> TDatasetScenario:
105
    """
106
    Creates a lazily-defined classification benchmark instance given a dataset
107
    generator for each stream.
108

109
    Generators must return properly initialized instances of
110
    :class:`AvalancheDataset` which will be used to create experiences.
111

112
    For additional info, please refer to :func:`create_lazy_generic_benchmark`.
113
    """
NEW
114
    return create_lazy_generic_benchmark(
×
115
        train_generator=train_generator,
116
        test_generator=test_generator,
117
        other_streams_generators=other_streams_generators,
118
        complete_test_set_only=complete_test_set_only,
119
        train_transform=train_transform,
120
        train_target_transform=train_target_transform,
121
        eval_transform=eval_transform,
122
        eval_target_transform=eval_target_transform,
123
        other_streams_transforms=other_streams_transforms,
124
        dataset_factory=dataset_factory,
125
        benchmark_factory=benchmark_factory
126
    )
127

128

129
create_classification_benchmark_from_filelists = \
4✔
130
    create_generic_benchmark_from_filelists
131

132

133
def create_classification_benchmark_from_paths(
4✔
134
    train_lists_of_files: Sequence[Sequence[FileAndLabel]],
135
    test_lists_of_files: Sequence[Sequence[FileAndLabel]],
136
    *,
137
    other_streams_lists_of_files: Optional[Dict[
138
        str, Sequence[Sequence[FileAndLabel]]
139
    ]] = None,
140
    task_labels: Sequence[int],
141
    complete_test_set_only: bool = False,
142
    train_transform: XTransform = None,
143
    train_target_transform: YTransform = None,
144
    eval_transform: XTransform = None,
145
    eval_target_transform: YTransform = None,
146
    other_streams_transforms: Optional[
147
        Mapping[str, Tuple[XTransform, YTransform]]] = None,
148
    dataset_factory: DatasetFactory = make_classification_dataset,
149
    benchmark_factory: Callable[
150
        [
151
            TStreamsUserDict,
152
            bool
153
        ], TDatasetScenario
154
    ] = _make_classification_scenario  # type: ignore
155
) -> TDatasetScenario:
156
    """
157
    Creates a classification benchmark instance given a sequence of lists of
158
    files. A separate dataset will be created for each list. Each of those
159
    datasets will be considered a separate experience.
160

161
    This is very similar to
162
    :func:`create_classification_benchmark_from_filelists`,
163
    with the main difference being that
164
    :func:`create_classification_benchmark_from_filelists` accepts, for each
165
    experience, a file list formatted in Caffe-style. On the contrary, this
166
    accepts a list of tuples where each tuple contains two elements: the full
167
    path to the pattern and its label. Optionally, the tuple may contain a third
168
    element describing the bounding box of the element to crop. This last
169
    bounding box may be useful when trying to extract the part of the image
170
    depicting the desired element.
171

172
    For additional info, please refer to
173
    :func:`create_generic_benchmark_from_paths`.
174
    """
175
    return create_generic_benchmark_from_paths(
4✔
176
        train_lists_of_files=train_lists_of_files,
177
        test_lists_of_files=test_lists_of_files,
178
        other_streams_lists_of_files=other_streams_lists_of_files,
179
        task_labels=task_labels,
180
        complete_test_set_only=complete_test_set_only,
181
        train_transform=train_transform,
182
        train_target_transform=train_target_transform,
183
        eval_transform=eval_transform,
184
        eval_target_transform=eval_target_transform,
185
        other_streams_transforms=other_streams_transforms,
186
        dataset_factory=dataset_factory,
187
        benchmark_factory=benchmark_factory
188
    )
189

190

191
def create_classification_benchmark_from_tensor_lists(
4✔
192
    train_tensors: Sequence[Sequence[Any]],
193
    test_tensors: Sequence[Sequence[Any]],
194
    *,
195
    other_streams_tensors: Optional[Dict[str, Sequence[Sequence[Any]]]] = None,
196
    task_labels: Sequence[int],
197
    complete_test_set_only: bool = False,
198
    train_transform: XTransform = None,
199
    train_target_transform: YTransform = None,
200
    eval_transform: XTransform = None,
201
    eval_target_transform: YTransform = None,
202
    other_streams_transforms: Optional[
203
        Mapping[str, Tuple[XTransform, YTransform]]] = None,
204
    dataset_factory: DatasetFactory = make_classification_dataset,
205
    benchmark_factory: Callable[
206
        [
207
            TStreamsUserDict,
208
            bool
209
        ], TDatasetScenario
210
    ] = _make_classification_scenario  # type: ignore
211
) -> TDatasetScenario:
212
    """
213
    Creates a classification benchmark instance given lists of Tensors. A
214
    separate dataset will be created from each Tensor tuple (x, y, z, ...)
215
    and each of those training datasets will be considered a separate training
216
    experience. Using this helper function is the lowest-level way to create a
217
    Continual Learning benchmark. When possible, consider using higher level
218
    helpers.
219

220
    Experiences are defined by passing lists of tensors as the `train_tensors`,
221
    `test_tensors` (and `other_streams_tensors`) parameters. Those parameters
222
    must be lists containing lists of tensors, one list for each experience.
223
    Each tensor defines the value of a feature ("x", "y", "z", ...) for all
224
    patterns of that experience.
225

226
    By default the second tensor of each experience will be used to fill the
227
    `targets` value (label of each pattern).
228

229
    For additional info, please refer to
230
    :func:`create_generic_benchmark_from_tensor_lists`.
231
    """
232
    return create_generic_benchmark_from_tensor_lists(
4✔
233
        train_tensors=train_tensors,
234
        test_tensors=test_tensors,
235
        other_streams_tensors=other_streams_tensors,
236
        task_labels=task_labels,
237
        complete_test_set_only=complete_test_set_only,
238
        train_transform=train_transform,
239
        train_target_transform=train_target_transform,
240
        eval_transform=eval_transform,
241
        eval_target_transform=eval_target_transform,
242
        other_streams_transforms=other_streams_transforms,
243
        dataset_factory=dataset_factory,
244
        benchmark_factory=benchmark_factory
245
    )
246

247

248
__all__ = [
4✔
249
    'create_multi_dataset_classification_benchmark',
250
    'create_lazy_classification_benchmark',
251
    'create_classification_benchmark_from_filelists',
252
    'create_classification_benchmark_from_paths',
253
    'create_classification_benchmark_from_tensor_lists'
254
]
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc