• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

yeliudev / nncore / 6185555099

14 Sep 2023 12:38PM UTC coverage: 16.593% (+0.1%) from 16.475%
6185555099

push

github

yeliudev
Upgrade dist functions

10 of 10 new or added lines in 4 files covered. (100.0%)

674 of 4062 relevant lines covered (16.59%)

3.21 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/nncore/engine/engine.py
1
# Copyright (c) Ye Liu. Licensed under the MIT License.
2

3
from collections import OrderedDict
×
4

5
import torch
×
6

7
import nncore
×
8
from nncore.nn import build_model
×
9
from nncore.optim import build_optimizer
×
10
from nncore.utils import CfgNode
×
11
from .buffer import Buffer
×
12
from .builder import build_dataloader, build_hook
×
13
from .comm import gather, is_distributed, is_main_process, sync
×
14
from .hooks import Hook
×
15
from .utils import get_checkpoint, load_checkpoint
×
16

17
_DEFAULT_STAGES = [
×
18
    dict(
19
        epochs=5,
20
        optimizer=dict(type='SGD', lr=1e-2, momentum=0.9, weight_decay=1e-4),
21
        lr_schedule=dict(type='iter', policy='cosine'),
22
        warmup=dict(type='iter', policy='linear', steps=500, ratio=0.001),
23
        validation=dict(interval=1))
24
]
25

26
_DEFAULT_HOOKS = [
×
27
    'TimerHook', 'LrUpdaterHook', 'OptimizerHook', 'CheckpointHook',
28
    'EvalHook', 'EventWriterHook'
29
]
30

31

32
@nncore.bind_getter('mode', 'max_stages', 'max_epochs', 'max_iters',
×
33
                    'start_iter', 'stage', 'epoch', 'iter', 'kwargs')
34
class Engine(object):
×
35
    """
36
    An engine that can take over the whole training, validation, and testing
37
    process, with all the baby-sitting works (stage control, optimizer
38
    configuration, lr scheduling, checkpoint management, metrics & tensorboard
39
    writing, etc.) done automatically.
40

41
    Args:
42
        model (:obj:`nn.Module` | cfg | str): The model or config of the model.
43
            The :obj:`forward` method of the model should return a dict
44
            containing a ``_avg_factor`` field indicating the number of
45
            samples in the current batch, and optionally a ``_out`` field
46
            denoting the model outputs to be collected and evaluated.
47
        data_loaders (dict | str): The configs of data loaders for training,
48
            validation, and testing. The dict should be in the format of
49
            ``dict(train=train_loader, val=val_loader, test=test_loader)``.
50
        stages (list[dict] | dict | None, optional): The stage config or list
51
            of stage configs to be scheduled. Each stage config should be a
52
            dict containing the following fields:
53

54
            - `epochs` (int): Number of epochs in the stage.
55
            - `optimizer` (:obj:`optim.Optimizer` | dict): The optimizer or \
56
                an optimizer config containing the following fields:
57

58
                - `type` (str): Type of the optimizer, which can be accessed \
59
                    via :obj:`torch.optim` attributes, e.g. ``'SGD'``.
60
                - `configs for the optimizer, e.g.` ``lr=0.01, momentum=0.9``.
61

62
            - `lr_schedule` (dict, optional): The learning rate schedule \
63
                config containing the following fields:
64

65
                - `type` (str): Type of the learning rate schedule. Expected \
66
                    values include ``'epoch'`` and ``'iter'``, indicating \
67
                    updating learning rates every epoch or iteration.
68
                - `policy` (str): The learning rate policy to use. Currently \
69
                    supported policies include ``step``, ``cosine``, ``exp``, \
70
                    ``poly``, and ``inv``.
71
                - `configs for the learning rate policy, e.g.` \
72
                    ``target_lr=0``. Please refer to :obj:`LrUpdaterHook` for \
73
                    full configs.
74

75
            - `warmup` (dict, optional): The warm-up policy config containing \
76
                the following fields:
77

78
                - `type` (str): Type of the warm-up schedule. Expected values \
79
                    include ``'epoch'`` and ``'iter'``, indicating warming up \
80
                    for ``step`` epochs for iterations.
81
                - `policy` (str): The warm-up policy to use. Currently \
82
                    supported policies include ``linear``, ``exp`` and \
83
                    ``constant``.
84
                - `step` (int): Number of iterations to warm-up.
85
                - `ratio` (float): The ratio of learning rate to start with. \
86
                    Expected values are in the range of ``0 ~ 1``.
87

88
            - `validation` (dict, optional): The validation config containing \
89
                the following fields:
90

91
                - `interval` (int, optional): The interval of performing \
92
                    validation. ``0`` means not performing validation. \
93
                    Default: ``0``.
94
                - `offset` (int, optional): The number of epochs to skip \
95
                    before counting the interval. Default: ``0``.
96

97
            Default: ``None``.
98
        hooks (list[:obj:`Hook` | dict | str] | None, optional): The list of
99
            extra hooks to be registered. Each hook can be represented as a
100
            :obj:`Hook`, a dict or a str. Default: ``None``.
101
        buffer_size (int, optional): Maximum size of the buffer. Default:
102
            ``100000``.
103
        logger (:obj:`logging.Logger` | str | None, optional): The logger or
104
            name of the logger to use. Default: ``None``.
105
        work_dir (str | None, optional): Path to the working directory. If not
106
            specified, the default working directory will be used. Default:
107
            ``None``.
108
        seed (int | None, optional): The random seed to use in data loaders.
109
            Default: ``None``.
110
        meta (any | None, optional): A dictionary-like object containing meta
111
            data of this engine. Default: ``None``.
112

113
    Example:
114
        >>> # Build model
115
        >>> model = build_model()
116
        ...
117
        >>> # Build data loaders
118
        >>> train_loader = build_dataloader(split='train')
119
        >>> val_loader = build_dataloader(split='val')
120
        >>> data_loaders = dict(train=train_loader, val=val_loader)
121
        ...
122
        >>> # Configure stages:
123
        >>> # [Stage 1] Train the model for 5 epochs using Adam optimizer with
124
        >>> # a fixed learning rate (1e-3) and a linear warm-up policy.
125
        >>> # [Stage 2] Train the model for another 3 epochs using SGD with
126
        >>> # momentum optimizer and an iter-based cosine learning rate
127
        >>> # schedule. Perform validation after every training epoch.
128
        >>> stages = [
129
        ...     dict(
130
        ...         epochs=5,
131
        ...         optimizer=dict(type='Adam', lr=1e-3),
132
        ...         warmup=dict(type='iter', policy='linear', steps=500)),
133
        ...     dict(
134
        ...         epochs=3,
135
        ...         optimizer=dict(type='SGD', lr=1e-3, momentum=0.9),
136
        ...         lr_schedule=dict(type='iter', policy='cosine'),
137
        ...         validation=dict(interval=1))
138
        ... ]
139
        ...
140
        >>> # Initialize and launch engine
141
        >>> engine = Engine(model, data_loaders, stages=stages)
142
        >>> engine.launch()
143
    """
144

145
    def __init__(self,
×
146
                 model,
147
                 data_loaders,
148
                 stages=None,
149
                 hooks=None,
150
                 buffer_size=100000,
151
                 logger=None,
152
                 work_dir=None,
153
                 seed=None,
154
                 meta=None,
155
                 **kwargs):
156
        self.model = build_model(model, **kwargs)
×
157

158
        if 'train' not in data_loaders:
×
159
            data_loaders = dict(train=data_loaders)
×
160

161
        for a, b in (('val', 'test'), ('test', 'val')):
×
162
            if a not in data_loaders:
×
163
                loader = data_loaders[b if b in data_loaders else 'train']
×
164
                if isinstance(loader, dict):
×
165
                    data_loaders[a] = loader.copy()
×
166
                else:
167
                    data_loaders[a] = loader
×
168

169
        self.data_loaders = {
×
170
            k: build_dataloader(v, seed=seed)
171
            for k, v in data_loaders.items()
172
        }
173

174
        if isinstance(stages, dict):
×
175
            self.stages = [stages]
×
176
        else:
177
            self.stages = stages or _DEFAULT_STAGES
×
178

179
        self.register_hook(_DEFAULT_HOOKS)
×
180
        if is_distributed():
×
181
            self.register_hook('SamplerSeedHook', before='OptimizerHook')
×
182
        if hooks is not None:
×
183
            self.register_hook(hooks)
×
184

185
        time_str = nncore.get_timestamp()
×
186
        self.work_dir = work_dir or nncore.join('work_dirs', time_str)
×
187

188
        log_file = nncore.join(self.work_dir, time_str + '.log')
×
189
        self.logger = nncore.get_logger(logger, log_file=log_file)
×
190

191
        self.buffer = Buffer(max_size=buffer_size, logger=self.logger)
×
192
        self.reset_states()
×
193

194
        self.meta = meta
×
195

196
    @property
×
197
    def cur_stage(self):
×
198
        return self.stages[self._stage]
×
199

200
    @property
×
201
    def epoch_in_stage(self):
×
202
        cumsum = 0
×
203
        for stage in self.stages:
×
204
            if self._epoch + 1 <= cumsum + stage['epochs']:
×
205
                return self._epoch - cumsum
×
206
            cumsum += stage['epochs']
×
207
        return self.stages[-1]['epochs']
×
208

209
    @property
×
210
    def iter_in_stage(self):
×
211
        cumsum = 0
×
212
        for i in range(self._stage):
×
213
            cumsum += len(
×
214
                self.data_loaders['train']) * self.stages[i]['epochs']
215
        return self._iter - cumsum
×
216

217
    @property
×
218
    def iter_in_epoch(self):
×
219
        return self._iter - len(self.data_loaders['train']) * self._epoch
×
220

221
    def _call_hook(self, name):
×
222
        for hook in self.hooks.values():
×
223
            getattr(hook, name)(self)
×
224

225
    def reset_states(self):
×
226
        self.buffer.clear()
×
227
        self._max_stages = 0 if self.stages is None else len(self.stages)
×
228
        self._max_epochs = 0 if self.stages is None else sum(
×
229
            stage['epochs'] for stage in self.stages)
230
        self._max_iters = (len(self.data_loaders['train']) if 'train'
×
231
                           in self.data_loaders else 0) * self._max_epochs
232
        self._start_iter = self._stage = self._epoch = self._iter = 0
×
233

234
    def register_hook(self, hook, before=None, overwrite=True, **kwargs):
×
235
        """
236
        Register a hook or a list of hooks into the engine.
237

238
        Args:
239
            hook (list | :obj:`Hook` | dict | str): The hook or list of hooks
240
                to be registered. Each hook can be represented as a
241
                :obj:`Hook`, a dict or a str.
242
            before (str, optional): Name of the hook to be inserted before. If
243
                not specified, the new hook will be added to the end of hook
244
                list. Default: ``None``.
245
            overwrite (bool, optional): Whether to overwrite the old hook with
246
                the same name if exists. Default: ``True``.
247
        """
248
        if isinstance(hook, (list, tuple)):
×
249
            for h in hook:
×
250
                self.register_hook(
×
251
                    h, before=before, overwrite=overwrite, **kwargs)
252
            return
×
253
        elif isinstance(hook, (dict, str)):
×
254
            hook = build_hook(hook, **kwargs)
×
255
        elif not isinstance(hook, Hook):
×
256
            raise TypeError(
×
257
                "hook must be a Hook, a dict or a str, but got '{}'".format(
258
                    type(hook)))
259

260
        if not hasattr(self, 'hooks'):
×
261
            self.hooks = OrderedDict()
×
262

263
        if hook.name in self.hooks:
×
264
            if overwrite:
×
265
                keys = list(self.hooks.keys())
×
266
                if before is None and keys[-1] != hook.name:
×
267
                    before = keys[keys.index(hook.name) + 1]
×
268
                self.hooks.pop(hook.name)
×
269
            else:
270
                raise KeyError("hook '{}' exists".format(hook.name))
×
271

272
        self.hooks[hook.name] = hook
×
273

274
        if before is not None:
×
275
            if before not in self.hooks:
×
276
                raise ValueError("hook '{}' not found".format(before))
×
277

278
            keys = list(self.hooks.keys())
×
279
            for key in keys[keys.index(before):-1]:
×
280
                self.hooks.move_to_end(key)
×
281

282
    def unregister_hook(self, hook):
×
283
        """
284
        Unregister a hook or a list of hooks from the engine.
285

286
        Args:
287
            hook (list | :obj:`Hook` | str): The hook or list of hooks to be
288
                unregistered. Each hook can be represented as a :obj:`Hook` or
289
                a str.
290
        """
291
        if isinstance(hook, (list, tuple)):
×
292
            for h in hook:
×
293
                self.unregister_hook(h)
×
294
            return
×
295

296
        if isinstance(hook, Hook):
×
297
            hook = hook.name
×
298
        self.hooks.pop(hook)
×
299

300
    def load_checkpoint(self, checkpoint, **kwargs):
×
301
        """
302
        Load checkpoint from a file or an URL.
303

304
        Args:
305
            checkpoint (dict | str): A dict, a filename, an URL or a
306
                ``torchvision://<model_name>`` str indicating the checkpoint.
307
        """
308
        load_checkpoint(
×
309
            self.model,
310
            checkpoint,
311
            map_location=next(self.model.parameters()).device,
312
            logger=self.logger,
313
            **kwargs)
314

315
        if isinstance(checkpoint, str):
×
316
            self.logger.info('Loaded checkpoint from {}'.format(checkpoint))
×
317
        else:
318
            self.logger.info('Loaded checkpoint')
×
319

320
    def resume(self, checkpoint, **kwargs):
×
321
        """
322
        Resume from a checkpoint file.
323

324
        Args:
325
            checkpoint (dict | str): A dict, a filename or an URL indicatin
326
                the checkpoint.
327
        """
328
        if isinstance(checkpoint, str):
×
329
            checkpoint = get_checkpoint(
×
330
                checkpoint, map_location=next(self.model.parameters()).device)
331

332
        if self.stages != checkpoint['meta']['stages']:
×
333
            self.logger.warn(
×
334
                'Stages in the engine and checkpoint are mismatch:'
335
                '\n\nCurrent stages: {}\n\nCheckpoint stages: {}'.format([
336
                    c.to_dict() if isinstance(c, CfgNode) else c
337
                    for c in self.stages
338
                ], checkpoint['meta']['stages']))
339

340
        load_checkpoint(self.model, checkpoint, logger=self.logger, **kwargs)
×
341

342
        self._epoch = checkpoint['meta']['epoch']
×
343
        self._iter = self._start_iter = checkpoint['meta']['iter']
×
344

345
        cumsum, count = 0, 0
×
346
        for stage in self.stages:
×
347
            if self._epoch + 1 <= cumsum + stage['epochs']:
×
348
                break
×
349
            count += 1
×
350
        self._stage = count
×
351

352
        if 'optimizer' in checkpoint:
×
353
            self.optimizer = build_optimizer(
×
354
                self.cur_stage['optimizer'], params=self.model.parameters())
355
            self.optimizer.load_state_dict(checkpoint['optimizer'])
×
356
        else:
357
            raise KeyError('optimizer not found in the checkpoint')
×
358

359
        self.logger.info('Resumed stage {}, epoch {}, iter {}'.format(
×
360
            self._stage + 1, self._epoch, self._iter))
361

362
    def train_iter(self, data):
×
363
        self._call_hook('before_train_iter')
×
364

365
        output = self.model(data, mode=self._mode, **self._kwargs)
×
366

367
        self.losses = {k: v for k, v in output.items() if 'loss' in k}
×
368
        if 'loss' not in output:
×
369
            self.losses['loss'] = output['loss'] = sum(
×
370
                v for v in self.losses.values())
371

372
        for key, value in output.items():
×
373
            self.buffer.update(
×
374
                key,
375
                value.detach().cpu() if torch.is_tensor(value) else value)
376

377
        self._call_hook('after_train_iter')
×
378
        self._iter += 1
×
379

380
    def val_iter(self, data):
×
381
        self._call_hook('before_val_iter')
×
382

383
        with torch.no_grad():
×
384
            output = self.model(data, mode=self._mode, **self._kwargs)
×
385

386
        if any('loss' in key for key in output) and 'loss' not in output:
×
387
            output['loss'] = sum(v for k, v in output.items() if 'loss' in k)
×
388

389
        for key, value in output.items():
×
390
            self.buffer.update(
×
391
                key,
392
                value.detach().cpu() if torch.is_tensor(value) else value)
393

394
        self._call_hook('after_val_iter')
×
395

396
    def test_iter(self, data):
×
397
        with torch.no_grad():
×
398
            output = self.model(data, mode=self._mode, **self._kwargs)
×
399

400
        for key, value in output.items():
×
401
            self.buffer.update(
×
402
                key,
403
                value.detach().cpu() if torch.is_tensor(value) else value)
404

405
    def train_epoch(self):
×
406
        self._mode = 'train'
×
407
        self.model.train()
×
408
        self.data_loader = self.data_loaders[self._mode]
×
409

410
        if callable(getattr(self.data_loader.dataset, 'set_state', None)):
×
411
            self.data_loader.dataset.set_state(self._mode)
×
412

413
        self._call_hook('before_train_epoch')
×
414

415
        for data in self.data_loader:
×
416
            self.train_iter(data)
×
417

418
        self._call_hook('after_train_epoch')
×
419
        self._epoch += 1
×
420

421
    def val_epoch(self):
×
422
        self.logger.info('Validating...')
×
423
        self._mode = 'val'
×
424
        self.model.eval()
×
425
        self.buffer.pop('_out', None)
×
426
        self.data_loader = self.data_loaders[self._mode]
×
427

428
        if callable(getattr(self.data_loader.dataset, 'set_state', None)):
×
429
            self.data_loader.dataset.set_state(self._mode)
×
430

431
        self._call_hook('before_val_epoch')
×
432

433
        for data in nncore.ProgressBar(self.data_loader):
×
434
            self.val_iter(data)
×
435

436
        self._call_hook('after_val_epoch')
×
437

438
    def test_epoch(self):
×
439
        self.logger.info('Evaluating...')
×
440
        self._mode = 'test'
×
441
        self.model.eval()
×
442
        self.buffer.pop('_out', None)
×
443
        self.data_loader = self.data_loaders[self._mode]
×
444

445
        if callable(getattr(self.data_loader.dataset, 'set_state', None)):
×
446
            self.data_loader.dataset.set_state(self._mode)
×
447

448
        for data in nncore.ProgressBar(self.data_loader):
×
449
            self.test_iter(data)
×
450

451
    def run_stage(self):
×
452
        if isinstance(self.cur_stage['optimizer'], dict):
×
453
            optim = self.cur_stage['optimizer'].copy()
×
454
            optim_type = optim.pop('type')
×
455
            optim_args = ['{}: {}'.format(k, v) for k, v in optim.items()]
×
456
            optim = '{}({})'.format(optim_type, ', '.join(optim_args))
×
457
        else:
458
            optim = '{}()'.format(
×
459
                self.cur_stage['optimizer'].__class__.__name__)
460

461
        self.logger.info('Stage: {}, epochs: {}, optimizer: {}'.format(
×
462
            self._stage + 1, self.cur_stage['epochs'], optim))
463

464
        if self.epoch_in_stage == 0:
×
465
            self.optimizer = build_optimizer(
×
466
                self.cur_stage['optimizer'], params=self.model.parameters())
467

468
        self._call_hook('before_stage')
×
469

470
        for _ in range(self.cur_stage['epochs'] - self.epoch_in_stage):
×
471
            self.train_epoch()
×
472
            cfg = self.cur_stage.get('validation')
×
473
            if (cfg is not None and 'val' in self.data_loaders
×
474
                    and cfg.get('interval', 0) > 0
475
                    and self.epoch_in_stage > cfg.get('offset', 0)
476
                    and self.epoch_in_stage % cfg.get('interval', 0) == 0):
477
                self.val_epoch()
×
478

479
        self._call_hook('after_stage')
×
480
        self._stage += 1
×
481

482
    def evaluate(self):
×
483
        """
484
        Perform evaluation. This methods is expected to be called after
485
        validation or testing.
486
        """
487
        blob = self.buffer.pop('_out')
×
488
        blob = gather(blob)
×
489

490
        if is_main_process():
×
491
            blob = nncore.interleave(blob)[:len(self.data_loader.dataset)]
×
492

493
            cfg = self.cur_stage.get('validation')
×
494
            if cfg is not None:
×
495
                cfg = cfg.copy()
×
496
                cfg.pop('interval', None)
×
497
                cfg.pop('offset', None)
×
498
            else:
499
                cfg = dict()
×
500

501
            output = self.data_loader.dataset.evaluate(
×
502
                blob, logger=self.logger, **cfg)
503
        else:
504
            output = dict()
×
505

506
        sync()
×
507
        return output
×
508

509
    def launch(self, eval=False, **kwargs):
×
510
        """
511
        Launch the engine.
512

513
        Args:
514
            eval (bool, optional): Whether to run evaluation only. Default:
515
                ``False``.
516
        """
517
        self._kwargs = kwargs
×
518

519
        if eval:
×
520
            self.test_epoch()
×
521
            output = self.evaluate()
×
522
            self.logger.info(
×
523
                'Evaluation results: ' +
524
                ', '.join(['{}: {}'.format(k, v) for k, v in output.items()]))
525
            return output
×
526

527
        self.logger.info('Launch engine, host: {}, work_dir: {}'.format(
×
528
            nncore.get_host_info(), self.work_dir))
529
        self._call_hook('before_launch')
×
530

531
        while self._stage < self._max_stages:
×
532
            self.run_stage()
×
533

534
        self._call_hook('after_launch')
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc