• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

JoranAngevaare / optim_esm_tools / 14750707513

30 Apr 2025 08:55AM UTC coverage: 89.832% (+0.02%) from 89.81%
14750707513

Pull #238

github

web-flow
Merge 594f364cc into 90367cc2c
Pull Request #238: Further sort candidate regions on lat lon in `Merger`

9 of 9 new or added lines in 2 files covered. (100.0%)

224 existing lines in 7 files now uncovered.

2942 of 3275 relevant lines covered (89.83%)

3.31 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

80.0
/optim_esm_tools/utils.py
1
import inspect
4✔
2
import os
4✔
3
import socket
4✔
4
import sys
4✔
5
import time
4✔
6
import typing as ty
4✔
7
import warnings
4✔
8
from collections import defaultdict
4✔
9
from functools import wraps
4✔
10
from importlib import import_module
4✔
11
from platform import python_version
4✔
12

13
import numpy as np
4✔
14
import pandas as pd
4✔
15
from immutabledict import immutabledict
4✔
16

17
try:
4✔
18
    from git import Repo, InvalidGitRepositoryError
4✔
19

20
    GIT_INSTALLED = True
4✔
21
except ImportError:  # pragma: no cover
22
    GIT_INSTALLED = False
23

24
import sys
4✔
25

26

27
# From https://github.com/AxFoundation/strax/blob/136a16975b18ee87500051fd81a90c894d9b58dc/strax/utils.py#L33
28
if any('jupyter' in arg for arg in sys.argv):
4✔
29
    # In some cases we are not using any notebooks,
30
    # Taken from 44952863 on stack overflow thanks!
31
    from tqdm.notebook import tqdm  # pragma: no cover
32
else:
33
    from tqdm import tqdm
4✔
34

35

36
# https://github.com/JoranAngevaare/thesis_plots/blob/d828c08e6f6c9c6926527220a23fd0e61e5d8c60/thesis_plots/main.py
37
root_folder = os.path.join(os.path.split(os.path.realpath(__file__))[0], '..')
4✔
38

39
from optim_esm_tools.plotting.plot_utils import *
4✔
40

41

42
def print_versions(
4✔
43
    modules=('optim_esm_tools',),
44
    print_output=True,
45
    include_python=True,
46
    return_string=False,
47
    include_git=True,
48
):
49
    """Print versions of modules installed.
50

51
    :param modules: Modules to print, should be str, tuple or list. E.g.
52
        print_versions(modules=('numpy', 'optim_esm_tools',))
53
    :param return_string: optional. Instead of printing the message,
54
        return a string
55
    :param include_git: Include the current branch and latest commit
56
        hash
57
    :return: optional, the message that would have been printed
58
    """
59
    versions = defaultdict(list)
4✔
60
    if not GIT_INSTALLED and include_git:  # pragma: no cover
61
        warnings.warn('Git is not installed, maybe try pip install gitpython')
62
        include_git = False
63
    if include_python:
4✔
64
        versions['module'] = ['python']
4✔
65
        versions['version'] = [python_version()]
4✔
66
        versions['path'] = [sys.executable]
4✔
67
        versions['git'] = [None]
4✔
68
    for m in to_str_tuple(modules):
4✔
69
        result = _version_info_for_module(m, include_git=include_git)
4✔
70
        if result is None:
4✔
71
            continue  # pragma: no cover
72
        version, path, git_info = result
4✔
73
        versions['module'].append(m)
4✔
74
        versions['version'].append(version)
4✔
75
        versions['path'].append(path)
4✔
76
        versions['git'].append(git_info)
4✔
77
    df = pd.DataFrame(versions)
4✔
78
    info = f'Host {socket.getfqdn()}\n{df.to_string(index=False)}'
4✔
79
    if print_output:
4✔
80
        print(info)
4✔
81
    return info if return_string else df
4✔
82

83

84
def _version_info_for_module(module_name, include_git):
4✔
85
    try:
4✔
86
        mod = import_module(module_name)
4✔
87
    except ImportError:
4✔
88
        print(f'{module_name} is not installed')
4✔
89
        return
4✔
90
    git = None
4✔
91
    version = mod.__dict__.get('__version__', None)
4✔
92
    module_path = mod.__dict__.get('__path__', [None])[0]
4✔
93
    if include_git:
4✔
94
        try:
4✔
95
            repo = Repo(module_path, search_parent_directories=True)
4✔
96
        except InvalidGitRepositoryError:
4✔
97
            # not a git repo
98
            pass
4✔
99
        else:
100
            try:
4✔
101
                branch = repo.active_branch
4✔
102
            except TypeError:  # pragma: no cover
103
                branch = 'unknown'
104
            try:
4✔
105
                commit_hash = repo.head.object.hexsha
4✔
106
            except TypeError:  # pragma: no cover
107
                commit_hash = 'unknown'
108
            git = f'branch:{branch} | {commit_hash[:7]}'
4✔
109
    return version, module_path, git
4✔
110

111

112
def to_str_tuple(
4✔
113
    x: ty.Union[str, bytes, list, tuple, pd.Series, np.ndarray],
114
) -> ty.Tuple[str]:
115
    """
116
    Convert any sensible instance to a tuple of strings
117
    from https://github.com/AxFoundation/strax/blob/d3608efc77acd52e1d5a208c3092b6b45b27a6e2/strax/utils.py#242
118
    """
119
    if isinstance(x, (str, bytes)):
4✔
120
        return (x,)
4✔
121
    if isinstance(x, list):
4✔
122
        return tuple(x)
4✔
123
    if isinstance(x, tuple):
4✔
124
        return x
4✔
125
    raise TypeError(
126
        f'Expected string or tuple of strings, got {type(x)}',
127
    )  # pragma: no cover
128

129

130
def mathrm(string):
4✔
131
    return string_to_mathrm(string)
4✔
132

133

134
def string_to_mathrm(string):
4✔
135
    """Wrap a string in mathrm mode for latex labels."""
136
    string = string.replace(' ', r'\ ')
4✔
137
    return fr'$\mathrm{{{string}}}$'
4✔
138

139

140
def filter_keyword_arguments(
4✔
141
    kw: ty.Mapping,
142
    func: ty.Callable,
143
    allow_varkw: bool = False,
144
) -> ty.Mapping:
145
    """Only pass accepted keyword arguments (from kw) into function "func".
146

147
    Args:
148
        kw (ty.Mapping): kwargs that could go into function func
149
        func (type): a function
150
        allow_varkw (bool, optional): If True and the function take kwargs, just return the <kw>
151
            argument. Defaults to False.
152

153
    Returns:
154
        dict: Filtered keyword arguments
155
    """
156
    spec = inspect.getfullargspec(func)
4✔
157
    if allow_varkw and spec.varkw is not None:
4✔
158
        return kw  # pragma: no cover
159
    return {k: v for k, v in kw.items() if k in spec.args}
4✔
160

161

162
def check_accepts(
4✔
163
    accepts: ty.Mapping[str, ty.Iterable] = immutabledict(unit=('absolute', 'std')),
164
    do_raise: bool = True,
165
):
166
    """Wrapper for function if certain kwargs are from a defined list of
167
    variables.
168

169
    Example:
170
        ```
171
            @check_accepts(accepts=dict(far=('boo', 'booboo')))
172
            def bla(far):
173
                print(far)
174

175
            bla(far='boo')  # prints boo
176
            bla(far='booboo')  # prints booboo
177
            bla(far=1)  # raises ValueError
178
        ```
179

180

181
    Args:
182
        accepts (ty.Mapping[str, ty.Iterable], optional): which kwarg to accept a limited set of options.
183
            Defaults to immutabledict(unit=('absolute', 'std')).
184
        do_raise (bool, optional): if False, don't raise an error but just warn. Defaults to True.
185

186
    Returns:
187
        wrapped function
188
    """
189

190
    def somedec_outer(fn):
4✔
191
        @wraps(fn)
4✔
192
        def somedec_inner(*args, **kwargs):
4✔
193
            message = ''
4✔
194
            for k, v in kwargs.items():
4✔
195
                if k in accepts and v not in accepts[k]:
4✔
196
                    message += (
197
                        f'{k} for {v} but only accepts {accepts[k]}'  # pragma: no cover
198
                    )
199
            if do_raise and message:  # pragma: no cover
200
                raise ValueError(message)
201
            if message:  # pragma: no cover
202
                warnings.warn(message)
203
            response = fn(*args, **kwargs)
4✔
204
            return response
4✔
205

206
        return somedec_inner
4✔
207

208
    return somedec_outer
4✔
209

210

211
def add_load_kw(func):
4✔
212
    """Add apply `.load` method to the dataset returned by wrapped function."""
213

214
    @wraps(func)
4✔
215
    def dep_fun(*args, **kwargs):
4✔
216
        if 'load' not in inspect.signature(func).parameters:
4✔
217
            add_load = kwargs.pop('load', False)
4✔
218
        else:
219
            add_load = kwargs.get('load', False)
4✔
220
        res = func(*args, **kwargs)
4✔
221
        return res.load() if add_load else res
4✔
222

223
    return dep_fun
4✔
224

225

226
def deprecated(func, message='is deprecated'):
4✔
227
    @wraps(func)
4✔
228
    def dep_fun(*args, **kwargs):
4✔
229
        warnings.warn(
4✔
230
            f'calling {func.__name__} {message}',
231
            category=DeprecationWarning,
232
        )
233
        return func(*args, **kwargs)
4✔
234

235
    return dep_fun
4✔
236

237

238
def _chopped_string(string, max_len):
4✔
239
    string = str(string)
4✔
240
    return string if len(string) < max_len else string[:max_len] + '...'
4✔
241

242

243
@check_accepts(accepts=dict(_report=('debug', 'info', 'warning', 'error', 'print')))
4✔
244
def timed(
4✔
245
    *a,
246
    seconds: ty.Optional[int] = None,
247
    _report: ty.Optional[str] = None,
248
    _args_max: int = 20,
249
    _fmt: str = '.2g',
250
    _stacklevel: int = 2,
251
):
252
    """Time a function and print if it takes more than <seconds>
253

254
    Args:
255
        seconds (int, optional): Defaults to 5.
256
        _report (str, optional): Method of reporting, either print or use the global logger. Defaults to 'print'.
257
        _args_max (int, optional): Max number of characters in the message for the args and kwars of the function. Defaults to 20.
258
        _fmt (str, optional): time format specification. Defaults to '.2g'.
259
    """
260
    if seconds is None or _report is None:
4✔
261
        from .config import config
4✔
262

263
        if seconds is None:
4✔
264
            seconds = float(config['time_tool']['min_seconds'])
4✔
265
        if _report is None:
4✔
266
            _report = config['time_tool']['reporter']
4✔
267

268
    def somedec_outer(fn):
4✔
269
        @wraps(fn)
4✔
270
        def timed_func(*args, **kwargs):
4✔
271
            t0 = time.time()
4✔
272
            res = fn(*args, **kwargs)
4✔
273
            dt = time.time() - t0
4✔
274
            if dt > seconds:
4✔
275
                hours = '' if dt < 3600 else f' ({dt/3600:{_fmt}} h) '
4✔
276
                message = (
4✔
277
                    f'{fn.__name__} took {dt:{_fmt}} s{hours} (for '
278
                    f'{_chopped_string(args, _args_max)}, '
279
                    f'{_chopped_string(kwargs, _args_max)})'
280
                ).replace('\n', ' ')
281
                if _report == 'print':
4✔
282
                    print(message)
4✔
283
                else:
284
                    from .config import get_logger
4✔
285

286
                    getattr(get_logger(), _report)(message, stacklevel=_stacklevel)
4✔
287
            return res
4✔
288

289
        return timed_func
4✔
290

291
    if a and isinstance(a[0], ty.Callable):
4✔
292
        # Decorator that isn't closed
293
        return somedec_outer(a[0])
4✔
294
    return somedec_outer
4✔
295

296

297
@check_accepts(accepts=dict(_report=('debug', 'info', 'warning', 'error', 'print')))
4✔
298
def logged_tqdm(*a, log=None, _report='warning', **kw):
4✔
299
    from .config import get_logger
×
300

301
    log = log or get_logger()
×
302
    pbar = tqdm(*a, **kw)
×
303
    generator = iter(pbar)
×
UNCOV
304
    while True:
305
        try:
×
306
            v = next(generator)
×
307
            getattr(log, _report)(pbar, stacklevel=2)
×
308
            yield v
×
309

310
        except StopIteration:
×
311
            pbar.close()
×
312
            getattr(log, _report)(pbar, stacklevel=2)
×
313
            return
×
314

315

316
def scientific_latex_notation(
4✔
317
    value: ty.Union[str, float, int],
318
    high: float = 1e3,
319
    low: float = 1e-3,
320
    precision: str = '.2e',
321
):
322
    """convert a string of float-representation to latex-format with exponents"""
323
    if isinstance(value, float):
×
324
        fl = value
×
325
    elif isinstance(value, int):
×
326
        fl = float(value)
×
327
    elif isinstance(value, str):
×
328
        value = str(value)
×
329

330
        try:
×
331
            fl = float(value)
×
332
        except (TypeError, ValueError) as e:
×
333
            return value
×
334
    else:
335
        raise TypeError(f'misunderstood {value} ({type(value)})')
×
336
    if abs(fl) > high or abs(fl) < low or "e" in str(value):
×
337
        fl_s = f"{fl:{precision}}"
×
338
        if "e" not in fl_s:
×
339
            return value
×
340
        a, b = fl_s.split("e+") if "e+" in fl_s else fl_s.split("e-")
×
341
        if '-' in fl_s:
×
342
            exp = f'-{int(b)}'
×
343
        else:
344
            exp = f'{int(b)}'
×
345
        res = f"${a}\\times 10^{{{exp}}}$"
×
346

347
        return res
×
348
    return value
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc