• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ContinualAI / avalanche / 5399886876

pending completion
5399886876

Pull #1398

github

web-flow
Merge 2c8aba8e6 into a61ae5cab
Pull Request #1398: switch to black formatting

1023 of 1372 new or added lines in 177 files covered. (74.56%)

144 existing lines in 66 files now uncovered.

16366 of 22540 relevant lines covered (72.61%)

2.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

49.04
/avalanche/benchmarks/scenarios/detection_scenario.py
1
################################################################################
2
# Copyright (c) 2022 ContinualAI.                                              #
3
# Copyrights licensed under the MIT License.                                   #
4
# See the accompanying LICENSE file for terms.                                 #
5
#                                                                              #
6
# Date: 10-03-2022                                                             #
7
# Author(s): Lorenzo Pellegrini                                                #
8
# E-mail: contact@continualai.org                                              #
9
# Website: avalanche.continualai.org                                           #
10
################################################################################
11

12
from typing import (
4✔
13
    Generic,
14
    Iterable,
15
    Mapping,
16
    Optional,
17
    Sequence,
18
    Set,
19
    Tuple,
20
    TypeVar,
21
    List,
22
    Callable,
23
    Union,
24
    overload,
25
)
26
import warnings
4✔
27

28

29
from avalanche.benchmarks.scenarios.dataset_scenario import (
4✔
30
    ClassesTimelineCLScenario,
31
    DatasetScenario,
32
    FactoryBasedStream,
33
    TStreamsUserDict,
34
)
35
from avalanche.benchmarks.scenarios.generic_scenario import (
4✔
36
    AbstractClassTimelineExperience,
37
    CLScenario,
38
    CLStream,
39
)
40
from avalanche.benchmarks.utils.data import AvalancheDataset
4✔
41
from avalanche.benchmarks.utils.dataset_utils import manage_advanced_indexing
4✔
42
from avalanche.benchmarks.utils.detection_dataset import DetectionDataset
4✔
43

44
# --- Dataset ---
45
# From utils:
46
TCLDataset = TypeVar("TCLDataset", bound="AvalancheDataset", covariant=True)
4✔
47

48
# --- Scenario ---
49
# From dataset_scenario:
50
TDatasetScenario = TypeVar("TDatasetScenario", bound="DatasetScenario")
4✔
51
TDetectionScenario = TypeVar("TDetectionScenario", bound="DetectionScenario")
4✔
52

53
# --- Stream ---
54
# Defined here:
55
TDetectionStream = TypeVar("TDetectionStream", bound="DetectionStream")
4✔
56

57
# --- Experience ---
58
# From generic_scenario:
59
TDetectionExperience = TypeVar("TDetectionExperience", bound="DetectionExperience")
4✔
60

61

62
def _default_detection_stream_factory(stream_name: str, benchmark: "DetectionScenario"):
4✔
NEW
63
    return DetectionStream(name=stream_name, benchmark=benchmark)
×
64

65

66
def _default_detection_experience_factory(
4✔
67
    stream: "DetectionStream", experience_idx: int
68
):
NEW
69
    return DetectionExperience(origin_stream=stream, current_experience=experience_idx)
×
70

71

72
class DetectionScenario(
4✔
73
    ClassesTimelineCLScenario[TDetectionStream, TDetectionExperience, DetectionDataset]
74
):
75
    """
4✔
76
    Base implementation of a Continual Learning object detection benchmark.
77

78
    For more info, please refer to the base class :class:`DatasetScenario`.
79
    """
80

81
    def __init__(
4✔
82
        self: TDetectionScenario,
83
        stream_definitions: TStreamsUserDict,
84
        n_classes: Optional[int] = None,
85
        stream_factory: Callable[
86
            [str, TDetectionScenario], TDetectionStream
87
        ] = _default_detection_stream_factory,
88
        experience_factory: Callable[
89
            [TDetectionStream, int], TDetectionExperience
90
        ] = _default_detection_experience_factory,
91
        complete_test_set_only: bool = False,
92
    ):
93
        """
94
        Creates an instance a Continual Learning object detection benchmark.
95

96
        :param stream_definitions: The definition of the streams. For a more
97
            precise description, please refer to :class:`DatasetScenario`
98
        :param n_classes: The number of classes in the scenario. Defaults to
99
            None.
100
        :param stream_factory: A callable that, given the name of the
101
            stream and the benchmark instance, returns a stream instance.
102
            Defaults to the constructor of :class:`DetectionStream`.
103
        :param experience_factory: A callable that, given the
104
            stream instance and the experience ID, returns an experience
105
            instance.
106
            Defaults to the constructor of :class:`DetectionExperience`.
107
        :param complete_test_set_only: If True, the test stream will contain
108
            a single experience containing the complete test set. This also
109
            means that the definition for the test stream must contain the
110
            definition for a single experience.
111
        """
112

113
        super().__init__(
×
114
            stream_definitions=stream_definitions,
115
            stream_factory=stream_factory,
116
            experience_factory=experience_factory,
117
            complete_test_set_only=complete_test_set_only,
118
        )
119

120
        self.n_classes: Optional[int] = n_classes
×
121
        """
×
122
        The number of classes in the scenario.
123

124
        May be None if unknown.
125
        """
126

127
    @property
4✔
128
    def classes_in_experience(self):
3✔
129
        return _LazyStreamClassesInDetectionExps(self)
×
130

131

132
DetectionCLScenario = DetectionScenario
4✔
133

134

135
class DetectionStream(FactoryBasedStream[TDetectionExperience]):
4✔
136
    def __init__(
4✔
137
        self,
138
        name: str,
139
        benchmark: DetectionScenario,
140
        *,
141
        slice_ids: Optional[List[int]] = None,
142
        set_stream_info: bool = True
143
    ):
144
        self.benchmark: DetectionScenario = benchmark
×
145
        super().__init__(
×
146
            name=name,
147
            benchmark=benchmark,
148
            slice_ids=slice_ids,
149
            set_stream_info=set_stream_info,
150
        )
151

152

153
class DetectionExperience(AbstractClassTimelineExperience[DetectionDataset]):
4✔
154
    """
4✔
155
    Definition of a learning experience based on a :class:`DetectionScenario`
156
    instance.
157

158
    This experience implementation uses the generic experience-patterns
159
    assignment defined in the :class:`DetectionScenario` instance. Instances of
160
    this class are usually obtained from an object detection benchmark stream.
161
    """
162

163
    def __init__(
4✔
164
        self: TDetectionExperience,
165
        origin_stream: DetectionStream[TDetectionExperience],
166
        current_experience: int,
167
    ):
168
        """
169
        Creates an instance of an experience given the stream from this
170
        experience was taken and the current experience ID.
171

172
        :param origin_stream: The stream from which this experience was
173
            obtained.
174
        :param current_experience: The current experience ID, as an integer.
175
        """
176

177
        self._benchmark: DetectionScenario = origin_stream.benchmark
×
178

NEW
179
        dataset: DetectionDataset = origin_stream.benchmark.stream_definitions[
×
180
            origin_stream.name
181
        ].exps_data[current_experience]
182

183
        (
×
184
            classes_in_this_exp,
185
            previous_classes,
186
            classes_seen_so_far,
187
            future_classes,
188
        ) = origin_stream.benchmark.get_classes_timeline(
189
            current_experience, stream=origin_stream.name
190
        )
191

192
        super().__init__(
×
193
            origin_stream,
194
            dataset,
195
            current_experience,
196
            classes_in_this_exp,
197
            previous_classes,
198
            classes_seen_so_far,
199
            future_classes,
200
        )
201

202
    @property  # type: ignore[override]
4✔
203
    def benchmark(self) -> DetectionScenario:
4✔
204
        bench = self._benchmark
×
NEW
205
        DetectionExperience._check_unset_attribute("benchmark", bench)
×
UNCOV
206
        return bench
×
207

208
    @benchmark.setter
4✔
209
    def benchmark(self, bench: DetectionScenario):
4✔
210
        self._benchmark = bench
×
211

212
    def _get_stream_def(self):
4✔
213
        return self._benchmark.stream_definitions[self.origin_stream.name]
×
214

215
    @property
4✔
216
    def task_labels(self) -> List[int]:
4✔
217
        stream_def = self._get_stream_def()
×
218
        return list(stream_def.exps_task_labels[self.current_experience])
×
219

220

221
GenericDetectionExperience = DetectionExperience
4✔
222

223

224
class _LazyStreamClassesInDetectionExps(Mapping[str, Sequence[Optional[Set[int]]]]):
4✔
225
    def __init__(self, benchmark: DetectionScenario):
4✔
226
        self._benchmark = benchmark
×
NEW
227
        self._default_lcie = _LazyClassesInDetectionExps(benchmark, stream="train")
×
228

229
    def __len__(self):
4✔
230
        return len(self._benchmark.stream_definitions)
×
231

232
    def __getitem__(self, stream_name_or_exp_id):
4✔
233
        if isinstance(stream_name_or_exp_id, str):
×
234
            return _LazyClassesInDetectionExps(
×
235
                self._benchmark, stream=stream_name_or_exp_id
236
            )
237

238
        warnings.warn(
×
239
            "Using classes_in_experience[exp_id] is deprecated. "
240
            "Consider using classes_in_experience[stream_name][exp_id]"
241
            "instead.",
242
            stacklevel=2,
243
        )
244
        return self._default_lcie[stream_name_or_exp_id]
×
245

246
    def __iter__(self):
4✔
247
        yield from self._benchmark.stream_definitions.keys()
×
248

249

250
LazyClassesInExpsRet = Union[Tuple[Optional[Set[int]], ...], Optional[Set[int]]]
4✔
251

252

253
class _LazyClassesInDetectionExps(Sequence[Optional[Set[int]]]):
4✔
254
    def __init__(self, benchmark: DetectionScenario, stream: str = "train"):
4✔
255
        self._benchmark = benchmark
×
256
        self._stream = stream
×
257

258
    def __len__(self):
4✔
259
        return len(self._benchmark.streams[self._stream])
×
260

261
    @overload
4✔
262
    def __getitem__(self, exp_id: int) -> Optional[Set[int]]:
4✔
263
        ...
×
264

265
    @overload
4✔
266
    def __getitem__(self, exp_id: slice) -> Tuple[Optional[Set[int]], ...]:
4✔
267
        ...
×
268

269
    def __getitem__(self, exp_id: Union[int, slice]) -> LazyClassesInExpsRet:
4✔
270
        indexing_collate = _LazyClassesInDetectionExps._slice_collate
×
271
        result = manage_advanced_indexing(
×
272
            exp_id, self._get_single_exp_classes, len(self), indexing_collate
273
        )
274
        return result
×
275

276
    def __str__(self):
4✔
NEW
277
        return "[" + ", ".join([str(self[idx]) for idx in range(len(self))]) + "]"
×
278

279
    def _get_single_exp_classes(self, exp_id) -> Optional[Set[int]]:
4✔
280
        b = self._benchmark.stream_definitions[self._stream]
×
281
        if not b.is_lazy and exp_id not in b.exps_data.targets_field_sequence:
×
282
            raise IndexError
×
283
        targets = b.exps_data.targets_field_sequence[exp_id]
×
284
        if targets is None:
×
285
            return None
×
286

287
        classes_in_exp = set()
×
288
        for target in targets:
×
NEW
289
            for label in target["labels"]:
×
290
                classes_in_exp.add(int(label))
×
291
        return classes_in_exp
×
292

293
    @staticmethod
4✔
294
    def _slice_collate(
4✔
295
        classes_in_exps: Iterable[Optional[Iterable[int]]],
296
    ) -> Optional[Tuple[Set[int], ...]]:
297
        result: List[Set[int]] = []
×
298
        for x in classes_in_exps:
×
299
            if x is None:
×
300
                return None
×
301
            result.append(set(x))
×
302

303
        return tuple(result)
×
304

305

306
__all__ = [
4✔
307
    "DetectionScenario",
308
    "DetectionCLScenario",
309
    "DetectionStream",
310
    "GenericDetectionExperience",
311
    "DetectionExperience",
312
]
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc