• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

yeliudev / nncore / 8679058429

14 Apr 2024 08:17AM UTC coverage: 15.764% (-0.007%) from 15.771%
8679058429

push

github

yeliudev
Update engine logger

0 of 4 new or added lines in 2 files covered. (0.0%)

2 existing lines in 2 files now uncovered.

678 of 4301 relevant lines covered (15.76%)

3.05 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/nncore/engine/engine.py
1
# Copyright (c) Ye Liu. Licensed under the MIT License.
2

3
from collections import OrderedDict
×
4

5
import torch
×
6

7
import nncore
×
8
from nncore.nn import build_model
×
9
from nncore.optim import build_optimizer
×
10
from nncore.utils import CfgNode
×
11
from .buffer import Buffer
×
12
from .builder import build_dataloader, build_hook
×
13
from .comm import gather, get_world_size, is_distributed, is_main_process, sync
×
14
from .hooks import Hook
×
15
from .utils import get_checkpoint, load_checkpoint
×
16

17
_DEFAULT_STAGES = [
×
18
    dict(
19
        epochs=5,
20
        optimizer=dict(type='SGD', lr=1e-2, momentum=0.9, weight_decay=1e-4),
21
        lr_schedule=dict(type='iter', policy='cosine'),
22
        warmup=dict(type='iter', policy='linear', steps=500, ratio=0.001),
23
        validation=dict(interval=1))
24
]
25

26
_DEFAULT_HOOKS = [
×
27
    'TimerHook', 'LrUpdaterHook', 'OptimizerHook', 'CheckpointHook',
28
    'EvalHook', 'EventWriterHook'
29
]
30

31

32
@nncore.bind_getter('mode', 'max_stages', 'max_epochs', 'max_iters',
×
33
                    'start_iter', 'stage', 'epoch', 'iter', 'kwargs')
34
class Engine(object):
×
35
    """
36
    An engine that can take over the whole training, validation, and testing
37
    process, with all the baby-sitting works (stage control, optimizer
38
    configuration, lr scheduling, checkpoint management, metrics & tensorboard
39
    writing, etc.) done automatically.
40

41
    Args:
42
        model (:obj:`nn.Module` | cfg | str): The model or config of the model.
43
            The :obj:`forward` method of the model should return a dict
44
            containing a ``_avg_factor`` field indicating the number of
45
            samples in the current batch, and optionally a ``_out`` field
46
            denoting the model outputs to be collected and evaluated.
47
        data_loaders (dict | str): The configs of data loaders for training,
48
            validation, and testing. The dict should be in the format of
49
            ``dict(train=train_loader, val=val_loader, test=test_loader)``.
50
        stages (list[dict] | dict | None, optional): The stage config or list
51
            of stage configs to be scheduled. Each stage config should be a
52
            dict containing the following fields:
53

54
            - `epochs` (int): Number of epochs in the stage.
55
            - `optimizer` (:obj:`optim.Optimizer` | dict): The optimizer or \
56
                an optimizer config containing the following fields:
57

58
                - `type` (str): Type of the optimizer, which can be accessed \
59
                    via :obj:`torch.optim` attributes, e.g. ``'SGD'``.
60
                - `configs for the optimizer, e.g.` ``lr=0.01, momentum=0.9``.
61

62
            - `lr_schedule` (dict, optional): The learning rate schedule \
63
                config containing the following fields:
64

65
                - `type` (str): Type of the learning rate schedule. Expected \
66
                    values include ``'epoch'`` and ``'iter'``, indicating \
67
                    updating learning rates every epoch or iteration.
68
                - `policy` (str): The learning rate policy to use. Currently \
69
                    supported policies include ``step``, ``cosine``, ``exp``, \
70
                    ``poly``, and ``inv``.
71
                - `configs for the learning rate policy, e.g.` \
72
                    ``target_lr=0``. Please refer to :obj:`LrUpdaterHook` for \
73
                    full configs.
74

75
            - `warmup` (dict, optional): The warm-up policy config containing \
76
                the following fields:
77

78
                - `type` (str): Type of the warm-up schedule. Expected values \
79
                    include ``'epoch'`` and ``'iter'``, indicating warming up \
80
                    for ``step`` epochs for iterations.
81
                - `policy` (str): The warm-up policy to use. Currently \
82
                    supported policies include ``linear``, ``exp`` and \
83
                    ``constant``.
84
                - `step` (int): Number of iterations to warm-up.
85
                - `ratio` (float): The ratio of learning rate to start with. \
86
                    Expected values are in the range of ``0 ~ 1``.
87

88
            - `validation` (dict, optional): The validation config containing \
89
                the following fields:
90

91
                - `interval` (int, optional): The interval of performing \
92
                    validation. ``0`` means not performing validation. \
93
                    Default: ``0``.
94
                - `offset` (int, optional): The number of epochs to skip \
95
                    before counting the interval. Default: ``0``.
96

97
            Default: ``None``.
98
        hooks (list[:obj:`Hook` | dict | str] | None, optional): The list of
99
            extra hooks to be registered. Each hook can be represented as a
100
            :obj:`Hook`, a dict or a str. Default: ``None``.
101
        buffer_size (int, optional): Maximum size of the buffer. Default:
102
            ``100000``.
103
        logger (:obj:`logging.Logger` | str | None, optional): The logger or
104
            name of the logger to use. Default: ``None``.
105
        work_dir (str | None, optional): Path to the working directory. If not
106
            specified, the default working directory will be used. Default:
107
            ``None``.
108
        seed (int | None, optional): The random seed to use in data loaders.
109
            Default: ``None``.
110
        meta (any | None, optional): A dictionary-like object containing meta
111
            data of this engine. Default: ``None``.
112
        amp (dict | str | bool | None, optional): Whether to use automatic
113
            mixed precision training. Default: ``None``.
114
        debug (bool, optional): Whether to activate debug mode. Default:
115
            ``False``.
116

117
    Example:
118
        >>> # Build model
119
        >>> model = build_model()
120
        ...
121
        >>> # Build data loaders
122
        >>> train_loader = build_dataloader(split='train')
123
        >>> val_loader = build_dataloader(split='val')
124
        >>> data_loaders = dict(train=train_loader, val=val_loader)
125
        ...
126
        >>> # Configure stages:
127
        >>> # [Stage 1] Train the model for 5 epochs using Adam optimizer with
128
        >>> # a fixed learning rate (1e-3) and a linear warm-up policy.
129
        >>> # [Stage 2] Train the model for another 3 epochs using SGD with
130
        >>> # momentum optimizer and an iter-based cosine learning rate
131
        >>> # schedule. Perform validation after every training epoch.
132
        >>> stages = [
133
        ...     dict(
134
        ...         epochs=5,
135
        ...         optimizer=dict(type='Adam', lr=1e-3),
136
        ...         warmup=dict(type='iter', policy='linear', steps=500)),
137
        ...     dict(
138
        ...         epochs=3,
139
        ...         optimizer=dict(type='SGD', lr=1e-3, momentum=0.9),
140
        ...         lr_schedule=dict(type='iter', policy='cosine'),
141
        ...         validation=dict(interval=1))
142
        ... ]
143
        ...
144
        >>> # Initialize and launch engine
145
        >>> engine = Engine(model, data_loaders, stages=stages)
146
        >>> engine.launch()
147
    """
148

149
    def __init__(self,
×
150
                 model,
151
                 data_loaders,
152
                 stages=None,
153
                 hooks=None,
154
                 buffer_size=100000,
155
                 logger=None,
156
                 work_dir=None,
157
                 seed=None,
158
                 meta=None,
159
                 amp=None,
160
                 debug=False,
161
                 **kwargs):
162
        self.model = build_model(model, **kwargs)
×
163

164
        if 'train' not in data_loaders:
×
165
            data_loaders = dict(train=data_loaders)
×
166

167
        for a, b in (('val', 'test'), ('test', 'val')):
×
168
            if a not in data_loaders:
×
169
                loader = data_loaders[b if b in data_loaders else 'train']
×
170
                if isinstance(loader, dict):
×
171
                    data_loaders[a] = loader.copy()
×
172
                else:
173
                    data_loaders[a] = loader
×
174

175
        self.data_loaders = {
×
176
            k: build_dataloader(v, seed=seed)
177
            for k, v in data_loaders.items()
178
        }
179

180
        if isinstance(stages, dict):
×
181
            self.stages = [stages]
×
182
        else:
183
            self.stages = stages or _DEFAULT_STAGES
×
184

185
        self.register_hook(_DEFAULT_HOOKS)
×
186
        if is_distributed():
×
187
            self.register_hook('SamplerSeedHook', before='OptimizerHook')
×
188
        if hooks is not None:
×
189
            self.register_hook(hooks)
×
190

191
        time_str = nncore.get_timestamp()
×
192
        self.work_dir = work_dir or nncore.join('work_dirs', time_str)
×
193

194
        log_file = nncore.join(self.work_dir, time_str + '.log')
×
195
        self.logger = nncore.get_logger(logger, log_file=log_file)
×
196

197
        self.buffer = Buffer(max_size=buffer_size, logger=self.logger)
×
198
        self.reset_states()
×
199

200
        if isinstance(amp, dict):
×
201
            amp.setdefault('enabled', True)
×
202
            self.amp_cfg = amp
×
203
        elif isinstance(amp, str):
×
204
            if amp in ('fp16', 'float16'):
×
205
                dtype = torch.float16
×
206
            elif amp in ('bf16', 'bfloat16'):
×
207
                dtype = torch.bfloat16
×
208
            else:
209
                raise TypeError(
×
210
                    "Amp training only supports 'float16' or 'bfloat16' data "
211
                    "types, but got '{}'".format(amp))
212
            self.amp_cfg = dict(enabled=True, dtype=dtype)
×
213
        else:
214
            self.amp_cfg = dict(enabled=bool(amp))
×
215

216
        self.meta = meta
×
217
        self.debug = debug
×
NEW
218
        self._kwargs = dict()
×
219

220
    @property
×
221
    def cur_stage(self):
×
222
        return self.stages[self._stage]
×
223

224
    @property
×
225
    def epoch_in_stage(self):
×
226
        cumsum = 0
×
227
        for stage in self.stages:
×
228
            if self._epoch + 1 <= cumsum + stage['epochs']:
×
229
                return self._epoch - cumsum
×
230
            cumsum += stage['epochs']
×
231
        return self.stages[-1]['epochs']
×
232

233
    @property
×
234
    def iter_in_stage(self):
×
235
        cumsum = 0
×
236
        for i in range(self._stage):
×
237
            cumsum += len(
×
238
                self.data_loaders['train']) * self.stages[i]['epochs']
239
        return self._iter - cumsum
×
240

241
    @property
×
242
    def iter_in_epoch(self):
×
243
        return self._iter - len(self.data_loaders['train']) * self._epoch
×
244

245
    def _call_hook(self, name):
×
246
        for hook in self.hooks.values():
×
247
            getattr(hook, name)(self)
×
248

249
    def get_amp_type(self):
×
250
        if self.amp_cfg['enabled']:
×
251
            dtype = self.amp_cfg.get('dtype', torch.float16)
×
252
            return 'fp16' if dtype is torch.float16 else 'bf16'
×
253

254
    def reset_states(self):
×
255
        self.buffer.clear()
×
256
        self._max_stages = 0 if self.stages is None else len(self.stages)
×
257
        self._max_epochs = 0 if self.stages is None else sum(
×
258
            stage['epochs'] for stage in self.stages)
259
        self._max_iters = (len(self.data_loaders['train']) if 'train'
×
260
                           in self.data_loaders else 0) * self._max_epochs
261
        self._start_iter = self._stage = self._epoch = self._iter = 0
×
262

263
    def register_hook(self, hook, before=None, overwrite=True, **kwargs):
×
264
        """
265
        Register a hook or a list of hooks into the engine.
266

267
        Args:
268
            hook (list | :obj:`Hook` | dict | str): The hook or list of hooks
269
                to be registered. Each hook can be represented as a
270
                :obj:`Hook`, a dict or a str.
271
            before (str, optional): Name of the hook to be inserted before. If
272
                not specified, the new hook will be added to the end of hook
273
                list. Default: ``None``.
274
            overwrite (bool, optional): Whether to overwrite the old hook with
275
                the same name if exists. Default: ``True``.
276
        """
277
        if isinstance(hook, (list, tuple)):
×
278
            for h in hook:
×
279
                self.register_hook(
×
280
                    h, before=before, overwrite=overwrite, **kwargs)
281
            return
×
282
        elif isinstance(hook, (dict, str)):
×
283
            hook = build_hook(hook, **kwargs)
×
284
        elif not isinstance(hook, Hook):
×
285
            raise TypeError(
×
286
                "hook must be a Hook, a dict or a str, but got '{}'".format(
287
                    type(hook)))
288

289
        if not hasattr(self, 'hooks'):
×
290
            self.hooks = OrderedDict()
×
291

292
        if hook.name in self.hooks:
×
293
            if overwrite:
×
294
                keys = list(self.hooks.keys())
×
295
                if before is None and keys[-1] != hook.name:
×
296
                    before = keys[keys.index(hook.name) + 1]
×
297
                self.hooks.pop(hook.name)
×
298
            else:
299
                raise KeyError("hook '{}' exists".format(hook.name))
×
300

301
        self.hooks[hook.name] = hook
×
302

303
        if before is not None:
×
304
            if before not in self.hooks:
×
305
                raise ValueError("hook '{}' not found".format(before))
×
306

307
            keys = list(self.hooks.keys())
×
308
            for key in keys[keys.index(before):-1]:
×
309
                self.hooks.move_to_end(key)
×
310

311
    def unregister_hook(self, hook):
×
312
        """
313
        Unregister a hook or a list of hooks from the engine.
314

315
        Args:
316
            hook (list | :obj:`Hook` | str): The hook or list of hooks to be
317
                unregistered. Each hook can be represented as a :obj:`Hook` or
318
                a str.
319
        """
320
        if isinstance(hook, (list, tuple)):
×
321
            for h in hook:
×
322
                self.unregister_hook(h)
×
323
            return
×
324

325
        if isinstance(hook, Hook):
×
326
            hook = hook.name
×
327
        self.hooks.pop(hook)
×
328

329
    def load_checkpoint(self, checkpoint, **kwargs):
×
330
        """
331
        Load checkpoint from a file or an URL.
332

333
        Args:
334
            checkpoint (dict | str): A dict, a filename, an URL or a
335
                ``torchvision://<model_name>`` str indicating the checkpoint.
336
        """
337
        load_checkpoint(
×
338
            self.model,
339
            checkpoint,
340
            map_location=next(self.model.parameters()).device,
341
            logger=self.logger,
342
            **kwargs)
343

344
        if isinstance(checkpoint, str):
×
345
            self.logger.info('Loaded checkpoint from {}'.format(checkpoint))
×
346
        else:
347
            self.logger.info('Loaded checkpoint')
×
348

349
    def resume(self, checkpoint, **kwargs):
×
350
        """
351
        Resume from a checkpoint file.
352

353
        Args:
354
            checkpoint (dict | str): A dict, a filename or an URL indicatin
355
                the checkpoint.
356
        """
357
        if isinstance(checkpoint, str):
×
358
            checkpoint = get_checkpoint(
×
359
                checkpoint, map_location=next(self.model.parameters()).device)
360

361
        if self.stages != checkpoint['meta']['stages']:
×
362
            self.logger.warn(
×
363
                'Stages in the engine and checkpoint are mismatched:'
364
                '\n\nCurrent stages: {}\n\nCheckpoint stages: {}\n'.format([
365
                    c.to_dict() if isinstance(c, CfgNode) else c
366
                    for c in self.stages
367
                ], checkpoint['meta']['stages']))
368

369
        load_checkpoint(self.model, checkpoint, logger=self.logger, **kwargs)
×
370

371
        self._epoch = checkpoint['meta']['epoch']
×
372
        self._iter = self._start_iter = checkpoint['meta']['iter']
×
373

374
        cumsum, count = 0, 0
×
375
        for stage in self.stages:
×
376
            if self._epoch + 1 <= cumsum + stage['epochs']:
×
377
                break
×
378
            count += 1
×
379
        self._stage = count
×
380

381
        if 'optimizer' in checkpoint:
×
382
            self.optimizer = build_optimizer(
×
383
                self.cur_stage['optimizer'],
384
                params=[p for p in self.model.parameters() if p.requires_grad])
385
            self.optimizer.load_state_dict(checkpoint['optimizer'])
×
386
        else:
387
            raise KeyError('optimizer not found in the checkpoint')
×
388

389
        self.logger.info('Resumed stage {}, epoch {}, iter {}'.format(
×
390
            self._stage + 1, self._epoch, self._iter))
391

392
    def train_iter(self, data):
×
393
        self._call_hook('before_train_iter')
×
394

395
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
×
396
        with torch.autocast(device, **self.amp_cfg):
×
397
            output = self.model(data, mode=self._mode, **self._kwargs)
×
398

399
            self.losses = {k: v for k, v in output.items() if 'loss' in k}
×
400
            if 'loss' not in output:
×
401
                self.losses['loss'] = output['loss'] = sum(
×
402
                    v for v in self.losses.values())
403

404
        for key, value in output.items():
×
405
            self.buffer.update(
×
406
                key,
407
                value.detach().cpu() if torch.is_tensor(value) else value)
408

409
        self._call_hook('after_train_iter')
×
410
        self._iter += 1
×
411

412
    def val_iter(self, data):
×
413
        self._call_hook('before_val_iter')
×
414

415
        with torch.no_grad():
×
416
            output = self.model(data, mode=self._mode, **self._kwargs)
×
417

418
        if any('loss' in key for key in output) and 'loss' not in output:
×
419
            output['loss'] = sum(v for k, v in output.items() if 'loss' in k)
×
420

421
        for key, value in output.items():
×
422
            self.buffer.update(
×
423
                key,
424
                value.detach().cpu() if torch.is_tensor(value) else value)
425

426
        self._call_hook('after_val_iter')
×
427

428
    def test_iter(self, data):
×
429
        with torch.no_grad():
×
430
            output = self.model(data, mode=self._mode, **self._kwargs)
×
431

432
        for key, value in output.items():
×
433
            self.buffer.update(
×
434
                key,
435
                value.detach().cpu() if torch.is_tensor(value) else value)
436

437
    def train_epoch(self):
×
438
        self._mode = 'train'
×
439
        self.model.train()
×
440
        self.data_loader = self.data_loaders[self._mode]
×
441

442
        if callable(getattr(self.data_loader.dataset, 'set_state', None)):
×
443
            self.data_loader.dataset.set_state(self._mode)
×
444

445
        self._call_hook('before_train_epoch')
×
446

447
        for data in self.data_loader:
×
448
            self.train_iter(data)
×
449

450
        self._call_hook('after_train_epoch')
×
451
        self._epoch += 1
×
452

453
    def val_epoch(self):
×
454
        self._mode = 'val'
×
455
        self.model.eval()
×
456
        self.buffer.pop('_out', None)
×
457
        self.data_loader = self.data_loaders[self._mode]
×
458

459
        if callable(getattr(self.data_loader.dataset, 'set_state', None)):
×
460
            self.data_loader.dataset.set_state(self._mode)
×
461

462
        self._call_hook('before_val_epoch')
×
463

464
        for data in nncore.ProgressBar(self.data_loader):
×
465
            self.val_iter(data)
×
466

467
        self._call_hook('after_val_epoch')
×
468

469
    def test_epoch(self):
×
470
        self._mode = 'test'
×
471
        self.model.eval()
×
472
        self.buffer.pop('_out', None)
×
473
        self.data_loader = self.data_loaders[self._mode]
×
474

475
        if callable(getattr(self.data_loader.dataset, 'set_state', None)):
×
476
            self.data_loader.dataset.set_state(self._mode)
×
477

478
        for data in nncore.ProgressBar(self.data_loader):
×
479
            self.test_iter(data)
×
480

481
    def run_stage(self):
×
482
        if isinstance(self.cur_stage['optimizer'], dict):
×
483
            optim_cfg = self.cur_stage['optimizer'].copy()
×
484
            optim_type = optim_cfg.pop('type')
×
485
            optim_args = ['{}: {}'.format(k, v) for k, v in optim_cfg.items()]
×
486
            optim_str = '{}({})'.format(optim_type, ', '.join(optim_args))
×
487
        else:
488
            optim_str = '{}()'.format(
×
489
                self.cur_stage['optimizer'].__class__.__name__)
490

491
        self.logger.info('Stage: {}, epochs: {}, optimizer: {}'.format(
×
492
            self._stage + 1, self.cur_stage['epochs'], optim_str))
493

494
        if self.epoch_in_stage == 0:
×
495
            self.optimizer = build_optimizer(
×
496
                self.cur_stage['optimizer'],
497
                params=[p for p in self.model.parameters() if p.requires_grad])
498

499
        self._call_hook('before_stage')
×
500

501
        for _ in range(self.cur_stage['epochs'] - self.epoch_in_stage):
×
502
            self.train_epoch()
×
503
            cfg = self.cur_stage.get('validation')
×
504
            if (cfg is not None and 'val' in self.data_loaders
×
505
                    and cfg.get('interval', 0) > 0
506
                    and self.epoch_in_stage > cfg.get('offset', 0)
507
                    and self.epoch_in_stage % cfg.get('interval', 0) == 0):
NEW
508
                self.logger.info('Validating...')
×
UNCOV
509
                self.val_epoch()
×
510

511
        self._call_hook('after_stage')
×
512
        self._stage += 1
×
513

514
    def evaluate(self):
×
515
        """
516
        Perform evaluation. This methods is expected to be called after
517
        validation or testing.
518
        """
519
        blob = self.buffer.pop('_out')
×
520
        blob = gather(blob)
×
521

522
        if is_main_process():
×
523
            blob = nncore.interleave(blob)[:len(self.data_loader.dataset)]
×
524

525
            cfg = self.cur_stage.get('validation')
×
526
            if cfg is not None:
×
527
                cfg = cfg.copy()
×
528
                cfg.pop('interval', None)
×
529
                cfg.pop('offset', None)
×
530
            else:
531
                cfg = dict()
×
532

533
            output = self.data_loader.dataset.evaluate(
×
534
                blob, logger=self.logger, **cfg)
535
        else:
536
            output = dict()
×
537

538
        sync()
×
539
        return output
×
540

541
    def launch(self, eval=False, **kwargs):
×
542
        """
543
        Launch the engine.
544

545
        Args:
546
            eval (bool, optional): Whether to run evaluation only. Default:
547
                ``False``.
548
        """
549
        self._kwargs = kwargs
×
550

551
        if eval:
×
NEW
552
            self.logger.info('Evaluating...')
×
553
            self.test_epoch()
×
554
            output = self.evaluate()
×
555
            self.logger.info(
×
556
                'Evaluation results: ' +
557
                ', '.join(['{}: {}'.format(k, v) for k, v in output.items()]))
558
            return output
×
559

560
        self.logger.info('Distributed: {}, AMP: {}, Debug: {}'.format(
×
561
            f'{get_world_size()} processes' if is_distributed() else False,
562
            self.get_amp_type(), self.debug))
563
        self.logger.info('Launch engine, host: {}, work_dir: {}'.format(
×
564
            nncore.get_host_info(), self.work_dir))
565

566
        self._call_hook('before_launch')
×
567

568
        while self._stage < self._max_stages:
×
569
            self.run_stage()
×
570

571
        self._call_hook('after_launch')
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc