• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

yeliudev / nncore / 6291836846

24 Sep 2023 06:57PM UTC coverage: 16.46% (-0.2%) from 16.646%
6291836846

push

github

yeliudev
Add support for amp

60 of 60 new or added lines in 3 files covered. (100.0%)

677 of 4113 relevant lines covered (16.46%)

3.18 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/nncore/engine/engine.py
1
# Copyright (c) Ye Liu. Licensed under the MIT License.
2

3
from collections import OrderedDict
×
4

5
import torch
×
6

7
import nncore
×
8
from nncore.nn import build_model
×
9
from nncore.optim import build_optimizer
×
10
from nncore.utils import CfgNode
×
11
from .buffer import Buffer
×
12
from .builder import build_dataloader, build_hook
×
13
from .comm import gather, is_distributed, is_main_process, sync
×
14
from .hooks import Hook
×
15
from .utils import get_checkpoint, load_checkpoint
×
16

17
_DEFAULT_STAGES = [
×
18
    dict(
19
        epochs=5,
20
        optimizer=dict(type='SGD', lr=1e-2, momentum=0.9, weight_decay=1e-4),
21
        lr_schedule=dict(type='iter', policy='cosine'),
22
        warmup=dict(type='iter', policy='linear', steps=500, ratio=0.001),
23
        validation=dict(interval=1))
24
]
25

26
_DEFAULT_HOOKS = [
×
27
    'TimerHook', 'LrUpdaterHook', 'OptimizerHook', 'CheckpointHook',
28
    'EvalHook', 'EventWriterHook'
29
]
30

31

32
@nncore.bind_getter('mode', 'max_stages', 'max_epochs', 'max_iters',
×
33
                    'start_iter', 'stage', 'epoch', 'iter', 'kwargs')
34
class Engine(object):
×
35
    """
36
    An engine that can take over the whole training, validation, and testing
37
    process, with all the baby-sitting works (stage control, optimizer
38
    configuration, lr scheduling, checkpoint management, metrics & tensorboard
39
    writing, etc.) done automatically.
40

41
    Args:
42
        model (:obj:`nn.Module` | cfg | str): The model or config of the model.
43
            The :obj:`forward` method of the model should return a dict
44
            containing a ``_avg_factor`` field indicating the number of
45
            samples in the current batch, and optionally a ``_out`` field
46
            denoting the model outputs to be collected and evaluated.
47
        data_loaders (dict | str): The configs of data loaders for training,
48
            validation, and testing. The dict should be in the format of
49
            ``dict(train=train_loader, val=val_loader, test=test_loader)``.
50
        stages (list[dict] | dict | None, optional): The stage config or list
51
            of stage configs to be scheduled. Each stage config should be a
52
            dict containing the following fields:
53

54
            - `epochs` (int): Number of epochs in the stage.
55
            - `optimizer` (:obj:`optim.Optimizer` | dict): The optimizer or \
56
                an optimizer config containing the following fields:
57

58
                - `type` (str): Type of the optimizer, which can be accessed \
59
                    via :obj:`torch.optim` attributes, e.g. ``'SGD'``.
60
                - `configs for the optimizer, e.g.` ``lr=0.01, momentum=0.9``.
61

62
            - `lr_schedule` (dict, optional): The learning rate schedule \
63
                config containing the following fields:
64

65
                - `type` (str): Type of the learning rate schedule. Expected \
66
                    values include ``'epoch'`` and ``'iter'``, indicating \
67
                    updating learning rates every epoch or iteration.
68
                - `policy` (str): The learning rate policy to use. Currently \
69
                    supported policies include ``step``, ``cosine``, ``exp``, \
70
                    ``poly``, and ``inv``.
71
                - `configs for the learning rate policy, e.g.` \
72
                    ``target_lr=0``. Please refer to :obj:`LrUpdaterHook` for \
73
                    full configs.
74

75
            - `warmup` (dict, optional): The warm-up policy config containing \
76
                the following fields:
77

78
                - `type` (str): Type of the warm-up schedule. Expected values \
79
                    include ``'epoch'`` and ``'iter'``, indicating warming up \
80
                    for ``step`` epochs for iterations.
81
                - `policy` (str): The warm-up policy to use. Currently \
82
                    supported policies include ``linear``, ``exp`` and \
83
                    ``constant``.
84
                - `step` (int): Number of iterations to warm-up.
85
                - `ratio` (float): The ratio of learning rate to start with. \
86
                    Expected values are in the range of ``0 ~ 1``.
87

88
            - `validation` (dict, optional): The validation config containing \
89
                the following fields:
90

91
                - `interval` (int, optional): The interval of performing \
92
                    validation. ``0`` means not performing validation. \
93
                    Default: ``0``.
94
                - `offset` (int, optional): The number of epochs to skip \
95
                    before counting the interval. Default: ``0``.
96

97
            Default: ``None``.
98
        hooks (list[:obj:`Hook` | dict | str] | None, optional): The list of
99
            extra hooks to be registered. Each hook can be represented as a
100
            :obj:`Hook`, a dict or a str. Default: ``None``.
101
        buffer_size (int, optional): Maximum size of the buffer. Default:
102
            ``100000``.
103
        logger (:obj:`logging.Logger` | str | None, optional): The logger or
104
            name of the logger to use. Default: ``None``.
105
        work_dir (str | None, optional): Path to the working directory. If not
106
            specified, the default working directory will be used. Default:
107
            ``None``.
108
        seed (int | None, optional): The random seed to use in data loaders.
109
            Default: ``None``.
110
        meta (any | None, optional): A dictionary-like object containing meta
111
            data of this engine. Default: ``None``.
112
        amp (dict | str | bool | None, optional): Whether to use automatic
113
            mixed precision training. Default: ``None``.
114
        debug (bool, optional): Whether to activate debug mode. Default:
115
            ``False``.
116

117
    Example:
118
        >>> # Build model
119
        >>> model = build_model()
120
        ...
121
        >>> # Build data loaders
122
        >>> train_loader = build_dataloader(split='train')
123
        >>> val_loader = build_dataloader(split='val')
124
        >>> data_loaders = dict(train=train_loader, val=val_loader)
125
        ...
126
        >>> # Configure stages:
127
        >>> # [Stage 1] Train the model for 5 epochs using Adam optimizer with
128
        >>> # a fixed learning rate (1e-3) and a linear warm-up policy.
129
        >>> # [Stage 2] Train the model for another 3 epochs using SGD with
130
        >>> # momentum optimizer and an iter-based cosine learning rate
131
        >>> # schedule. Perform validation after every training epoch.
132
        >>> stages = [
133
        ...     dict(
134
        ...         epochs=5,
135
        ...         optimizer=dict(type='Adam', lr=1e-3),
136
        ...         warmup=dict(type='iter', policy='linear', steps=500)),
137
        ...     dict(
138
        ...         epochs=3,
139
        ...         optimizer=dict(type='SGD', lr=1e-3, momentum=0.9),
140
        ...         lr_schedule=dict(type='iter', policy='cosine'),
141
        ...         validation=dict(interval=1))
142
        ... ]
143
        ...
144
        >>> # Initialize and launch engine
145
        >>> engine = Engine(model, data_loaders, stages=stages)
146
        >>> engine.launch()
147
    """
148

149
    def __init__(self,
×
150
                 model,
151
                 data_loaders,
152
                 stages=None,
153
                 hooks=None,
154
                 buffer_size=100000,
155
                 logger=None,
156
                 work_dir=None,
157
                 seed=None,
158
                 meta=None,
159
                 amp=None,
160
                 debug=False,
161
                 **kwargs):
162
        self.model = build_model(model, **kwargs)
×
163

164
        if 'train' not in data_loaders:
×
165
            data_loaders = dict(train=data_loaders)
×
166

167
        for a, b in (('val', 'test'), ('test', 'val')):
×
168
            if a not in data_loaders:
×
169
                loader = data_loaders[b if b in data_loaders else 'train']
×
170
                if isinstance(loader, dict):
×
171
                    data_loaders[a] = loader.copy()
×
172
                else:
173
                    data_loaders[a] = loader
×
174

175
        self.data_loaders = {
×
176
            k: build_dataloader(v, seed=seed)
177
            for k, v in data_loaders.items()
178
        }
179

180
        if isinstance(stages, dict):
×
181
            self.stages = [stages]
×
182
        else:
183
            self.stages = stages or _DEFAULT_STAGES
×
184

185
        self.register_hook(_DEFAULT_HOOKS)
×
186
        if is_distributed():
×
187
            self.register_hook('SamplerSeedHook', before='OptimizerHook')
×
188
        if hooks is not None:
×
189
            self.register_hook(hooks)
×
190

191
        time_str = nncore.get_timestamp()
×
192
        self.work_dir = work_dir or nncore.join('work_dirs', time_str)
×
193

194
        log_file = nncore.join(self.work_dir, time_str + '.log')
×
195
        self.logger = nncore.get_logger(logger, log_file=log_file)
×
196

197
        self.buffer = Buffer(max_size=buffer_size, logger=self.logger)
×
198
        self.reset_states()
×
199

200
        if isinstance(amp, dict):
×
201
            amp.setdefault('enabled', True)
×
202
            self.amp_cfg = amp
×
203
        elif isinstance(amp, str):
×
204
            if amp in ('fp16', 'float16'):
×
205
                dtype = torch.float16
×
206
            elif amp in ('bf16', 'bfloat16'):
×
207
                dtype = torch.bfloat16
×
208
            else:
209
                raise TypeError(
×
210
                    "Amp training only supports 'float16' or 'bfloat16' data "
211
                    "types, but got '{}'".format(amp))
212
            self.amp_cfg = dict(enabled=True, dtype=dtype)
×
213
        else:
214
            self.amp_cfg = dict(enabled=bool(amp))
×
215

216
        self.meta = meta
×
217
        self.debug = debug
×
218

219
    @property
×
220
    def cur_stage(self):
×
221
        return self.stages[self._stage]
×
222

223
    @property
×
224
    def epoch_in_stage(self):
×
225
        cumsum = 0
×
226
        for stage in self.stages:
×
227
            if self._epoch + 1 <= cumsum + stage['epochs']:
×
228
                return self._epoch - cumsum
×
229
            cumsum += stage['epochs']
×
230
        return self.stages[-1]['epochs']
×
231

232
    @property
×
233
    def iter_in_stage(self):
×
234
        cumsum = 0
×
235
        for i in range(self._stage):
×
236
            cumsum += len(
×
237
                self.data_loaders['train']) * self.stages[i]['epochs']
238
        return self._iter - cumsum
×
239

240
    @property
×
241
    def iter_in_epoch(self):
×
242
        return self._iter - len(self.data_loaders['train']) * self._epoch
×
243

244
    def _call_hook(self, name):
×
245
        for hook in self.hooks.values():
×
246
            getattr(hook, name)(self)
×
247

248
    def get_amp_type(self):
×
249
        if self.amp_cfg['enabled']:
×
250
            dtype = self.amp_cfg.get('dtype', torch.float16)
×
251
            return 'fp16' if dtype is torch.float16 else 'bf16'
×
252

253
    def reset_states(self):
×
254
        self.buffer.clear()
×
255
        self._max_stages = 0 if self.stages is None else len(self.stages)
×
256
        self._max_epochs = 0 if self.stages is None else sum(
×
257
            stage['epochs'] for stage in self.stages)
258
        self._max_iters = (len(self.data_loaders['train']) if 'train'
×
259
                           in self.data_loaders else 0) * self._max_epochs
260
        self._start_iter = self._stage = self._epoch = self._iter = 0
×
261

262
    def register_hook(self, hook, before=None, overwrite=True, **kwargs):
×
263
        """
264
        Register a hook or a list of hooks into the engine.
265

266
        Args:
267
            hook (list | :obj:`Hook` | dict | str): The hook or list of hooks
268
                to be registered. Each hook can be represented as a
269
                :obj:`Hook`, a dict or a str.
270
            before (str, optional): Name of the hook to be inserted before. If
271
                not specified, the new hook will be added to the end of hook
272
                list. Default: ``None``.
273
            overwrite (bool, optional): Whether to overwrite the old hook with
274
                the same name if exists. Default: ``True``.
275
        """
276
        if isinstance(hook, (list, tuple)):
×
277
            for h in hook:
×
278
                self.register_hook(
×
279
                    h, before=before, overwrite=overwrite, **kwargs)
280
            return
×
281
        elif isinstance(hook, (dict, str)):
×
282
            hook = build_hook(hook, **kwargs)
×
283
        elif not isinstance(hook, Hook):
×
284
            raise TypeError(
×
285
                "hook must be a Hook, a dict or a str, but got '{}'".format(
286
                    type(hook)))
287

288
        if not hasattr(self, 'hooks'):
×
289
            self.hooks = OrderedDict()
×
290

291
        if hook.name in self.hooks:
×
292
            if overwrite:
×
293
                keys = list(self.hooks.keys())
×
294
                if before is None and keys[-1] != hook.name:
×
295
                    before = keys[keys.index(hook.name) + 1]
×
296
                self.hooks.pop(hook.name)
×
297
            else:
298
                raise KeyError("hook '{}' exists".format(hook.name))
×
299

300
        self.hooks[hook.name] = hook
×
301

302
        if before is not None:
×
303
            if before not in self.hooks:
×
304
                raise ValueError("hook '{}' not found".format(before))
×
305

306
            keys = list(self.hooks.keys())
×
307
            for key in keys[keys.index(before):-1]:
×
308
                self.hooks.move_to_end(key)
×
309

310
    def unregister_hook(self, hook):
×
311
        """
312
        Unregister a hook or a list of hooks from the engine.
313

314
        Args:
315
            hook (list | :obj:`Hook` | str): The hook or list of hooks to be
316
                unregistered. Each hook can be represented as a :obj:`Hook` or
317
                a str.
318
        """
319
        if isinstance(hook, (list, tuple)):
×
320
            for h in hook:
×
321
                self.unregister_hook(h)
×
322
            return
×
323

324
        if isinstance(hook, Hook):
×
325
            hook = hook.name
×
326
        self.hooks.pop(hook)
×
327

328
    def load_checkpoint(self, checkpoint, **kwargs):
×
329
        """
330
        Load checkpoint from a file or an URL.
331

332
        Args:
333
            checkpoint (dict | str): A dict, a filename, an URL or a
334
                ``torchvision://<model_name>`` str indicating the checkpoint.
335
        """
336
        load_checkpoint(
×
337
            self.model,
338
            checkpoint,
339
            map_location=next(self.model.parameters()).device,
340
            logger=self.logger,
341
            **kwargs)
342

343
        if isinstance(checkpoint, str):
×
344
            self.logger.info('Loaded checkpoint from {}'.format(checkpoint))
×
345
        else:
346
            self.logger.info('Loaded checkpoint')
×
347

348
    def resume(self, checkpoint, **kwargs):
×
349
        """
350
        Resume from a checkpoint file.
351

352
        Args:
353
            checkpoint (dict | str): A dict, a filename or an URL indicatin
354
                the checkpoint.
355
        """
356
        if isinstance(checkpoint, str):
×
357
            checkpoint = get_checkpoint(
×
358
                checkpoint, map_location=next(self.model.parameters()).device)
359

360
        if self.stages != checkpoint['meta']['stages']:
×
361
            self.logger.warn(
×
362
                'Stages in the engine and checkpoint are mismatch:'
363
                '\n\nCurrent stages: {}\n\nCheckpoint stages: {}'.format([
364
                    c.to_dict() if isinstance(c, CfgNode) else c
365
                    for c in self.stages
366
                ], checkpoint['meta']['stages']))
367

368
        load_checkpoint(self.model, checkpoint, logger=self.logger, **kwargs)
×
369

370
        self._epoch = checkpoint['meta']['epoch']
×
371
        self._iter = self._start_iter = checkpoint['meta']['iter']
×
372

373
        cumsum, count = 0, 0
×
374
        for stage in self.stages:
×
375
            if self._epoch + 1 <= cumsum + stage['epochs']:
×
376
                break
×
377
            count += 1
×
378
        self._stage = count
×
379

380
        if 'optimizer' in checkpoint:
×
381
            self.optimizer = build_optimizer(
×
382
                self.cur_stage['optimizer'], params=self.model.parameters())
383
            self.optimizer.load_state_dict(checkpoint['optimizer'])
×
384
        else:
385
            raise KeyError('optimizer not found in the checkpoint')
×
386

387
        self.logger.info('Resumed stage {}, epoch {}, iter {}'.format(
×
388
            self._stage + 1, self._epoch, self._iter))
389

390
    def train_iter(self, data):
×
391
        self._call_hook('before_train_iter')
×
392

393
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
×
394
        with torch.autocast(device, **self.amp_cfg):
×
395
            output = self.model(data, mode=self._mode, **self._kwargs)
×
396

397
            self.losses = {k: v for k, v in output.items() if 'loss' in k}
×
398
            if 'loss' not in output:
×
399
                self.losses['loss'] = output['loss'] = sum(
×
400
                    v for v in self.losses.values())
401

402
        for key, value in output.items():
×
403
            self.buffer.update(
×
404
                key,
405
                value.detach().cpu() if torch.is_tensor(value) else value)
406

407
        self._call_hook('after_train_iter')
×
408
        self._iter += 1
×
409

410
    def val_iter(self, data):
×
411
        self._call_hook('before_val_iter')
×
412

413
        with torch.no_grad():
×
414
            output = self.model(data, mode=self._mode, **self._kwargs)
×
415

416
        if any('loss' in key for key in output) and 'loss' not in output:
×
417
            output['loss'] = sum(v for k, v in output.items() if 'loss' in k)
×
418

419
        for key, value in output.items():
×
420
            self.buffer.update(
×
421
                key,
422
                value.detach().cpu() if torch.is_tensor(value) else value)
423

424
        self._call_hook('after_val_iter')
×
425

426
    def test_iter(self, data):
×
427
        with torch.no_grad():
×
428
            output = self.model(data, mode=self._mode, **self._kwargs)
×
429

430
        for key, value in output.items():
×
431
            self.buffer.update(
×
432
                key,
433
                value.detach().cpu() if torch.is_tensor(value) else value)
434

435
    def train_epoch(self):
×
436
        self._mode = 'train'
×
437
        self.model.train()
×
438
        self.data_loader = self.data_loaders[self._mode]
×
439

440
        if callable(getattr(self.data_loader.dataset, 'set_state', None)):
×
441
            self.data_loader.dataset.set_state(self._mode)
×
442

443
        self._call_hook('before_train_epoch')
×
444

445
        for data in self.data_loader:
×
446
            self.train_iter(data)
×
447

448
        self._call_hook('after_train_epoch')
×
449
        self._epoch += 1
×
450

451
    def val_epoch(self):
×
452
        self.logger.info('Validating...')
×
453
        self._mode = 'val'
×
454
        self.model.eval()
×
455
        self.buffer.pop('_out', None)
×
456
        self.data_loader = self.data_loaders[self._mode]
×
457

458
        if callable(getattr(self.data_loader.dataset, 'set_state', None)):
×
459
            self.data_loader.dataset.set_state(self._mode)
×
460

461
        self._call_hook('before_val_epoch')
×
462

463
        for data in nncore.ProgressBar(self.data_loader):
×
464
            self.val_iter(data)
×
465

466
        self._call_hook('after_val_epoch')
×
467

468
    def test_epoch(self):
×
469
        self.logger.info('Evaluating...')
×
470
        self._mode = 'test'
×
471
        self.model.eval()
×
472
        self.buffer.pop('_out', None)
×
473
        self.data_loader = self.data_loaders[self._mode]
×
474

475
        if callable(getattr(self.data_loader.dataset, 'set_state', None)):
×
476
            self.data_loader.dataset.set_state(self._mode)
×
477

478
        for data in nncore.ProgressBar(self.data_loader):
×
479
            self.test_iter(data)
×
480

481
    def run_stage(self):
×
482
        if isinstance(self.cur_stage['optimizer'], dict):
×
483
            optim_cfg = self.cur_stage['optimizer'].copy()
×
484
            optim_type = optim_cfg.pop('type')
×
485
            optim_args = ['{}: {}'.format(k, v) for k, v in optim_cfg.items()]
×
486
            optim_str = '{}({})'.format(optim_type, ', '.join(optim_args))
×
487
        else:
488
            optim_str = '{}()'.format(
×
489
                self.cur_stage['optimizer'].__class__.__name__)
490

491
        self.logger.info('Stage: {}, epochs: {}, optimizer: {}'.format(
×
492
            self._stage + 1, self.cur_stage['epochs'], optim_str))
493

494
        if self.epoch_in_stage == 0:
×
495
            self.optimizer = build_optimizer(
×
496
                self.cur_stage['optimizer'],
497
                params=[p for p in self.model.parameters() if p.requires_grad])
498

499
        self._call_hook('before_stage')
×
500

501
        for _ in range(self.cur_stage['epochs'] - self.epoch_in_stage):
×
502
            self.train_epoch()
×
503
            cfg = self.cur_stage.get('validation')
×
504
            if (cfg is not None and 'val' in self.data_loaders
×
505
                    and cfg.get('interval', 0) > 0
506
                    and self.epoch_in_stage > cfg.get('offset', 0)
507
                    and self.epoch_in_stage % cfg.get('interval', 0) == 0):
508
                self.val_epoch()
×
509

510
        self._call_hook('after_stage')
×
511
        self._stage += 1
×
512

513
    def evaluate(self):
×
514
        """
515
        Perform evaluation. This methods is expected to be called after
516
        validation or testing.
517
        """
518
        blob = self.buffer.pop('_out')
×
519
        blob = gather(blob)
×
520

521
        if is_main_process():
×
522
            blob = nncore.interleave(blob)[:len(self.data_loader.dataset)]
×
523

524
            cfg = self.cur_stage.get('validation')
×
525
            if cfg is not None:
×
526
                cfg = cfg.copy()
×
527
                cfg.pop('interval', None)
×
528
                cfg.pop('offset', None)
×
529
            else:
530
                cfg = dict()
×
531

532
            output = self.data_loader.dataset.evaluate(
×
533
                blob, logger=self.logger, **cfg)
534
        else:
535
            output = dict()
×
536

537
        sync()
×
538
        return output
×
539

540
    def launch(self, eval=False, **kwargs):
×
541
        """
542
        Launch the engine.
543

544
        Args:
545
            eval (bool, optional): Whether to run evaluation only. Default:
546
                ``False``.
547
        """
548
        self._kwargs = kwargs
×
549

550
        if eval:
×
551
            self.test_epoch()
×
552
            output = self.evaluate()
×
553
            self.logger.info(
×
554
                'Evaluation results: ' +
555
                ', '.join(['{}: {}'.format(k, v) for k, v in output.items()]))
556
            return output
×
557

558
        self.logger.info('Distributed: {}, AMP: {}, Debug: {}'.format(
×
559
            is_distributed(), self.get_amp_type(), self.debug))
560
        self.logger.info('Launch engine, host: {}, work_dir: {}'.format(
×
561
            nncore.get_host_info(), self.work_dir))
562

563
        self._call_hook('before_launch')
×
564

565
        while self._stage < self._max_stages:
×
566
            self.run_stage()
×
567

568
        self._call_hook('after_launch')
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc