• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

JohannesBuchner / UltraNest / 9f2dd4f6-0775-47e9-b700-af647027ebfa

22 Apr 2024 12:51PM UTC coverage: 74.53% (+0.3%) from 74.242%
9f2dd4f6-0775-47e9-b700-af647027ebfa

push

circleci

web-flow
Merge pull request #118 from njzifjoiez/fixed-size-vectorised-slice-sampler

vectorised slice sampler of fixed batch size

1329 of 2026 branches covered (65.6%)

Branch coverage included in aggregate %.

79 of 80 new or added lines in 1 file covered. (98.75%)

1 existing line in 1 file now uncovered.

4026 of 5159 relevant lines covered (78.04%)

0.78 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

65.16
/ultranest/plot.py
1
"""
2
Plotting utilities
3
------------------
4

5
"""
6

7
from __future__ import (print_function, division)
1✔
8
from six.moves import range
1✔
9

10
import logging
1✔
11
import types
1✔
12
import warnings
1✔
13

14
import numpy as np
1✔
15
import matplotlib.pyplot as pl
1✔
16
from matplotlib.ticker import MaxNLocator, NullLocator
1✔
17
# from matplotlib.colors import LinearSegmentedColormap, colorConverter
18
from matplotlib.ticker import ScalarFormatter
1✔
19

20
import scipy.stats
1✔
21
import matplotlib.pyplot as plt
1✔
22
import numpy
1✔
23

24
from .utils import resample_equal
1✔
25
from .utils import quantile as _quantile
1✔
26

27
try:
1✔
28
    str_type = types.StringTypes
1✔
29
    float_type = types.FloatType
×
30
    int_type = types.IntType
×
31
except Exception:
1✔
32
    str_type = str
1✔
33
    float_type = float
1✔
34
    int_type = int
1✔
35

36
import corner
1✔
37

38
__all__ = ["runplot", "cornerplot", "traceplot", "PredictionBand"]
1✔
39

40

41
def cornerplot(
1✔
42
    results, min_weight=1e-4, with_legend=True, logger=None,
43
    levels=[0.9973, 0.9545, 0.6827, 0.3934],
44
    plot_datapoints=False, plot_density=False, show_titles=True, quiet=True,
45
    contour_kwargs=dict(linestyles=['-','-.',':','--'], colors=['navy','navy','navy','purple']),
46
    color='purple', quantiles=[0.15866, 0.5, 0.8413], **corner_kwargs
47
):
48
    """Make a healthy corner plot with corner.
49

50
    Essentially does::
51

52
        paramnames = results['paramnames']
53
        data = results['weighted_samples']['points']
54
        weights = results['weighted_samples']['weights']
55

56
        return corner.corner(
57
            results['weighted_samples']['points'],
58
            weights=results['weighted_samples']['weights'],
59
            labels=results['paramnames'])
60

61
    Parameters
62
    ----------
63
    min_weight: float
64
        cut off low-weight posterior points. Avoids meaningless
65
        stragglers when plot_datapoints is True.
66
    with_legend: bool
67
        whether to add a legend to show meaning of the lines.
68
    color : str
69
        ``matplotlib`` style color for all histograms.
70
    plot_density : bool
71
        Draw the density colormap.
72
    plot_contours : bool
73
        Draw the contours.
74
    show_titles : bool
75
        Displays a title above each 1-D histogram showing the 0.5 quantile
76
        with the upper and lower errors supplied by the quantiles argument.
77
    quiet : bool
78
        If true, suppress warnings for small datasets.
79
    contour_kwargs : dict
80
        Any additional keyword arguments to pass to the `contour` method.
81
    quantiles: list
82
        fractional quantiles to show on the 1-D histograms as vertical dashed lines.
83
    **corner_kwargs: dict
84
        Any remaining keyword arguments are sent to :func:`corner.corner`.
85

86
    Returns
87
    -------
88
    fig : `~matplotlib.figure.Figure`
89
        The ``matplotlib`` figure instance for the corner plot.
90

91
    """
92
    paramnames = results['paramnames']
1✔
93
    data = np.array(results['weighted_samples']['points'])
1✔
94
    weights = np.array(results['weighted_samples']['weights'])
1✔
95
    cumsumweights = np.cumsum(weights)
1✔
96

97
    mask = cumsumweights > min_weight
1✔
98

99
    if mask.sum() == 1:
1!
100
        if logger is not None:
×
101
            warn = 'Posterior is still concentrated in a single point:'
×
102
            for i, p in enumerate(paramnames):
×
103
                v = results['samples'][mask,i]
×
104
                warn += "\n" + '    %-20s: %s' % (p, v)
×
105

106
            logger.warning(warn)
×
107
            logger.info('Try running longer.')
×
108
        return
×
109

110
    # monkey patch to disable a useless warning
111
    oldfunc = logging.warning
1✔
112
    logging.warning = lambda *args, **kwargs: None
1!
113
    fig = corner.corner(
1✔
114
        data[mask,:], weights=weights[mask],
115
        labels=paramnames, show_titles=show_titles, quiet=quiet,
116
        plot_datapoints=plot_datapoints, plot_density=plot_density,
117
        levels=levels, quantiles=quantiles,
118
        contour_kwargs=contour_kwargs, color=color, **corner_kwargs
119
    )
120
    # Create legend handles
121
    if with_legend and data.shape[1] > 1:
1✔
122
        legend_handles = [
1✔
123
            plt.Line2D(
124
                [0], [0], linestyle='--', color=color,
125
                label='%.1f%% marginal' % (100 * (quantiles[-1] - quantiles[0]))),
126
        ] + [plt.Line2D(
127
            [0], [0], linestyle=ls, color=linecolor,
128
            label='%.1f%%' % (100 * level))
129
            for ls, linecolor, level in zip(
130
                contour_kwargs.get('linestyles', [])[::-1],
131
                contour_kwargs.get('colors', [color] * 100)[::-1],
132
                levels[::-1])
133
        ]
134
        if len(legend_handles) == len(levels) + 1 and len(legend_handles) > 0:
1!
135
            plt.legend(
1✔
136
                title='credible prob level',
137
                handles=legend_handles,
138
                loc='lower right', bbox_to_anchor=(1.01,1.2), frameon=False
139
            )
140
    logging.warning = oldfunc
1✔
141
    return fig
1✔
142

143

144
def highest_density_interval_from_samples(xsamples, xlo=None, xhi=None, probability_level=0.68):
1✔
145
    """
146
    Compute the highest density interval (HDI) from posterior samples.
147

148
    Parameters
149
    ----------
150
    xsamples : array_like
151
        The posterior samples from which to compute the HDI.
152
    xlo : float or None, optional
153
        Lower boundary limiting the space. Default is None.
154
    xhi : float or None, optional
155
        Upper boundary limiting the space. Default is None.
156
    probability_level : float, optional
157
        The desired probability level for the HDI. Default is 0.68.
158

159
    Returns
160
    -------
161
    x_MAP: float
162
        maximum a posteriori (MAP) estimate.
163
    xerrlo: float
164
        lower uncertainty (lower HDI bound minus x_MAP).
165
    xerrhi: float
166
        upper uncertainty (x_MAP minus upper HDI bound).
167

168
    Notes
169
    -----
170
    The function starts at the highest density point and accumulates neighboring points
171
    until the specified probability level is reached. If `xlo` or `xhi` is provided,
172
    the HDI is constrained within these bounds.
173

174
    Requires getdist to be installed for a kernel density estimation.
175

176
    For uniform distributions, this function will give unpredictable results for the MAP.
177

178
    Examples
179
    --------
180
    >>> xsamples = np.random.normal(loc=0, scale=1, size=100000)
181
    >>> hdi = highest_density_interval_from_samples(xsamples)
182
    >>> print('x = %.1f + %.2f - %.2f' % hdi)
183
    x = 0.0 + 1.02 - 0.96
184
    """
185
    from getdist.mcsamples import MCSamples
1✔
186
    import getdist.chains
1✔
187
    getdist.chains.print_load_details = False
1✔
188
    samples = MCSamples(
1✔
189
        samples=xsamples, names=['x'], ranges={'x':[xlo,xhi]},
190
        settings=dict(mult_bias_correction_order=1))
191
    samples.raise_on_bandwidth_errors = True
1✔
192
    density_bounded = samples.get1DDensityGridData('x')
1✔
193

194
    x = density_bounded.x
1✔
195
    y = density_bounded.P / np.sum(density_bounded.P)
1✔
196

197
    # Sort the y values in descending order
198
    sorted_indices = np.argsort(y)[::-1]
1✔
199

200
    # define MAP as the peak. This works well if the peak is declining to both sides
201
    MAP = x[sorted_indices[0]]
1✔
202
    total_probability = y[sorted_indices[0]]
1✔
203
    i_lo = sorted_indices[0]
1✔
204
    i_hi = sorted_indices[0]
1✔
205
    for i in sorted_indices[1:]:
1!
206
        # Add the current probability to the total
207
        i_lo = min(i_lo, i)
1✔
208
        i_hi = max(i_hi, i)
1✔
209
        total_probability = y[i_lo:i_hi + 1].sum()
1✔
210
        # Check if the total probability exceeds or equals the desired level
211
        if total_probability >= probability_level:
1✔
212
            break
1✔
213

214
    x_lo = x[i_lo]
1✔
215
    x_hi = x[i_hi]
1✔
216
    return MAP, MAP - x_lo, x_hi - MAP
1✔
217

218

219
class PredictionBand(object):
1✔
220
    """Plot bands of model predictions as calculated from a chain.
221

222
    call add(y) to add predictions from each chain point
223

224
    .. testsetup::
225

226
        import numpy
227
        chain = numpy.random.uniform(size=(20, 2))
228

229
    .. testcode::
230

231
        x = numpy.linspace(0, 1, 100)
232
        band = PredictionBand(x)
233
        for c in chain:
234
            band.add(c[0] * x + c[1])
235
        # add median line. As an option a matplotlib ax can be given.
236
        band.line(color='k')
237
        # add 1 sigma quantile
238
        band.shade(color='k', alpha=0.3)
239
        # add wider quantile
240
        band.shade(q=0.01, color='gray', alpha=0.1)
241
        plt.show()
242

243
    To plot onto a specific axis, use `band.line(..., ax=myaxis)`.
244

245
    Parameters
246
    ----------
247
    x: array
248
        The independent variable
249

250
    """
251

252
    def __init__(self, x, shadeargs={}, lineargs={}):
1✔
253
        """Initialise with independent variable *x*."""
254
        self.x = x
1✔
255
        self.ys = []
1✔
256
        self.shadeargs = shadeargs
1✔
257
        self.lineargs = lineargs
1✔
258

259
    def add(self, y):
1✔
260
        """Add a possible prediction *y*."""
261
        self.ys.append(y)
1✔
262

263
    def set_shadeargs(self, **kwargs):
1✔
264
        """Set matplotlib style for shading."""
265
        self.shadeargs = kwargs
×
266

267
    def set_lineargs(self, **kwargs):
1✔
268
        """Set matplotlib style for line."""
269
        self.lineargs = kwargs
×
270

271
    def get_line(self, q=0.5):
1✔
272
        """Over prediction space x, get quantile *q*. Default is median."""
273
        if not 0 <= q <= 1:
1!
274
            raise ValueError("quantile q must be between 0 and 1, not %s" % q)
×
275
        assert len(self.ys) > 0, self.ys
1✔
276
        return scipy.stats.mstats.mquantiles(self.ys, q, axis=0)[0]
1✔
277

278
    def shade(self, q=0.341, ax=None, **kwargs):
1✔
279
        """Plot a shaded region between 0.5-q and 0.5+q, by default 1 sigma."""
280
        if not 0 <= q <= 0.5:
1✔
281
            raise ValueError("quantile distance from the median, q, must be between 0 and 0.5, not %s. For a 99%% quantile range, use q=0.48." % q)
1✔
282
        shadeargs = dict(self.shadeargs)
1✔
283
        shadeargs.update(kwargs)
1✔
284
        lo = self.get_line(0.5 - q)
1✔
285
        hi = self.get_line(0.5 + q)
1✔
286
        if ax is None:
1✔
287
            ax = plt
1✔
288
        return ax.fill_between(self.x, lo, hi, **shadeargs)
1✔
289

290
    def line(self, ax=None, **kwargs):
1✔
291
        """Plot the median curve."""
292
        lineargs = dict(self.lineargs)
1✔
293
        lineargs.update(kwargs)
1✔
294
        mid = self.get_line(0.5)
1✔
295
        if ax is None:
1✔
296
            ax = plt
1✔
297
        return ax.plot(self.x, mid, **lineargs)
1✔
298

299

300
# the following function is taken from https://github.com/joshspeagle/dynesty/blob/master/dynesty/plotting.py
301
# Copyright (c) 2017 - Present: Josh Speagle and contributors.
302
# Copyright (c) 2014 - 2017: Kyle Barbary and contributors.
303
# https://github.com/joshspeagle/dynesty/blob/master/LICENSE
304
def runplot(results, span=None, logplot=False, kde=True, nkde=1000,
1✔
305
            color='blue', plot_kwargs=None, label_kwargs=None, lnz_error=True,
306
            lnz_truth=None, truth_color='red', truth_kwargs=None,
307
            max_x_ticks=8, max_y_ticks=3, use_math_text=True,
308
            mark_final_live=True, fig=None
309
            ):
310
    """Plot live points, ln(likelihood), ln(weight), and ln(evidence) vs. ln(prior volume).
311

312
    Parameters
313
    ----------
314
    results : dynesty.results.Results instance
315
        dynesty.results.Results instance from a nested
316
        sampling run.
317
    span : iterable with shape (4,), optional
318
        A list where each element is either a length-2 tuple containing
319
        lower and upper bounds *or* a float from `(0., 1.]` giving the
320
        fraction below the maximum. If a fraction is provided,
321
        the bounds are chosen to be equal-tailed. An example would be::
322

323
            span = [(0., 10.), 0.001, 0.2, (5., 6.)]
324

325
        Default is `(0., 1.05 * max(data))` for each element.
326
    logplot : bool, optional
327
        Whether to plot the evidence on a log scale. Default is `False`.
328
    kde : bool, optional
329
        Whether to use kernel density estimation to estimate and plot
330
        the PDF of the importance weights as a function of log-volume
331
        (as opposed to the importance weights themselves). Default is
332
        `True`.
333
    nkde : int, optional
334
        The number of grid points used when plotting the kernel density
335
        estimate. Default is `1000`.
336
    color : str or iterable with shape (4,), optional
337
        A `~matplotlib`-style color (either a single color or a different
338
        value for each subplot) used when plotting the lines in each subplot.
339
        Default is `'blue'`.
340
    plot_kwargs : dict, optional
341
        Extra keyword arguments that will be passed to `plot`.
342
    label_kwargs : dict, optional
343
        Extra keyword arguments that will be sent to the
344
        `~matplotlib.axes.Axes.set_xlabel` and
345
        `~matplotlib.axes.Axes.set_ylabel` methods.
346
    lnz_error : bool, optional
347
        Whether to plot the 1, 2, and 3-sigma approximate error bars
348
        derived from the ln(evidence) error approximation over the course
349
        of the run. Default is True.
350
    lnz_truth : float, optional
351
        A reference value for the evidence that will be overplotted on the
352
        evidence subplot if provided.
353
    truth_color : str or iterable with shape (ndim,), optional
354
        A `~matplotlib`-style color used when plotting `lnz_truth`.
355
        Default is `'red'`.
356
    truth_kwargs : dict, optional
357
        Extra keyword arguments that will be used for plotting
358
        `lnz_truth`.
359
    max_x_ticks : int, optional
360
        Maximum number of ticks allowed for the x axis. Default is `8`.
361
    max_y_ticks : int, optional
362
        Maximum number of ticks allowed for the y axis. Default is `4`.
363
    use_math_text : bool, optional
364
        Whether the axis tick labels for very large/small exponents should be
365
        displayed as powers of 10 rather than using `e`. Default is `False`.
366
    mark_final_live : bool, optional
367
        Whether to indicate the final addition of recycled live points
368
        (if they were added to the resulting samples) using
369
        a dashed vertical line. Default is `True`.
370
    fig : (`~matplotlib.figure.Figure`, `~matplotlib.axes.Axes`), optional
371
        If provided, overplot the run onto the provided figure.
372
        Otherwise, by default an internal figure is generated.
373

374
    Returns
375
    -------
376
    runplot : (`~matplotlib.figure.Figure`, `~matplotlib.axes.Axes`)
377
        Output summary plot.
378

379
    """
380
    # Initialize values.
381
    if label_kwargs is None:
1!
382
        label_kwargs = dict()
1✔
383
    if plot_kwargs is None:
1!
384
        plot_kwargs = dict()
1✔
385
    if truth_kwargs is None:
1!
386
        truth_kwargs = dict()
1✔
387

388
    # Set defaults.
389
    plot_kwargs['linewidth'] = plot_kwargs.get('linewidth', 5)
1✔
390
    plot_kwargs['alpha'] = plot_kwargs.get('alpha', 0.7)
1✔
391
    truth_kwargs['linestyle'] = truth_kwargs.get('linestyle', 'solid')
1✔
392
    truth_kwargs['linewidth'] = truth_kwargs.get('linewidth', 3)
1✔
393

394
    # Extract results.
395
    niter = results['niter']  # number of iterations
1✔
396
    logvol = results['logvol']  # ln(prior volume)
1✔
397
    logl = results['logl'] - max(results['logl'])  # ln(normalized likelihood)
1✔
398
    logwt = results['logwt'] - results['logz'][-1]  # ln(importance weight)
1✔
399
    logz = results['logz']  # ln(evidence)
1✔
400
    logzerr = results['logzerr']  # error in ln(evidence)
1✔
401
    weights = results['weights']
1✔
402
    logzerr[~np.isfinite(logzerr)] = 0.
1✔
403
    nsamps = len(logwt)  # number of samples
1✔
404

405
    # Check whether the run was "static" or "dynamic".
406
    try:
1✔
407
        nlive = results['samples_n']
1✔
408
        mark_final_live = False
1✔
409
    except Exception:
×
410
        nlive = np.ones(niter) * results['nlive']
×
411
        if nsamps - niter == results['nlive']:
×
412
            nlive_final = np.arange(1, results['nlive'] + 1)[::-1]
×
413
            nlive = np.append(nlive, nlive_final)
×
414

415
    # Check if the final set of live points were added to the results.
416
    if mark_final_live:
1!
417
        if nsamps - niter == results['nlive']:
×
418
            live_idx = niter
×
419
        else:
420
            warnings.warn("The number of iterations and samples differ "
×
421
                          "by an amount that isn't the number of final "
422
                          "live points. `mark_final_live` has been disabled.")
423
            mark_final_live = False
×
424

425
    # Determine plotting bounds for each subplot.
426
    data = [nlive, np.exp(logl), weights, logz if logplot else np.exp(logz)]
1✔
427

428
    kde = kde and (weights * len(logvol) > 0.1).sum() > 10
1✔
429
    if kde:
1!
430
        try:
1✔
431
            # from scipy.ndimage import gaussian_filter as norm_kde
432
            from scipy.stats import gaussian_kde
1✔
433
            # Derive kernel density estimate.
434
            wt_kde = gaussian_kde(resample_equal(-logvol, weights))  # KDE
1✔
435
            logvol_new = np.linspace(logvol[0], logvol[-1], nkde)  # resample
1✔
436
            data[2] = wt_kde.pdf(-logvol_new)  # evaluate KDE PDF
1✔
437
        except ImportError:
×
438
            kde = False
×
439

440
    if span is None:
1!
441
        span = [(0., 1.05 * max(d)) for d in data]
1✔
442
        no_span = True
1✔
443
    else:
444
        no_span = False
×
445
    span = list(span)
1✔
446
    if len(span) != 4:
1!
447
        raise ValueError("More bounds provided in `span` than subplots!")
×
448
    for i, _ in enumerate(span):
1✔
449
        try:
1✔
450
            ymin, ymax = span[i]
1✔
451
        except Exception:
×
452
            span[i] = (max(data[i]) * span[i], max(data[i]))
×
453
    if lnz_error and no_span:
1!
454
        if logplot:
1!
455
            zspan = (logz[-1] - 10.3 * 3. * logzerr[-1],
1✔
456
                     logz[-1] + 1.3 * 3. * logzerr[-1])
457
        else:
458
            zspan = (0., 1.05 * np.exp(logz[-1] + 3. * logzerr[-1]))
×
459
        span[3] = zspan
1✔
460

461
    # Setting up default plot layout.
462
    if fig is None:
1!
463
        fig, axes = pl.subplots(4, 1, figsize=(16, 16))
1✔
464
        xspan = [(0., -min(logvol)) for _ax in axes]
1✔
465
        yspan = span
1✔
466
    else:
467
        fig, axes = fig
×
468
        try:
×
469
            axes.reshape(4, 1)
×
470
        except Exception:
×
471
            raise ValueError("Provided axes do not match the required shape "
×
472
                             "for plotting samples.")
473
        # If figure is provided, keep previous bounds if they were larger.
474
        xspan = [ax.get_xlim() for ax in axes]
×
475
        yspan = [ax.get_ylim() for ax in axes]
×
476
        # One exception: if the bounds are the plotting default `(0., 1.)`,
477
        # overwrite them.
478
        xspan = [t if t != (0., 1.) else (None, None) for t in xspan]
×
479
        yspan = [t if t != (0., 1.) else (None, None) for t in yspan]
×
480

481
    # Set up bounds for plotting.
482
    for i in range(4):
1✔
483
        if xspan[i][0] is None:
1!
484
            xmin = None
×
485
        else:
486
            xmin = min(0., xspan[i][0])
1✔
487
        if xspan[i][1] is None:
1!
488
            xmax = -min(logvol)
×
489
        else:
490
            xmax = max(-min(logvol), xspan[i][1])
1✔
491
        if yspan[i][0] is None:
1!
492
            ymin = None
×
493
        else:
494
            ymin = min(span[i][0], yspan[i][0])
1✔
495
        if yspan[i][1] is None:
1!
496
            ymax = span[i][1]
×
497
        else:
498
            ymax = max(span[i][1], yspan[i][1])
1✔
499
        axes[i].set_xlim([xmin, xmax])
1✔
500
        axes[i].set_ylim([ymin, ymax])
1✔
501

502
    # Plotting.
503
    labels = ['Live Points', 'Likelihood\n(normalized)',
1✔
504
              'Importance\nWeight', 'Evidence']
505
    if kde:
1!
506
        labels[2] += ' PDF'
1✔
507

508
    for i, d in enumerate(data):
1✔
509

510
        # Establish axes.
511
        ax = axes[i]
1✔
512
        # Set color(s)/colormap(s).
513
        if isinstance(color, str_type):
1!
514
            c = color
1✔
515
        else:
516
            c = color[i]
×
517
        # Setup axes.
518
        if max_x_ticks == 0:
1!
519
            ax.xaxis.set_major_locator(NullLocator())
×
520
        else:
521
            ax.xaxis.set_major_locator(MaxNLocator(max_x_ticks))
1✔
522
        if max_y_ticks == 0:
1!
523
            ax.yaxis.set_major_locator(NullLocator())
×
524
        else:
525
            ax.yaxis.set_major_locator(MaxNLocator(max_y_ticks))
1✔
526
        # Label axes.
527
        sf = ScalarFormatter(useMathText=use_math_text)
1✔
528
        ax.yaxis.set_major_formatter(sf)
1✔
529
        ax.set_xlabel(r"$-\ln X$", **label_kwargs)
1✔
530
        ax.set_ylabel(labels[i], **label_kwargs)
1✔
531
        # Plot run.
532
        if logplot and i == 3:
1✔
533
            ax.plot(-logvol, d, color=c, **plot_kwargs)
1✔
534
            yspan = [ax.get_ylim() for _ax in axes]
1✔
535
        elif kde and i == 2:
1✔
536
            ax.plot(-logvol_new, d, color=c, **plot_kwargs)
1✔
537
        else:
538
            ax.plot(-logvol, d, color=c, **plot_kwargs)
1✔
539
        if i == 3 and lnz_error:
1✔
540
            if logplot:
1!
541
                mask = logz >= ax.get_ylim()[0] - 10
1✔
542
                [ax.fill_between(-logvol[mask], (logz + s * logzerr)[mask],
1✔
543
                                 (logz - s * logzerr)[mask],
544
                                 color=c, alpha=0.2)
545
                 for s in range(1, 4)]
546
            else:
547
                [ax.fill_between(-logvol, np.exp(logz + s * logzerr),
×
548
                                 np.exp(logz - s * logzerr), color=c, alpha=0.2)
549
                 for s in range(1, 4)]
550
        # Mark addition of final live points.
551
        if mark_final_live:
1!
552
            ax.axvline(-logvol[live_idx], color=c, ls="dashed", lw=2,
×
553
                       **plot_kwargs)
554
            if i == 0:
×
555
                ax.axhline(live_idx, color=c, ls="dashed", lw=2,
×
556
                           **plot_kwargs)
557
        # Add truth value(s).
558
        if i == 3 and lnz_truth is not None:
1!
559
            if logplot:
×
560
                ax.axhline(lnz_truth, color=truth_color, **truth_kwargs)
×
561
            else:
562
                ax.axhline(np.exp(lnz_truth), color=truth_color, **truth_kwargs)
×
563

564
    return fig, axes
1✔
565

566

567
def traceplot(results, span=None, quantiles=[0.025, 0.5, 0.975], smooth=0.02,
1✔
568
              post_color='blue', post_kwargs=None, kde=True, nkde=1000,
569
              trace_cmap='plasma', trace_color=None, trace_kwargs=None,
570
              connect=False, connect_highlight=10, connect_color='red',
571
              connect_kwargs=None, max_n_ticks=5, use_math_text=False,
572
              labels=None, label_kwargs=None,
573
              show_titles=False, title_fmt=".2f", title_kwargs=None,
574
              truths=None, truth_color='red', truth_kwargs=None,
575
              verbose=False, fig=None):
576
    """Plot traces and marginalized posteriors for each parameter.
577

578
    Parameters
579
    ----------
580
    results : `~dynesty.results.Results` instance
581
        A `~dynesty.results.Results` instance from a nested
582
        sampling run. **Compatible with results derived from**
583
        `nestle <http://kylebarbary.com/nestle/>`_.
584
    span : iterable with shape (ndim,), optional
585
        A list where each element is either a length-2 tuple containing
586
        lower and upper bounds or a float from `(0., 1.]` giving the
587
        fraction of (weighted) samples to include. If a fraction is provided,
588
        the bounds are chosen to be equal-tailed. An example would be::
589

590
            span = [(0., 10.), 0.95, (5., 6.)]
591

592
        Default is `0.999999426697` (5-sigma credible interval) for each
593
        parameter.
594
    quantiles : iterable, optional
595
        A list of fractional quantiles to overplot on the 1-D marginalized
596
        posteriors as vertical dashed lines. Default is `[0.025, 0.5, 0.975]`
597
        (the 95%/2-sigma credible interval).
598
    smooth : float or iterable with shape (ndim,), optional
599
        The standard deviation (either a single value or a different value for
600
        each subplot) for the Gaussian kernel used to smooth the 1-D
601
        marginalized posteriors, expressed as a fraction of the span.
602
        Default is `0.02` (2% smoothing). If an integer is provided instead,
603
        this will instead default to a simple (weighted) histogram with
604
        `bins=smooth`.
605
    post_color : str or iterable with shape (ndim,), optional
606
        A `~matplotlib`-style color (either a single color or a different
607
        value for each subplot) used when plotting the histograms.
608
        Default is `'blue'`.
609
    post_kwargs : dict, optional
610
        Extra keyword arguments that will be used for plotting the
611
        marginalized 1-D posteriors.
612
    kde : bool, optional
613
        Whether to use kernel density estimation to estimate and plot
614
        the PDF of the importance weights as a function of log-volume
615
        (as opposed to the importance weights themselves). Default is
616
        `True`.
617
    nkde : int, optional
618
        The number of grid points used when plotting the kernel density
619
        estimate. Default is `1000`.
620
    trace_cmap : str or iterable with shape (ndim,), optional
621
        A `~matplotlib`-style colormap (either a single colormap or a
622
        different colormap for each subplot) used when plotting the traces,
623
        where each point is colored according to its weight. Default is
624
        `'plasma'`.
625
    trace_color : str or iterable with shape (ndim,), optional
626
        A `~matplotlib`-style color (either a single color or a
627
        different color for each subplot) used when plotting the traces.
628
        This overrides the `trace_cmap` option by giving all points
629
        the same color. Default is `None` (not used).
630
    trace_kwargs : dict, optional
631
        Extra keyword arguments that will be used for plotting the traces.
632
    connect : bool, optional
633
        Whether to draw lines connecting the paths of unique particles.
634
        Default is `False`.
635
    connect_highlight : int or iterable, optional
636
        If `connect=True`, highlights the paths of a specific set of
637
        particles. If an integer is passed, :data:`connect_highlight`
638
        random particle paths will be highlighted. If an iterable is passed,
639
        then the particle paths corresponding to the provided indices
640
        will be highlighted.
641
    connect_color : str, optional
642
        The color of the highlighted particle paths. Default is `'red'`.
643
    connect_kwargs : dict, optional
644
        Extra keyword arguments used for plotting particle paths.
645
    max_n_ticks : int, optional
646
        Maximum number of ticks allowed. Default is `5`.
647
    use_math_text : bool, optional
648
        Whether the axis tick labels for very large/small exponents should be
649
        displayed as powers of 10 rather than using `e`. Default is `False`.
650
    labels : iterable with shape (ndim,), optional
651
        A list of names for each parameter. If not provided, the default name
652
        used when plotting will follow :math:`x_i` style.
653
    label_kwargs : dict, optional
654
        Extra keyword arguments that will be sent to the
655
        `~matplotlib.axes.Axes.set_xlabel` and
656
        `~matplotlib.axes.Axes.set_ylabel` methods.
657
    show_titles : bool, optional
658
        Whether to display a title above each 1-D marginalized posterior
659
        showing the 0.5 quantile along with the upper/lower bounds associated
660
        with the 0.025 and 0.975 (95%/2-sigma credible interval) quantiles.
661
        Default is `True`.
662
    title_fmt : str, optional
663
        The format string for the quantiles provided in the title. Default is
664
        `'.2f'`.
665
    title_kwargs : dict, optional
666
        Extra keyword arguments that will be sent to the
667
        `~matplotlib.axes.Axes.set_title` command.
668
    truths : iterable with shape (ndim,), optional
669
        A list of reference values that will be overplotted on the traces and
670
        marginalized 1-D posteriors as solid horizontal/vertical lines.
671
        Individual values can be exempt using `None`. Default is `None`.
672
    truth_color : str or iterable with shape (ndim,), optional
673
        A `~matplotlib`-style color (either a single color or a different
674
        value for each subplot) used when plotting `truths`.
675
        Default is `'red'`.
676
    truth_kwargs : dict, optional
677
        Extra keyword arguments that will be used for plotting the vertical
678
        and horizontal lines with `truths`.
679
    verbose : bool, optional
680
        Whether to print the values of the computed quantiles associated with
681
        each parameter. Default is `False`.
682
    fig : (`~matplotlib.figure.Figure`, `~matplotlib.axes.Axes`), optional
683
        If provided, overplot the traces and marginalized 1-D posteriors
684
        onto the provided figure. Otherwise, by default an
685
        internal figure is generated.
686

687
    Returns
688
    -------
689
    traceplot : (`~matplotlib.figure.Figure`, `~matplotlib.axes.Axes`)
690
        Output trace plot.
691

692
    """
693
    # Initialize values.
694
    if title_kwargs is None:
1!
695
        title_kwargs = dict()
1✔
696
    if label_kwargs is None:
1!
697
        label_kwargs = dict()
1✔
698
    if trace_kwargs is None:
1!
699
        trace_kwargs = dict()
1✔
700
    if connect_kwargs is None:
1!
701
        connect_kwargs = dict()
1✔
702
    if post_kwargs is None:
1!
703
        post_kwargs = dict()
1✔
704
    if truth_kwargs is None:
1!
705
        truth_kwargs = dict()
1✔
706

707
    # Set defaults.
708
    connect_kwargs['alpha'] = connect_kwargs.get('alpha', 0.7)
1✔
709
    post_kwargs['alpha'] = post_kwargs.get('alpha', 0.6)
1✔
710
    trace_kwargs['s'] = trace_kwargs.get('s', 3)
1✔
711
    truth_kwargs['linestyle'] = truth_kwargs.get('linestyle', 'solid')
1✔
712
    truth_kwargs['linewidth'] = truth_kwargs.get('linewidth', 2)
1✔
713

714
    # Extract weighted samples.
715
    samples = results['samples']
1✔
716
    logvol = results['logvol']
1✔
717
    weights = results['weights']
1✔
718

719
    wts = weights
1✔
720
    kde = kde and (weights * len(logvol) > 0.1).sum() > 10
1✔
721
    if kde:
1!
722
        try:
1✔
723
            from scipy.ndimage import gaussian_filter as norm_kde
1✔
724
            from scipy.stats import gaussian_kde
1✔
725
            # Derive kernel density estimate.
726
            wt_kde = gaussian_kde(resample_equal(-logvol, weights))  # KDE
1✔
727
            logvol_grid = np.linspace(logvol[0], logvol[-1], nkde)  # resample
1✔
728
            wt_grid = wt_kde.pdf(-logvol_grid)  # evaluate KDE PDF
1✔
729
            wts = np.interp(-logvol, -logvol_grid, wt_grid)  # interpolate
1✔
730
        except ImportError:
×
731
            kde = False
×
732

733
    # Deal with 1D results. A number of extra catches are also here
734
    # in case users are trying to plot other results besides the `Results`
735
    # instance generated by `dynesty`.
736
    samples = np.atleast_1d(samples)
1✔
737
    if len(samples.shape) == 1:
1!
738
        samples = np.atleast_2d(samples)
×
739
    else:
740
        assert len(samples.shape) == 2, "Samples must be 1- or 2-D."
1✔
741
        samples = samples.T
1✔
742
    assert samples.shape[0] <= samples.shape[1], "There are more dimensions than samples!"
1✔
743
    ndim, nsamps = samples.shape
1✔
744

745
    # Check weights.
746
    if weights.ndim != 1:
1!
747
        raise ValueError("Weights must be 1-D.")
×
748
    if nsamps != weights.shape[0]:
1!
749
        raise ValueError("The number of weights and samples disagree!")
×
750

751
    # Check ln(volume).
752
    if logvol.ndim != 1:
1!
753
        raise ValueError("Ln(volume)'s must be 1-D.")
×
754
    if nsamps != logvol.shape[0]:
1!
755
        raise ValueError("The number of ln(volume)'s and samples disagree!")
×
756

757
    # Check sample IDs.
758
    if connect:
1!
759
        try:
×
760
            samples_id = results['samples_id']
×
761
            uid = np.unique(samples_id)
×
762
        except Exception:
×
763
            raise ValueError("Sample IDs are not defined!")
×
764
        try:
×
765
            ids = connect_highlight[0]
×
766
            ids = connect_highlight
×
767
        except Exception:
×
768
            ids = np.random.choice(uid, size=connect_highlight, replace=False)
×
769

770
    # Determine plotting bounds for marginalized 1-D posteriors.
771
    if span is None:
1!
772
        span = [0.999999426697 for i in range(ndim)]
1✔
773
    span = list(span)
1✔
774
    if len(span) != ndim:
1!
775
        raise ValueError("Dimension mismatch between samples and span.")
×
776
    for i, _ in enumerate(span):
1✔
777
        try:
1✔
778
            xmin, xmax = span[i]
1✔
779
        except Exception:
1✔
780
            q = [0.5 - 0.5 * span[i], 0.5 + 0.5 * span[i]]
1✔
781
            span[i] = _quantile(samples[i], q, weights=weights)
1✔
782

783
    # Setting up labels.
784
    if labels is None:
1!
785
        labels = [r"$x_{%d}$" % (i + 1) for i in range(ndim)]
×
786

787
    # Setting up smoothing.
788
    if (isinstance(smooth, int_type) or isinstance(smooth, float_type)):
1!
789
        smooth = [smooth for i in range(ndim)]
1✔
790

791
    # Setting up default plot layout.
792
    if fig is None:
1!
793
        fig, axes = pl.subplots(ndim, 2, figsize=(12, 3 * ndim))
1✔
794
    else:
795
        fig, axes = fig
×
796
        try:
×
797
            axes.reshape(ndim, 2)
×
798
        except Exception:
×
799
            raise ValueError("Provided axes do not match the required shape "
×
800
                             "for plotting samples.")
801

802
    # Plotting.
803
    for i, x in enumerate(samples):
1✔
804

805
        # Plot trace.
806

807
        # Establish axes.
808
        if np.shape(samples)[0] == 1:
1✔
809
            ax = axes[1]
1✔
810
        else:
811
            ax = axes[i, 0]
1✔
812
        # Set color(s)/colormap(s).
813
        if trace_color is not None:
1!
814
            if isinstance(trace_color, str_type):
×
815
                color = trace_color
×
816
            else:
817
                color = trace_color[i]
×
818
        else:
819
            color = wts
1✔
820
        if isinstance(trace_cmap, str_type):
1!
821
            cmap = trace_cmap
1✔
822
        else:
823
            cmap = trace_cmap[i]
×
824
        # Setup axes.
825
        ax.set_xlim([0., -min(logvol)])
1✔
826
        ax.set_ylim([min(x), max(x)])
1✔
827
        if max_n_ticks == 0:
1!
828
            ax.xaxis.set_major_locator(NullLocator())
×
829
            ax.yaxis.set_major_locator(NullLocator())
×
830
        else:
831
            ax.xaxis.set_major_locator(MaxNLocator(max_n_ticks))
1✔
832
            ax.yaxis.set_major_locator(MaxNLocator(max_n_ticks))
1✔
833
        # Label axes.
834
        sf = ScalarFormatter(useMathText=use_math_text)
1✔
835
        ax.yaxis.set_major_formatter(sf)
1✔
836
        ax.set_xlabel(r"$-\ln X$", **label_kwargs)
1✔
837
        ax.set_ylabel(labels[i], **label_kwargs)
1✔
838
        # Generate scatter plot.
839
        ax.scatter(-logvol, x, c=color, cmap=cmap, **trace_kwargs)
1✔
840
        if connect:
1!
841
            # Add lines highlighting specific particle paths.
842
            for j in ids:
×
843
                sel = (samples_id == j)
×
844
                ax.plot(-logvol[sel], x[sel], color=connect_color,
×
845
                        **connect_kwargs)
846
        # Add truth value(s).
847
        if truths is not None and truths[i] is not None:
1!
848
            try:
×
849
                [ax.axhline(t, color=truth_color, **truth_kwargs)
×
850
                 for t in truths[i]]
851
            except Exception:
×
852
                ax.axhline(truths[i], color=truth_color, **truth_kwargs)
×
853

854
        # Plot marginalized 1-D posterior.
855

856
        # Establish axes.
857
        if np.shape(samples)[0] == 1:
1✔
858
            ax = axes[0]
1✔
859
        else:
860
            ax = axes[i, 1]
1✔
861
        # Set color(s).
862
        if isinstance(post_color, str_type):
1!
863
            color = post_color
1✔
864
        else:
865
            color = post_color[i]
×
866
        # Setup axes
867
        ax.set_xlim(span[i])
1✔
868
        if max_n_ticks == 0:
1!
869
            ax.xaxis.set_major_locator(NullLocator())
×
870
            ax.yaxis.set_major_locator(NullLocator())
×
871
        else:
872
            ax.xaxis.set_major_locator(MaxNLocator(max_n_ticks))
1✔
873
            ax.yaxis.set_major_locator(NullLocator())
1✔
874
        # Label axes.
875
        sf = ScalarFormatter(useMathText=use_math_text)
1✔
876
        ax.xaxis.set_major_formatter(sf)
1✔
877
        ax.set_xlabel(labels[i], **label_kwargs)
1✔
878
        # Generate distribution.
879
        s = smooth[i]
1✔
880
        if isinstance(s, int_type):
1!
881
            # If `s` is an integer, plot a weighted histogram with
882
            # `s` bins within the provided bounds.
883
            n, b, _ = ax.hist(x, bins=s, weights=weights, color=color,
×
884
                              range=np.sort(span[i]), **post_kwargs)
885
            x0 = np.array(list(zip(b[:-1], b[1:]))).flatten()
×
886
            y0 = np.array(list(zip(n, n))).flatten()
×
887
        else:
888
            # If `s` is a float, oversample the data relative to the
889
            # smoothing filter by a factor of 10, then use a Gaussian
890
            # filter to smooth the results.
891
            if kde:
1!
892
                bins = int(round(10. / s))
1✔
893
                n, b = np.histogram(x, bins=bins, weights=weights,
1✔
894
                                    range=np.sort(span[i]))
895
                x0 = 0.5 * (b[1:] + b[:-1])
1✔
896
                n = norm_kde(n, 10.)
1✔
897
                y0 = n
1✔
898
                ax.fill_between(x0, y0, color=color, **post_kwargs)
1✔
899
            else:
900
                bins = 40
×
901
                n, b = np.histogram(x, bins=bins, weights=weights,
×
902
                                    range=np.sort(span[i]))
903
                x0 = 0.5 * (b[1:] + b[:-1])
×
904
                y0 = n
×
905
                ax.fill_between(x0, y0, color=color, **post_kwargs)
×
906
        ax.set_ylim([0., max(y0) * 1.05])
1✔
907
        # Plot quantiles.
908
        if quantiles is not None and len(quantiles) > 0:
1!
909
            qs = _quantile(x, quantiles, weights=weights)
1✔
910
            for q in qs:
1✔
911
                ax.axvline(q, lw=2, ls="dashed", color=color)
1✔
912
            if verbose:
1!
913
                print("Quantiles:")
×
914
                print(labels[i], [blob for blob in zip(quantiles, qs)])
×
915
        # Add truth value(s).
916
        if truths is not None and truths[i] is not None:
1!
917
            try:
×
918
                [ax.axvline(t, color=truth_color, **truth_kwargs)
×
919
                 for t in truths[i]]
920
            except Exception:
×
921
                ax.axvline(truths[i], color=truth_color, **truth_kwargs)
×
922
        # Set titles.
923
        if show_titles:
1!
924
            title = None
×
925
            if title_fmt is not None:
×
926
                ql, qm, qh = _quantile(x, [0.025, 0.5, 0.975], weights=weights)
×
927
                q_minus, q_plus = qm - ql, qh - qm
×
928
                fmt = "{{0:{0}}}".format(title_fmt).format
×
929
                title = r"${{{0}}}_{{-{1}}}^{{+{2}}}$"
×
930
                title = title.format(fmt(qm), fmt(q_minus), fmt(q_plus))
×
931
                title = "{0} = {1}".format(labels[i], title)
×
932
                ax.set_title(title, **title_kwargs)
×
933

934
    return fig, axes
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc