• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

JohannesBuchner / UltraNest / 9f2dd4f6-0775-47e9-b700-af647027ebfa

22 Apr 2024 12:51PM UTC coverage: 74.53% (+0.3%) from 74.242%
9f2dd4f6-0775-47e9-b700-af647027ebfa

push

circleci

web-flow
Merge pull request #118 from njzifjoiez/fixed-size-vectorised-slice-sampler

vectorised slice sampler of fixed batch size

1329 of 2026 branches covered (65.6%)

Branch coverage included in aggregate %.

79 of 80 new or added lines in 1 file covered. (98.75%)

1 existing line in 1 file now uncovered.

4026 of 5159 relevant lines covered (78.04%)

0.78 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

55.25
/ultranest/viz.py
1
"""
2
Live point visualisations
3
-------------------------
4

5
Gives a live impression of current exploration.
6
This is powerful because the user can abort partial runs if the fit
7
converges to unreasonable values.
8

9
"""
10

11
from __future__ import print_function, division
1✔
12

13
import sys
1✔
14
import shutil
1✔
15
from numpy import log10
1✔
16
import numpy as np
1✔
17
import string
1✔
18
from xml.sax.saxutils import escape as html_escape
1✔
19

20

21
clusteridstrings = ['%d' % i for i in range(10)] + list(string.ascii_uppercase) + list(string.ascii_lowercase)
1✔
22

23
spearman = None
1✔
24
try:
1✔
25
    import scipy.stats
1✔
26
    spearman = scipy.stats.spearmanr
1✔
27
except ImportError:
×
28
    pass
×
29

30

31
def round_parameterlimits(plo, phi, paramlimitguess=None):
1✔
32
    """Guess the current parameter range.
33

34
    Parameters
35
    -----------
36
    plo: array of floats
37
        for each parameter, current minimum value
38
    phi: array of floats
39
        for each parameter, current maximum value
40
    paramlimitguess: array of float tuples
41
        for each parameter, guess of parameter range if available
42

43
    Returns
44
    -------
45
    plo_rounded: array of floats
46
        for each parameter, rounded minimum value
47
    phi_rounded: array of floats
48
        for each parameter, rounded maximum value
49
    formats: array of float tuples
50
        for each parameter, string format for representing it.
51

52
    """
53
    with np.errstate(divide='ignore'):
1✔
54
        expos = log10(np.abs([plo, phi]))
1✔
55
    expolo = np.floor(np.min(expos, axis=0))
1✔
56
    expohi = np.ceil(np.max(expos, axis=0))
1✔
57
    is_negative = plo < 0
1✔
58
    plo_rounded = np.where(is_negative, -10**expohi, 0)
1✔
59
    phi_rounded = np.where(is_negative, 10**expohi, 10**expohi)
1✔
60

61
    if paramlimitguess is not None:
1✔
62
        for i, (plo_guess, phi_guess) in enumerate(paramlimitguess):
1✔
63
            # if plo_guess is higher than what we thought, we can increase to match
64
            if plo_guess <= plo[i] and plo_guess >= plo_rounded[i]:
1✔
65
                plo_rounded[i] = plo_guess
1✔
66
            if phi_guess >= phi[i] and phi_guess <= phi_rounded[i]:
1✔
67
                phi_rounded[i] = phi_guess
1✔
68

69
    formats = []
1✔
70
    for i in range(len(plo)):
1✔
71
        fmt = '%+.1e'
1✔
72
        if -1 <= expolo[i] <= 2 and -1 <= expohi[i] <= 2:
1✔
73
            fmt = '%+.1f'
1✔
74
        if -4 <= expolo[i] <= 0 and -4 <= expohi[i] <= 0:
1✔
75
            fmt = '%%+.%df' % (max(0, -min(expolo[i], expohi[i])))
1✔
76
        if phi[i] == plo[i]:
1!
77
            fmt = '%+.1f'
×
78
        elif fmt % plo[i] == fmt % phi[i]:
1✔
79
            fmt = '%%+.%df' % (max(0, -int(np.floor(log10(phi[i] - plo[i])))))
1✔
80
        formats.append(fmt)
1✔
81

82
    return plo_rounded, phi_rounded, formats
1✔
83

84

85
def nicelogger(points, info, region, transformLayer, region_fresh=False):
1✔
86
    """Log current live points and integration progress to stdout.
87

88
    Parameters
89
    -----------
90
    points: dict with keys "u", "p", "logl"
91
        live points (u: cube coordinates, p: transformed coordinates,
92
        logl: loglikelihood values)
93
    info: dict
94
        integration information. Keys are:
95

96
        - paramlims (optional): parameter ranges
97
        - logvol: expected volume at this iteration
98

99
    region: MLFriends
100
        Current region.
101
    transformLayer: ScaleLayer or AffineLayer or MaxPrincipleGapAffineLayer
102
        Current transformLayer (for clustering information).
103
    region_fresh: bool
104
        Whether the region was just updated.
105

106
    """
107
    p = points['p']
1✔
108
    paramnames = info['paramnames']
1✔
109
    # print()
110
    # print('lnZ = %.1f, remainder = %.1f, lnLike = %.1f | Efficiency: %d/%d = %.4f%%\r' % (
111
    #       logz, logz_remain, np.max(logl), ncall, it, it * 100 / ncall))
112

113
    plo = p.min(axis=0)
1✔
114
    phi = p.max(axis=0)
1✔
115
    plo_rounded, phi_rounded, paramformats = round_parameterlimits(plo, phi, paramlimitguess=info.get('paramlims'))
1✔
116

117
    if sys.stderr.isatty() and hasattr(shutil, 'get_terminal_size'):
1✔
118
        columns, _ = shutil.get_terminal_size(fallback=(80, 25))
1✔
119
    else:
120
        columns, _ = 80, 25
1✔
121

122
    paramwidth = max([len(pname) for pname in paramnames])
1✔
123
    width = columns - 23 - paramwidth
1✔
124
    width = max(width, 10)
1✔
125
    indices = ((p - plo_rounded) * width / (phi_rounded - plo_rounded).reshape((1, -1))).astype(int)
1✔
126
    indices[indices >= width] = width - 1
1✔
127
    indices[indices < 0] = 0
1✔
128
    ndim = len(plo)
1✔
129

130
    print()
1✔
131
    print()
1✔
132
    clusterids = transformLayer.clusterids % len(clusteridstrings)
1✔
133
    nmodes = transformLayer.nclusters
1✔
134
    print(
1✔
135
        "Mono-modal" if nmodes == 1 else "Have %d modes" % nmodes,
136
        "Volume: ~exp(%.2f)" % region.estimate_volume(), '*' if region_fresh else ' ',
137
        "Expected Volume: exp(%.2f)" % info['logvol'],
138
        '' if 'order_test_correlation' not in info else
139
        ("Quality: correlation length: %d (%s)" % (info['order_test_correlation'], '+' if info['order_test_direction'] >= 0 else '-'))
140
        if np.isfinite(info['order_test_correlation']) else "Quality: ok",
141
    )
142
    if info.get('stepsampler_info', {}).get('num_logs', 0) > 0:
1✔
143
        print(
1✔
144
            ('Step sampler performance: %(rejection_rate).1f rej/step, %(mean_nsteps)d steps/it' % (info['stepsampler_info'])) +
145
            (', rel jump distance: %.2f (should be >1), %.2f%% (should be >50%%)' % (
146
                info['stepsampler_info']['mean_distance'], 100 * info['stepsampler_info']['frac_far_enough']
147
            )) if 'mean_distance' in info['stepsampler_info'] else ''
148
        )
149

150
    print()
1✔
151
    if ndim == 1:
1✔
152
        pass
1✔
153
    elif ndim == 2 and spearman is not None:
1✔
154
        rho, pval = spearman(p)
1✔
155
        if pval < 0.01 and abs(rho) > 0.75:
1!
156
            print("   %s between %s and %s: rho=%.2f" % (
×
157
                'positive degeneracy' if rho > 0 else 'negative degeneracy',
158
                paramnames[0], paramnames[1], rho))
159
    elif spearman is not None:
1!
160
        rho, pval = spearman(p)
1✔
161
        if np.isfinite(pval).all() and pval.ndim == 2:
1!
162
            for i, param in enumerate(paramnames):
1✔
163
                for j, param2 in enumerate(paramnames[:i]):
1✔
164
                    if pval[i,j] < 0.01 and abs(rho[i,j]) > 0.99:
1!
165
                        s = 'positive relation' if rho[i,j] > 0 else 'negative relation'
×
166
                        print("   perfect %s between %s and %s" % (s, param, param2))
×
167
                    elif pval[i,j] < 0.01 and abs(rho[i,j]) > 0.75:
1✔
168
                        s = 'positive degeneracy' if rho[i,j] > 0 else 'negative degeneracy'
1✔
169
                        print("   %s between %s and %s: rho=%.2f" % (s, param, param2, rho[i,j]))
1✔
170

171
    for i, (param, fmt) in enumerate(zip(paramnames, paramformats)):
1✔
172
        if nmodes == 1:
1✔
173
            line = [' ' for _ in range(width)]
1✔
174
            for j in np.unique(indices[:,i]):
1✔
175
                line[j] = '*'
1✔
176
            linestr = ''.join(line)
1✔
177
        else:
178
            line = [' ' for _ in range(width)]
1✔
179
            for clusterid, j in zip(clusterids, indices[:,i]):
1✔
180
                if clusterid > 0 and line[j] in (' ', '0'):
1✔
181
                    # set it to correct cluster id
182
                    line[j] = clusteridstrings[clusterid]
1✔
183
                elif clusterid == 0 and line[j] == ' ':
1✔
184
                    # empty, so set it although we don't know the cluster id
185
                    line[j] = '0'
1✔
186
                # else:
187
                #    line[j] = '*'
188
            linestr = ''.join(line)
1✔
189

190
        line = linestr
1✔
191
        ilo, ihi = indices[:,i].min(), indices[:,i].max()
1✔
192
        if ilo > 10:
1✔
193
            assert line[:10] == ' ' * 10
1✔
194
            leftstr = fmt % plo[i]
1✔
195
            j = ilo - 2 - len(leftstr)  # left-bound
1✔
196
            if j < width and j > 0:
1!
197
                line = line[:j] + leftstr + line[j + len(leftstr):]
1✔
198
        if ihi < width - 10:
1✔
199
            rightstr = fmt % phi[i]
1✔
200
            j = ihi + 3  # right-bound
1✔
201
            if j < width and j > 0:
1!
202
                line = line[:j] + rightstr + line[j + len(rightstr):]
1✔
203

204
        parampadded = ('%%-%ds' % paramwidth) % param
1✔
205
        print('%s: %09s|%s|%9s' % (parampadded, fmt % plo_rounded[i], line, fmt % phi_rounded[i]))
1✔
206

207
    print()
1✔
208

209

210
def isnotebook():
1✔
211
    """Check if running in a Jupyter notebook."""
212
    try:
1✔
213
        shell = get_ipython().__class__.__name__
1✔
214
        if shell == 'ZMQInteractiveShell':
×
215
            return True   # Jupyter notebook or qtconsole
×
216
        elif shell == 'TerminalInteractiveShell':
×
217
            return False  # Terminal running IPython
×
218
        else:
219
            return False  # Other type (?)
×
220
    except NameError:
1✔
221
        return False      # Probably standard Python interpreter
1✔
222

223

224
class LivePointsWidget(object):
1✔
225
    """
226
    Widget for ipython and jupyter notebooks.
227

228
    Shows where the live points are currently in parameter space.
229
    """
230

231
    def __init__(self):
1✔
232
        """Initialise. To draw, call .initialize()."""
233
        self.grid = None
×
234
        self.label = None
×
235
        self.laststatus = None
×
236

237
    def initialize(self, paramnames, width):
1✔
238
        """Set up and display widget.
239

240
        Parameters
241
        ----------
242
        paramnames: list of str
243
            Parameter names
244
        width: int
245
            number of html table columns.
246

247
        """
248
        from ipywidgets import HTML, VBox, Layout, GridspecLayout
×
249
        from IPython.display import display
×
250

251
        grid = GridspecLayout(len(paramnames), width + 3)
×
252
        self.laststatus = []
×
253
        for a, paramname in enumerate(paramnames):
×
254
            self.laststatus.append('*' * width)
×
255
            htmlcode = "<div style='background-color:#6E6BF4;'>&nbsp;</div>"
×
256
            for b in range(width):
×
257
                grid[a, b + 2] = HTML(htmlcode, layout=Layout(margin="0"))
×
258
            htmlcode = "<div style='background-color:#FFB858; font-weight:bold; padding-right: 2em;'>%s</div>"
×
259
            grid[a, 0] = HTML(htmlcode % html_escape(paramname), layout=Layout(margin="0"))
×
260
            grid[a, 1] = HTML("...", layout=Layout(margin="0"))
×
261
            grid[a,-1] = HTML("...", layout=Layout(margin="0"))
×
262
        self.grid = grid
×
263

264
        self.label = HTML()
×
265
        box = VBox(children=[self.label, grid])
×
266
        display(box)
×
267

268
    def __call__(self, points, info, region, transformLayer, region_fresh=False):
1✔
269
        """Update widget to show current live points and integration progress to stdout.
270

271
        Parameters
272
        -----------
273
        points: dict with keys u, p, logl
274
            live points (u: cube coordinates, p: transformed coordinates,
275
            logl: loglikelihood values)
276
        info: dict
277
            integration information. Keys are:
278

279
            - paramlims (optional): parameter ranges
280
            - logvol: expected volume at this iteration
281

282
        region: MLFriends
283
            Current region.
284
        transformLayer: ScaleLayer or AffineLayer or MaxPrincipleGapAffineLayer
285
            Current transformLayer (for clustering information).
286
        region_fresh: bool
287
            Whether the region was just updated.
288

289
        """
290
        # t = time.time()
291
        # if self.lastupdate is not None and self.lastupdate < t - 5:
292
        #    return
293
        # self.lastupdate = t
294
        # u, p, logl = points['u'], points['p'], points['logl']
295
        p = points['p']
×
296
        paramnames = info['paramnames']
×
297
        # print()
298
        # print('lnZ = %.1f, remainder = %.1f, lnLike = %.1f | Efficiency: %d/%d = %.4f%%\r' % (
299
        #       logz, logz_remain, np.max(logl), ncall, it, it * 100 / ncall))
300

301
        plo = p.min(axis=0)
×
302
        phi = p.max(axis=0)
×
303
        plo_rounded, phi_rounded, paramformats = round_parameterlimits(plo, phi, paramlimitguess=info.get('paramlims'))
×
304

305
        width = 50
×
306

307
        if self.grid is None:
×
308
            self.initialize(paramnames, width)
×
309

310
        with np.errstate(invalid="ignore"):
×
311
            indices = ((p - plo_rounded) * width / (phi_rounded - plo_rounded).reshape((1, -1))).astype(int)
×
312
        indices[indices >= width] = width - 1
×
313
        indices[indices < 0] = 0
×
314
        ndim = len(plo)
×
315

316
        clusterids = transformLayer.clusterids % len(clusteridstrings)
×
317
        nmodes = transformLayer.nclusters
×
318
        labeltext = ("Mono-modal" if nmodes == 1 else "Have %d modes" % nmodes) + \
×
319
            (" | Volume: ~exp(%.2f) " % region.estimate_volume()) + ('*' if region_fresh else ' ') + \
320
            " | Expected Volume: exp(%.2f)" % info['logvol'] + \
321
            ('' if 'order_test_correlation' not in info else
322
             (" | Quality: correlation length: %d (%s)" % (info['order_test_correlation'], '+' if info['order_test_direction'] >= 0 else '-'))
323
             if np.isfinite(info['order_test_correlation']) else " | Quality: ok")
324

325
        if info.get('stepsampler_info', {}).get('num_logs', 0) > 0:
×
326
            labeltext += ("<br/>" +
×
327
                'Step sampler performance: %(rejection_rate).1f%% rej/step, %(mean_nsteps)d steps/it' % (info['stepsampler_info']) +
328
                ('mean rel jump distance: %.2f (should be >1), %.2f%% (should be >50%%)' % (
329
                    info['stepsampler_info']['mean_distance'], 100 * info['stepsampler_info']['frac_far_enough']
330
                )) if 'mean_distance' in info['stepsampler_info'] else ''
331
            )
332

333
        if ndim == 1:
×
334
            pass
×
335
        elif ndim == 2 and spearman is not None:
×
336
            rho, pval = spearman(p)
×
337
            if pval < 0.01 and abs(rho) > 0.75:
×
338
                labeltext += ("<br/>   %s between %s and %s: rho=%.2f" % (
×
339
                    'positive degeneracy' if rho > 0 else 'negative degeneracy',
340
                    paramnames[0], paramnames[1], rho))
341
        elif spearman is not None:
×
342
            rho, pval = spearman(p)
×
343
            for i, param in enumerate(paramnames):
×
344
                for j, param2 in enumerate(paramnames[:i]):
×
345
                    if pval[i,j] < 0.01 and abs(rho[i,j]) > 0.99:
×
346
                        labeltext += ("<br/>   perfect %s between %s and %s" % (
×
347
                            'positive relation' if rho[i,j] > 0 else 'negative relation',
348
                            param2, param))
349
                    elif pval[i,j] < 0.01 and abs(rho[i,j]) > 0.75:
×
350
                        labeltext += ("<br/>   %s between %s and %s: rho=%.2f" % (
×
351
                            'positive degeneracy' if rho[i,j] > 0 else 'negative degeneracy',
352
                            param2, param, rho[i,j]))
353

354
        for i, (param, fmt) in enumerate(zip(paramnames, paramformats)):
×
355
            if nmodes == 1:
×
356
                line = [' ' for _ in range(width)]
×
357
                for j in np.unique(indices[:,i]):
×
358
                    line[j] = '*'
×
359
                linestr = ''.join(line)
×
360
            else:
361
                line = [' ' for _ in range(width)]
×
362
                for clusterid, j in zip(clusterids, indices[:,i]):
×
363
                    if clusterid > 0 and line[j] in (' ', '0'):
×
364
                        # set it to correct cluster id
365
                        line[j] = clusteridstrings[clusterid]
×
366
                    elif clusterid == 0 and line[j] == ' ':
×
367
                        # empty, so set it although we don't know the cluster id
368
                        line[j] = '0'
×
369
                    # else:
370
                    #     line[j] = '*'
371
                linestr = ''.join(line)
×
372

373
            oldlinestr = self.laststatus[i]
×
374
            for j, (c, d) in enumerate(zip(linestr, oldlinestr)):
×
375
                if c != d:
×
376
                    if c == ' ':
×
377
                        self.grid[i, j + 2].value = "<div style='background-color:white;'>&nbsp;</div>"
×
378
                    else:
379
                        self.grid[i, j + 2].value = "<div style='background-color:#6E6BF4; font-family:monospace'>%s</div>" % c.replace('*', '&nbsp;')
×
380

381
            self.laststatus[i] = linestr
×
382
            # self.grid[i,0].value = param
383
            self.grid[i, 1].value = fmt % plo_rounded[i]
×
384
            self.grid[i,-1].value = fmt % phi_rounded[i]
×
385

386
        self.label.value = labeltext
×
387

388

389
def get_default_viz_callback():
1✔
390
    """Get default callback.
391

392
    LivePointsWidget for Jupyter notebook, nicelogger otherwise.
393
    """
394
    if isnotebook():
1!
395
        return LivePointsWidget()
×
396
    else:
397
        return nicelogger
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc