• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ContinualAI / avalanche / 5268393053

pending completion
5268393053

Pull #1397

github

web-flow
Merge 60d244754 into e91562200
Pull Request #1397: Specialize benchmark creation helpers

417 of 538 new or added lines in 30 files covered. (77.51%)

43 existing lines in 5 files now uncovered.

16586 of 22630 relevant lines covered (73.29%)

2.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

32.69
/avalanche/benchmarks/classic/ctrl.py
1
################################################################################
2
# Copyright (c) 2021 ContinualAI.                                              #
3
# Copyrights licensed under the MIT License.                                   #
4
# See the accompanying LICENSE file for terms.                                 #
5
#                                                                              #
6
# Date: 22-06-2021                                                             #
7
# Author(s): Tom Veniat                                                        #
8
# E-mail: contact@continualai.org                                              #
9
# Website: avalanche.continualai.org                                           #
10
################################################################################
11

12
import random
4✔
13
import sys
4✔
14
from pathlib import Path
4✔
15
from typing import List, Optional, Tuple
4✔
16
from PIL.Image import Image
4✔
17

18
import torchvision.transforms.functional as F
4✔
19
from torchvision import transforms
4✔
20
from tqdm import tqdm
4✔
21
from avalanche.benchmarks.generators.benchmark_generators import (
4✔
22
    dataset_classification_benchmark,
23
)
24

25
from avalanche.benchmarks.utils.classification_dataset import (
4✔
26
    ClassificationDataset,
27
)
28

29
try:
4✔
30
    import ctrl
4✔
31
except ImportError:
×
32
    raise ModuleNotFoundError(
×
33
        "ctrl not found, if you want to use this"
34
        "dataset please install avalanche with the "
35
        "extra dependencies: "
36
        "pip install avalanche-lib[extra]"
37
    )
38

39
from avalanche.benchmarks import dataset_benchmark
4✔
40
from avalanche.benchmarks.datasets import default_dataset_location
4✔
41
from avalanche.benchmarks.utils import (
4✔
42
    make_tensor_classification_dataset,
43
    common_paths_root,
44
    make_classification_dataset,
45
    PathsDataset,
46
)
47

48

49
def CTrL(
4✔
50
    stream_name: str,
51
    save_to_disk: bool = False,
52
    path: Path = default_dataset_location(""),
53
    seed: Optional[int] = None,
54
    n_tasks: Optional[int] = None,
55
):
56
    """
57
    Gives access to the Continual Transfer Learning benchmark streams
58
    introduced in https://arxiv.org/abs/2012.12631.
59
    :param stream_name: Name of the test stream to generate. Must be one of
60
    `s_plus`, `s_minus`, `s_in`, `s_out` and `s_pl`.
61
    :param save_to_disk:  Whether to save each stream on the disk or load
62
    everything in memory. Setting it to `True` will save memory but takes more
63
    time on the first generation using the corresponding seed.
64
    :param path: The path under which the generated stream will be saved if
65
    save_to_disk is True.
66
    :param seed: The seed to use to generate the streams. If no seed is given,
67
    a random one will be used to make sure that the generated stream can
68
    be reproduced.
69
    :param n_tasks: The number of tasks to generate. This parameter is only
70
    relevant for the `s_long` stream, as all other streams have a fixed number
71
    of tasks.
72
    :return: A scenario containing 3 streams: train, val and test.
73
    """
74
    seed = seed or random.randint(0, sys.maxsize)
×
75
    if stream_name != "s_long" and n_tasks is not None:
×
76
        raise ValueError(
×
77
            "The n_tasks parameter can only be used with the "
78
            f'"s_long" stream, asked {n_tasks} for {stream_name}'
79
        )
80
    elif stream_name == "s_long" and n_tasks is None:
×
81
        n_tasks = 100
×
82

83
    stream = ctrl.get_stream(stream_name, seed)
×
84

85
    if save_to_disk:
×
86
        folder = path / "ctrl" / stream_name / f"seed_{seed}"
×
87

88
    # Train, val and test experiences
NEW
89
    exps: List[List[ClassificationDataset]] = [[], [], []]
×
90
    for t_id, t in enumerate(
×
91
        tqdm(stream, desc=f"Loading {stream_name}"),
92
    ):
93
        trans = transforms.Normalize(t.statistics["mean"], t.statistics["std"])
×
94
        for split, split_name, exp in zip(t.datasets, t.split_names, exps):
×
95
            samples, labels = split.tensors
×
96
            task_labels = [t.id] * samples.size(0)
×
97
            if save_to_disk:
×
98
                exp_folder = folder / f"exp_{t_id}" / split_name
×
99
                exp_folder.mkdir(parents=True, exist_ok=True)
×
100
                files: List[Tuple[Path, int]] = []
×
101
                for i, (sample, label) in enumerate(zip(samples, labels)):
×
102
                    sample_path = exp_folder / f"sample_{i}.png"
×
103
                    if not sample_path.exists():
×
104
                        F.to_pil_image(sample).save(sample_path)
×
105
                    files.append((sample_path, label.item()))
×
106

107
                common_root, exp_paths_list = common_paths_root(files)
×
108
                paths_dataset: PathsDataset[Image, int] = \
×
109
                    PathsDataset(common_root, exp_paths_list)
NEW
110
                dataset: ClassificationDataset = \
×
111
                    make_classification_dataset(
112
                        paths_dataset,
113
                        task_labels=task_labels,
114
                        transform=transforms.Compose(
115
                            [transforms.ToTensor(), trans]
116
                        ),
117
                    )
118
            else:
119
                dataset = make_tensor_classification_dataset(
×
120
                    samples,
121
                    labels.squeeze(1),
122
                    task_labels=task_labels,
123
                    transform=trans,
124
                    targets=1  # Use the 2nd tensor as targets
125
                )
126
            exp.append(dataset)
×
127
        if stream_name == "s_long":
×
128
            assert n_tasks is not None
×
129
            if t_id == n_tasks - 1:
×
130
                break
×
131

NEW
132
    return dataset_classification_benchmark(
×
133
        train_datasets=exps[0],
134
        test_datasets=exps[2],
135
        other_streams_datasets=dict(val=exps[1]),
136
    )
137

138

139
__all__ = ["CTrL"]
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc