• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

yeliudev / nncore / 9528342153

15 Jun 2024 01:10PM UTC coverage: 15.572% (-0.05%) from 15.622%
9528342153

push

github

yeliudev
Update download script

1 of 24 new or added lines in 2 files covered. (4.17%)

2 existing lines in 1 file now uncovered.

678 of 4354 relevant lines covered (15.57%)

3.01 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/nncore/engine/utils.py
1
# Copyright (c) Ye Liu. Licensed under the MIT License.
2

3
import os
×
4
import random
×
5
from datetime import datetime
×
6
from importlib import import_module
×
7
from pkgutil import walk_packages
×
8

9
import numpy as np
×
10
import torch
×
11
import torchvision
×
12
from torch.hub import load_state_dict_from_url
×
13

14
import nncore
×
15
from nncore.nn import move_to_device
×
16
from .comm import broadcast, is_main_process, sync
×
17

18
DATASETS = nncore.Registry('dataset')
×
19

20

21
def _load_url_dist(url, **kwargs):
×
22
    if is_main_process():
×
23
        load_state_dict_from_url(url, **kwargs)
×
24

25
    sync()
×
26
    state_dict = load_state_dict_from_url(url, **kwargs)
×
27

28
    return state_dict
×
29

30

31
def _match_keys(keys, cand):
×
32
    keys = [k.split('.') for k in keys]
×
33
    cand = cand.split('.')
×
34

35
    for key in keys:
×
36
        if cand[:len(key)] == key:
×
37
            return True
×
38

39
    return False
×
40

41

NEW
42
def _load_state_dict(module,
×
43
                     state_dict,
44
                     strict=False,
45
                     warning=True,
46
                     logger=None):
47
    unexpected_keys = []
×
48
    missing_keys = []
×
49
    err_msg = []
×
50

51
    metadata = getattr(state_dict, '_metadata', None)
×
52
    state_dict = state_dict.copy()
×
53
    if metadata is not None:
×
54
        state_dict._metadata = metadata
×
55

56
    def _load(module, prefix=''):
×
57
        local_metadata = dict() if metadata is None else metadata.get(
×
58
            prefix[:-1], {})
59
        module._load_from_state_dict(state_dict, prefix, local_metadata, True,
×
60
                                     missing_keys, unexpected_keys, err_msg)
61
        for name, child in module._modules.items():
×
62
            if child is not None:
×
63
                _load(child, prefix + name + '.')
×
64

65
    _load(module)
×
66
    _load = None
×
67

68
    if len(unexpected_keys) > 0:
×
69
        err_msg.append('Unexpected keys in source state dict: {}\n'.format(
×
70
            ', '.join(unexpected_keys)))
71
    if len(missing_keys) > 0:
×
72
        err_msg.append('Missing keys in source state dict: {}\n'.format(
×
73
            ', '.join(missing_keys)))
74

75
    if is_main_process() and len(err_msg) > 0:
×
76
        err_msg.insert(
×
77
            0, 'The model and loaded state dict do not match exactly\n')
78
        err_msg = '\n'.join(err_msg)
×
79
        if strict:
×
80
            raise RuntimeError(
×
81
                'error in loading state dict for {}:\n\t{}'.format(
82
                    module.__class__.__name__, "\n\t".join(err_msg)))
NEW
83
        if warning:
×
NEW
84
            nncore.log_or_print(err_msg, logger, log_level='WARNING')
×
85

86

87
def generate_random_seed(sync=True, src=0, group=None):
×
88
    """
89
    Generate a random seed.
90

91
    Args:
92
        sync (bool, optional): Whether to synchronize the random seed among the
93
            processes in the group in distributed settings. Default: ``True``.
94
        src (int, optional): The source rank of the process in distributed
95
            settings. This argument is valid only when ``sync==True``. Default:
96
            ``0``.
97
        group (:obj:`dist.ProcessGroup` | None, optional): The process group
98
            to use in distributed settings. This argument is valid only when
99
            ``sync==True``. If not specified, the default process group will
100
            be used. Default: ``None``.
101

102
    Returns:
103
        int: The generated random seed.
104
    """
105
    seed = 0
×
106
    while len(str(seed)) != 8:
×
107
        seed = os.getpid() + int.from_bytes(os.urandom(4), 'big') + int(
×
108
            datetime.now().strftime('%f'))
109
    if sync:
×
110
        seed = broadcast(data=seed, src=src, group=group)
×
111
    return seed
×
112

113

114
def set_random_seed(seed=None, benchmark=False, deterministic=False, **kwargs):
×
115
    """
116
    Set random seed for ``random``, ``numpy``, and ``torch`` packages. If
117
    ``seed`` is not specified, this method will generate and return a new
118
    random seed.
119

120
    Args:
121
        seed (int | None, optional): The random seed to use. If not specified,
122
            a new random seed will be generated. Default: ``None``.
123
        benchmark (bool, optional): Whether to enable benchmark mode. Default:
124
            ``False``.
125
        deterministic (bool, optional): Whether to enable deterministic mode.
126
            Default: ``False``.
127

128
    Returns:
129
        int: The actually used random seed.
130
    """
131
    if seed is None:
×
132
        seed = generate_random_seed(**kwargs)
×
133

134
    random.seed(seed)
×
135
    np.random.seed(seed)
×
136
    torch.manual_seed(seed)
×
137

138
    torch.backends.cudnn.benchmark = benchmark
×
139
    torch.backends.cudnn.deterministic = deterministic
×
140

141
    return seed
×
142

143

144
def get_checkpoint(file_or_url, map_location=None, **kwargs):
×
145
    """
146
    Get checkpoint from a file or an URL.
147

148
    Args:
149
        file_or_url (str): The filename or URL of the checkpoint.
150
        map_location (str | None, optional): Same as the :obj:`torch.load`
151
            interface. Default: ``None``.
152

153
    Returns:
154
        :obj:`OrderedDict` | dict: The loaded checkpoint.
155
    """
156
    if file_or_url.startswith('torchvision://'):
×
157
        model_urls = dict()
×
158
        for _, name, ispkg in walk_packages(torchvision.models.__path__):
×
159
            if ispkg:
×
160
                continue
×
161
            mod = import_module('torchvision.models.{}'.format(name))
×
162
            if hasattr(mod, 'model_urls'):
×
163
                urls = getattr(mod, 'model_urls')
×
164
                model_urls.update(urls)
×
165
        checkpoint = _load_url_dist(model_urls[file_or_url[14:]], **kwargs)
×
166
    elif file_or_url.startswith(('http://', 'https://')):
×
167
        checkpoint = _load_url_dist(file_or_url, **kwargs)
×
168
    else:
169
        checkpoint = torch.load(file_or_url, map_location=map_location)
×
170

171
    return checkpoint
×
172

173

174
def load_checkpoint(model,
×
175
                    checkpoint,
176
                    map_location=None,
177
                    strict=False,
178
                    warning=True,
179
                    keys=None,
180
                    logger=None,
181
                    **kwargs):
182
    """
183
    Load checkpoint from a file or an URL.
184

185
    Args:
186
        model (:obj:`nn.Module`): The module to load checkpoint.
187
        checkpoint (dict | str): A dict, a filename, an URL or a
188
            ``torchvision://<model_name>`` str indicating the checkpoint.
189
        map_location (str | None, optional): Same as the :obj:`torch.load`
190
            interface. Default: ``None``.
191
        strict (bool, optional): Whether to allow different params for the
192
            model and checkpoint. If ``True``, raise an error when the params
193
            do not match exactly. Default: ``False``.
194
        warning (bool, optional): Whether to display warnings if the params
195
            for the model and checkpoint are not matched. Default: ``True``.
196
        keys (list[str] | None, optional): The list of parameter keys to load.
197
            Default: ``None``.
198
        logger (:obj:`logging.Logger` | str | None, optional): The logger or
199
            name of the logger for displaying error messages. Default:
200
            ``None``.
201

202
    Returns:
203
        :obj:`nn.Module`: The model with loaded checkpoint.
204
    """
205
    if isinstance(checkpoint, str):
×
206
        checkpoint = get_checkpoint(
×
207
            checkpoint, map_location=map_location, **kwargs)
208

209
    if isinstance(checkpoint, dict):
×
210
        state_dict = checkpoint.get('state_dict', checkpoint)
×
211
    else:
212
        raise RuntimeError('no state dict found in the checkpoint file')
×
213

214
    if list(state_dict.keys())[0].startswith('module.'):
×
215
        state_dict = {k[7:]: v for k, v in state_dict.items()}
×
216

217
    if keys is not None:
×
218
        state_dict = {
×
219
            k: v
220
            for k, v in state_dict.items() if _match_keys(keys, k)
221
        }
222

223
    _load_state_dict(
×
224
        getattr(model, 'module', model),
225
        state_dict,
226
        strict=strict,
227
        warning=warning,
228
        logger=logger)
229

NEW
230
    return model
×
231

232

233
def save_checkpoint(model, filename, optimizer=None, meta=None):
×
234
    """
235
    Save checkpoint to a file.
236

237
    The checkpoint object will have 3 fields: ``meta``, ``state_dict`` and
238
    ``optimizer``, where ``meta`` contains the version of nncore and the time
239
    info by default.
240

241
    Args:
242
        model (:obj:`nn.Module`): The model whose params are to be saved.
243
        filename (str): Path to the checkpoint file.
244
        optimizer (:obj:`optim.Optimizer` | None, optional): The optimizer to
245
            be saved. Default: ``None``.
246
        meta (dict | None, optional): The metadata to be saved. Default:
247
            ``None``.
248

249
    Returns:
250
        dict: The saved checkpoint.
251
    """
252
    if meta is None:
×
253
        meta = dict()
×
254

255
    meta.update(
×
256
        nncore_version=nncore.__version__, create_time=nncore.get_time_str())
257

258
    state_dict = getattr(model, 'module', model).state_dict()
×
259
    checkpoint = dict(meta=meta, state_dict=state_dict)
×
260

261
    if optimizer is not None:
×
262
        checkpoint['optimizer'] = optimizer.state_dict()
×
263

264
    checkpoint = move_to_device(checkpoint, 'cpu')
×
265

266
    nncore.mkdir(nncore.dir_name(filename))
×
267
    torch.save(checkpoint, filename)
×
268

269
    return checkpoint
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc