• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ContinualAI / avalanche / 5399886876

pending completion
5399886876

Pull #1398

github

web-flow
Merge 2c8aba8e6 into a61ae5cab
Pull Request #1398: switch to black formatting

1023 of 1372 new or added lines in 177 files covered. (74.56%)

144 existing lines in 66 files now uncovered.

16366 of 22540 relevant lines covered (72.61%)

2.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

41.67
/avalanche/benchmarks/classic/comniglot.py
1
################################################################################
2
# Copyright (c) 2021 ContinualAI.                                              #
3
# Copyrights licensed under the MIT License.                                   #
4
# See the accompanying LICENSE file for terms.                                 #
5
#                                                                              #
6
# Date: 13-02-2021                                                             #
7
# Author(s): Jary Pomponi, Antonio Carta                                       #
8
################################################################################
9
from pathlib import Path
4✔
10
from typing import Optional, Sequence, Any, Union
4✔
11
from torch import Tensor
4✔
12
from torchvision.transforms import (
4✔
13
    ToTensor,
14
    Compose,
15
    Normalize,
16
    ToPILImage,
17
)
18
from PIL.Image import Image
4✔
19

20
from avalanche.benchmarks import nc_benchmark
4✔
21
from avalanche.benchmarks.classic.classic_benchmarks_utils import (
4✔
22
    check_vision_benchmark,
23
)
24
from avalanche.benchmarks.datasets import default_dataset_location
4✔
25
from avalanche.benchmarks.datasets.omniglot import Omniglot
4✔
26

27

28
_default_omniglot_train_transform = Compose(
4✔
29
    [ToTensor(), Normalize((0.9221,), (0.2681,))]
30
)
31

32
_default_omniglot_eval_transform = Compose(
4✔
33
    [ToTensor(), Normalize((0.9221,), (0.2681,))]
34
)
35

36

37
class PixelsPermutation(object):
4✔
38
    """Apply a fixed permutation to the pixels of the given image.
4✔
39

40
    Works with both Tensors and PIL images. Returns an object of the same type
41
    of the input element.
42
    """
43

44
    def __init__(self, index_permutation: Sequence[int]):
4✔
45
        self.permutation = index_permutation
×
46
        self._to_tensor = ToTensor()
×
47
        self._to_image = ToPILImage()
×
48

49
    def __call__(self, img: Union[Image, Tensor]):
4✔
50
        is_image = isinstance(img, Image)
×
51
        if (not is_image) and (not isinstance(img, Tensor)):
×
52
            raise ValueError("Invalid input: must be a PIL image or a Tensor")
×
53

54
        image_as_tensor: Tensor
55
        if is_image:
×
56
            image_as_tensor = self._to_tensor(img)
×
57
        else:
58
            image_as_tensor = img
×
59

60
        image_as_tensor = image_as_tensor.view(-1)[self.permutation].view(
×
61
            *image_as_tensor.shape
62
        )
63

64
        if is_image:
×
65
            img = self._to_image(image_as_tensor)
×
66
        else:
67
            img = image_as_tensor
×
68

69
        return img
×
70

71

72
def SplitAlphabetOmniglot(
4✔
73
    n_experiences: int,
74
    *,
75
    return_task_id=False,
76
    seed: Optional[int] = None,
77
    fixed_class_order: Optional[Sequence[int]] = None,
78
    class_ids_from_zero_from_first_exp: bool = False,
79
    shuffle: bool = True,
80
    train_transform: Optional[Any] = _default_omniglot_train_transform,
81
    eval_transform: Optional[Any] = _default_omniglot_eval_transform,
82
    dataset_root: Optional[Union[str, Path]] = None
83
):
84
    """Class-incremental OMNIGLOT with the alphabet used as target.
85

86
    If the dataset is not present in the computer, this method will
87
    automatically download and store it.
88

89
    The returned benchmark will return experiences containing all patterns of a
90
    subset of alphabets (class-incremental scenario).
91

92
    The benchmark API is quite simple and is uniform across all benchmark
93
    generators. It is recommended to check the tutorial of the "benchmark" API,
94
    which contains usage examples ranging from "basic" to "advanced".
95

96
    :param n_experiences: The number of incremental experiences in the current
97
        benchmark. The value of this parameter should be a divisor of 10.
98
    :param return_task_id: if True, a progressive task id is returned for every
99
        experience. If False, all experiences will have a task ID of 0.
100
    :param seed: A valid int used to initialize the random number generator.
101
        Can be None.
102
    :param fixed_class_order: A list of class IDs used to define the class
103
        order. If None, value of ``seed`` will be used to define the class
104
        order. If non-None, ``seed`` parameter will be ignored.
105
        Defaults to None.
106
    :param class_ids_from_zero_from_first_exp: If True, original class IDs
107
        will be remapped so that they will appear as having an ascending
108
        order. For instance, if the resulting class order after shuffling
109
        (or defined by fixed_class_order) is [23, 34, 11, 7, 6, ...] and
110
        class_ids_from_zero_from_first_exp is True, then all the patterns
111
        belonging to class 23 will appear as belonging to class "0",
112
        class "34" will be mapped to "1", class "11" to "2" and so on.
113
        This is very useful when drawing confusion matrices and when dealing
114
        with algorithms with dynamic head expansion. Defaults to False.
115
        Mutually exclusive with the ``class_ids_from_zero_in_each_exp``
116
        parameter.
117
    :param shuffle: If true, the class order in the incremental experiences is
118
        randomly shuffled. Default to True.
119
    :param train_transform: The transformation to apply to the training data,
120
        e.g. a random crop, a normalization or a concatenation of different
121
        transformations (see torchvision.transform documentation for a
122
        comprehensive list of possible transformations).
123
        If no transformation is passed, the default train transformation
124
        will be used.
125
    :param eval_transform: The transformation to apply to the test data,
126
        e.g. a random crop, a normalization or a concatenation of different
127
        transformations (see torchvision.transform documentation for a
128
        comprehensive list of possible transformations).
129
        If no transformation is passed, the default test transformation
130
        will be used.
131
    :param dataset_root: The root path of the dataset. Defaults to None, which
132
        means that the default location for 'omniglot' will be used.
133

134
    :returns: A properly initialized :class:`NCScenario` instance.
135
    """
136

137
    omniglot_train, omniglot_test = _get_omniglot_dataset(dataset_root)
×
138
    return nc_benchmark(
×
139
        train_dataset=omniglot_train,
140
        test_dataset=omniglot_test,
141
        n_experiences=n_experiences,
142
        task_labels=return_task_id,
143
        seed=seed,
144
        fixed_class_order=fixed_class_order,
145
        shuffle=shuffle,
146
        class_ids_from_zero_in_each_exp=False,
147
        class_ids_from_zero_from_first_exp=class_ids_from_zero_from_first_exp,
148
        train_transform=train_transform,
149
        eval_transform=eval_transform,
150
    )
151

152

153
def SplitOmniglot(
4✔
154
    n_experiences: int,
155
    *,
156
    return_task_id=False,
157
    seed: Optional[int] = None,
158
    fixed_class_order: Optional[Sequence[int]] = None,
159
    shuffle: bool = True,
160
    class_ids_from_zero_in_each_exp: bool = False,
161
    class_ids_from_zero_from_first_exp: bool = False,
162
    train_transform: Optional[Any] = _default_omniglot_train_transform,
163
    eval_transform: Optional[Any] = _default_omniglot_eval_transform,
164
    dataset_root: Optional[Union[str, Path]] = None
165
):
166
    """
167
    Creates a CL benchmark using the OMNIGLOT dataset.
168

169
    If the dataset is not present in the computer, this method will
170
    automatically download and store it.
171

172
    The returned benchmark will return experiences containing all patterns of a
173
    subset of classes, which means that each class is only seen "once".
174
    This is one of the most common scenarios in the Continual Learning
175
    literature. Common names used in literature to describe this kind of
176
    scenario are "Class Incremental", "New Classes", etc.
177

178
    By default, an equal amount of classes will be assigned to each experience.
179
    OMNIGLOT consists of 964 classes, which means that the number of
180
    experiences can be 1, 2, 4, 241, 482, 964.
181

182
    This generator doesn't force a choice on the availability of task labels,
183
    a choice that is left to the user (see the `return_task_id` parameter for
184
    more info on task labels).
185

186
    The benchmark instance returned by this method will have two fields,
187
    `train_stream` and `test_stream`, which can be iterated to obtain
188
    training and test :class:`Experience`. Each Experience contains the
189
    `dataset` and the associated task label.
190

191
    The benchmark API is quite simple and is uniform across all benchmark
192
    generators. It is recommended to check the tutorial of the "benchmark" API,
193
    which contains usage examples ranging from "basic" to "advanced".
194

195
    :param n_experiences: The number of incremental experiences in the current
196
        benchmark. The value of this parameter should be a divisor of 10.
197
    :param return_task_id: if True, a progressive task id is returned for every
198
        experience. If False, all experiences will have a task ID of 0.
199
    :param seed: A valid int used to initialize the random number generator.
200
        Can be None.
201
    :param fixed_class_order: A list of class IDs used to define the class
202
        order. If None, value of ``seed`` will be used to define the class
203
        order. If non-None, ``seed`` parameter will be ignored.
204
        Defaults to None.
205
    :param shuffle: If true, the class order in the incremental experiences is
206
        randomly shuffled. Default to True.
207
    :param class_ids_from_zero_in_each_exp: If True, original class IDs
208
        will be mapped to range [0, n_classes_in_exp) for each experience.
209
        Defaults to False. Mutually exclusive with the
210
        ``class_ids_from_zero_from_first_exp`` parameter.
211
    :param class_ids_from_zero_from_first_exp: If True, original class IDs
212
        will be remapped so that they will appear as having an ascending
213
        order. For instance, if the resulting class order after shuffling
214
        (or defined by fixed_class_order) is [23, 34, 11, 7, 6, ...] and
215
        class_ids_from_zero_from_first_exp is True, then all the patterns
216
        belonging to class 23 will appear as belonging to class "0",
217
        class "34" will be mapped to "1", class "11" to "2" and so on.
218
        This is very useful when drawing confusion matrices and when dealing
219
        with algorithms with dynamic head expansion. Defaults to False.
220
        Mutually exclusive with the ``class_ids_from_zero_in_each_exp``
221
        parameter.
222
    :param train_transform: The transformation to apply to the training data,
223
        e.g. a random crop, a normalization or a concatenation of different
224
        transformations (see torchvision.transform documentation for a
225
        comprehensive list of possible transformations).
226
        If no transformation is passed, the default train transformation
227
        will be used.
228
    :param eval_transform: The transformation to apply to the test data,
229
        e.g. a random crop, a normalization or a concatenation of different
230
        transformations (see torchvision.transform documentation for a
231
        comprehensive list of possible transformations).
232
        If no transformation is passed, the default test transformation
233
        will be used.
234
    :param dataset_root: The root path of the dataset. Defaults to None, which
235
        means that the default location for 'omniglot' will be used.
236

237
    :returns: A properly initialized :class:`NCScenario` instance.
238
    """
239

240
    omniglot_train, omniglot_test = _get_omniglot_dataset(dataset_root)
×
241
    return nc_benchmark(
×
242
        train_dataset=omniglot_train,
243
        test_dataset=omniglot_test,
244
        n_experiences=n_experiences,
245
        task_labels=return_task_id,
246
        seed=seed,
247
        fixed_class_order=fixed_class_order,
248
        shuffle=shuffle,
249
        class_ids_from_zero_in_each_exp=class_ids_from_zero_in_each_exp,
250
        class_ids_from_zero_from_first_exp=class_ids_from_zero_from_first_exp,
251
        train_transform=train_transform,
252
        eval_transform=eval_transform,
253
    )
254

255

256
def _get_omniglot_dataset(dataset_root):
4✔
257
    if dataset_root is None:
×
258
        dataset_root = default_dataset_location("omniglot")
×
259

260
    train = Omniglot(root=dataset_root, train=True, download=True)
×
261
    test = Omniglot(root=dataset_root, train=False, download=True)
×
262

263
    return train, test
×
264

265

266
__all__ = ["SplitOmniglot"]
4✔
267

268
if __name__ == "__main__":
4✔
269
    import sys
×
270

271
    print("Split Omniglot")
×
NEW
272
    benchmark_instance = SplitOmniglot(4, train_transform=None, eval_transform=None)
×
UNCOV
273
    check_vision_benchmark(benchmark_instance)
×
274

275
    sys.exit(0)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc