• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

JohannesBuchner / UltraNest / efd8fd56-40f1-4ae5-8e75-d7c23dd7b563

26 Oct 2024 03:12PM UTC coverage: 74.189% (-0.02%) from 74.207%
efd8fd56-40f1-4ae5-8e75-d7c23dd7b563

push

circleci

JohannesBuchner
doc: make pydocstyle happy

1235 of 1930 branches covered (63.99%)

Branch coverage included in aggregate %.

4028 of 5164 relevant lines covered (78.0%)

0.78 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

65.1
/ultranest/plot.py
1
# noqa: D400 D205
2
"""
3
Plotting utilities
4
------------------
5

6
"""
7

8
from __future__ import division, print_function
1✔
9

10
import logging
1✔
11
import types
1✔
12
import warnings
1✔
13

14
import matplotlib.pyplot as pl
1✔
15
import matplotlib.pyplot as plt
1✔
16
import numpy
1✔
17
import numpy as np
1✔
18
import scipy.stats
1✔
19
# from matplotlib.colors import LinearSegmentedColormap, colorConverter
20
from matplotlib.ticker import MaxNLocator, NullLocator, ScalarFormatter
1✔
21
from six.moves import range
1✔
22

23
from .utils import quantile as _quantile
1✔
24
from .utils import resample_equal
1✔
25

26
try:
1✔
27
    str_type = types.StringTypes
1✔
28
    float_type = types.FloatType
×
29
    int_type = types.IntType
×
30
except Exception:
1✔
31
    str_type = str
1✔
32
    float_type = float
1✔
33
    int_type = int
1✔
34

35
import corner
1✔
36

37
__all__ = ["runplot", "cornerplot", "traceplot", "PredictionBand"]
1✔
38

39

40
def cornerplot(
1✔
41
    results, min_weight=1e-4, with_legend=True, logger=None,
42
    levels=[0.9973, 0.9545, 0.6827, 0.3934],
43
    plot_datapoints=False, plot_density=False, show_titles=True, quiet=True,
44
    contour_kwargs=dict(linestyles=['-','-.',':','--'], colors=['navy','navy','navy','purple']),
45
    color='purple', quantiles=[0.15866, 0.5, 0.8413], **corner_kwargs
46
):
47
    """Make a healthy corner plot with corner.
48

49
    Essentially does::
50

51
        paramnames = results['paramnames']
52
        data = results['weighted_samples']['points']
53
        weights = results['weighted_samples']['weights']
54

55
        return corner.corner(
56
            results['weighted_samples']['points'],
57
            weights=results['weighted_samples']['weights'],
58
            labels=results['paramnames'])
59

60
    Parameters
61
    ----------
62
    results: dict
63
        data dictionary
64
    min_weight: float
65
        cut off low-weight posterior points. Avoids meaningless
66
        stragglers when plot_datapoints is True.
67
    with_legend: bool
68
        whether to add a legend to show meaning of the lines.
69
    logger: None | object
70
        where to log
71
    levels: list
72
        list of credible interval levels
73
    plot_datapoints : bool
74
        Draw individual data points.
75
    plot_density : bool
76
        Draw the density colormap.
77
    show_titles : bool
78
        Displays a title above each 1-D histogram showing the 0.5 quantile
79
        with the upper and lower errors supplied by the quantiles argument.
80
    quiet : bool
81
        If true, suppress warnings for small datasets.
82
    contour_kwargs : dict
83
        Any additional keyword arguments to pass to the `contour` method.
84
    color : str
85
        ``matplotlib`` style color for all histograms.
86
    quantiles: list
87
        fractional quantiles to show on the 1-D histograms as vertical dashed lines.
88
    **corner_kwargs: dict
89
        Any remaining keyword arguments are sent to :func:`corner.corner`.
90

91
    Returns
92
    -------
93
    fig : `~matplotlib.figure.Figure`
94
        The ``matplotlib`` figure instance for the corner plot.
95

96
    """
97
    paramnames = results['paramnames']
1✔
98
    data = np.array(results['weighted_samples']['points'])
1✔
99
    weights = np.array(results['weighted_samples']['weights'])
1✔
100
    cumsumweights = np.cumsum(weights)
1✔
101

102
    mask = cumsumweights > min_weight
1✔
103

104
    if mask.sum() == 1:
1!
105
        if logger is not None:
×
106
            warn = 'Posterior is still concentrated in a single point:'
×
107
            for i, p in enumerate(paramnames):
×
108
                v = results['samples'][mask,i]
×
109
                warn += "\n" + '    %-20s: %s' % (p, v)
×
110

111
            logger.warning(warn)
×
112
            logger.info('Try running longer.')
×
113
        return
×
114

115
    # monkey patch to disable a useless warning
116
    oldfunc = logging.warning
1✔
117
    logging.warning = lambda *args, **kwargs: None
1!
118
    fig = corner.corner(
1✔
119
        data[mask,:], weights=weights[mask],
120
        labels=paramnames, show_titles=show_titles, quiet=quiet,
121
        plot_datapoints=plot_datapoints, plot_density=plot_density,
122
        levels=levels, quantiles=quantiles,
123
        contour_kwargs=contour_kwargs, color=color, **corner_kwargs
124
    )
125
    # Create legend handles
126
    if with_legend and data.shape[1] > 1:
1✔
127
        legend_handles = [
1✔
128
            plt.Line2D(
129
                [0], [0], linestyle='--', color=color,
130
                label='%.1f%% marginal' % (100 * (quantiles[-1] - quantiles[0]))),
131
        ] + [plt.Line2D(
132
            [0], [0], linestyle=ls, color=linecolor,
133
            label='%.1f%%' % (100 * level))
134
            for ls, linecolor, level in zip(
135
                contour_kwargs.get('linestyles', [])[::-1],
136
                contour_kwargs.get('colors', [color] * 100)[::-1],
137
                levels[::-1])
138
        ]
139
        if len(legend_handles) == len(levels) + 1 and len(legend_handles) > 0:
1!
140
            plt.legend(
1✔
141
                title='credible prob level',
142
                handles=legend_handles,
143
                loc='lower right', bbox_to_anchor=(1.01,1.2), frameon=False
144
            )
145
    logging.warning = oldfunc
1✔
146
    return fig
1✔
147

148

149
def highest_density_interval_from_samples(xsamples, xlo=None, xhi=None, probability_level=0.68):
1✔
150
    """
151
    Compute the highest density interval (HDI) from posterior samples.
152

153
    Parameters
154
    ----------
155
    xsamples : array_like
156
        The posterior samples from which to compute the HDI.
157
    xlo : float or None, optional
158
        Lower boundary limiting the space. Default is None.
159
    xhi : float or None, optional
160
        Upper boundary limiting the space. Default is None.
161
    probability_level : float, optional
162
        The desired probability level for the HDI. Default is 0.68.
163

164
    Returns
165
    -------
166
    x_MAP: float
167
        maximum a posteriori (MAP) estimate.
168
    xerrlo: float
169
        lower uncertainty (lower HDI bound minus x_MAP).
170
    xerrhi: float
171
        upper uncertainty (x_MAP minus upper HDI bound).
172

173
    Notes
174
    -----
175
    The function starts at the highest density point and accumulates neighboring points
176
    until the specified probability level is reached. If `xlo` or `xhi` is provided,
177
    the HDI is constrained within these bounds.
178

179
    Requires getdist to be installed for a kernel density estimation.
180

181
    For uniform distributions, this function will give unpredictable results for the MAP.
182

183
    Examples
184
    --------
185
    >>> xsamples = np.random.normal(loc=0, scale=1, size=100000)
186
    >>> hdi = highest_density_interval_from_samples(xsamples)
187
    >>> print('x = %.1f + %.2f - %.2f' % hdi)
188
    x = 0.0 + 1.02 - 0.96
189
    """
190
    import getdist.chains
1✔
191
    from getdist.mcsamples import MCSamples
1✔
192
    getdist.chains.print_load_details = False
1✔
193
    samples = MCSamples(
1✔
194
        samples=xsamples, names=['x'], ranges={'x':[xlo,xhi]},
195
        settings=dict(mult_bias_correction_order=1))
196
    samples.raise_on_bandwidth_errors = True
1✔
197
    density_bounded = samples.get1DDensityGridData('x')
1✔
198

199
    x = density_bounded.x
1✔
200
    y = density_bounded.P / np.sum(density_bounded.P)
1✔
201

202
    # Sort the y values in descending order
203
    sorted_indices = np.argsort(y)[::-1]
1✔
204

205
    # define MAP as the peak. This works well if the peak is declining to both sides
206
    MAP = x[sorted_indices[0]]
1✔
207
    total_probability = y[sorted_indices[0]]
1✔
208
    i_lo = sorted_indices[0]
1✔
209
    i_hi = sorted_indices[0]
1✔
210
    for i in sorted_indices[1:]:
1!
211
        # Add the current probability to the total
212
        i_lo = min(i_lo, i)
1✔
213
        i_hi = max(i_hi, i)
1✔
214
        total_probability = y[i_lo:i_hi + 1].sum()
1✔
215
        # Check if the total probability exceeds or equals the desired level
216
        if total_probability >= probability_level:
1✔
217
            break
1✔
218

219
    x_lo = x[i_lo]
1✔
220
    x_hi = x[i_hi]
1✔
221
    return MAP, MAP - x_lo, x_hi - MAP
1✔
222

223

224
class PredictionBand:
1✔
225
    """Plot bands of model predictions as calculated from a chain.
226

227
    call add(y) to add predictions from each chain point
228

229
    .. testsetup::
230

231
        import numpy
232
        chain = numpy.random.uniform(size=(20, 2))
233

234
    .. testcode::
235

236
        x = numpy.linspace(0, 1, 100)
237
        band = PredictionBand(x)
238
        for c in chain:
239
            band.add(c[0] * x + c[1])
240
        # add median line. As an option a matplotlib ax can be given.
241
        band.line(color='k')
242
        # add 1 sigma quantile
243
        band.shade(color='k', alpha=0.3)
244
        # add wider quantile
245
        band.shade(q=0.01, color='gray', alpha=0.1)
246
        plt.show()
247

248
    To plot onto a specific axis, use `band.line(..., ax=myaxis)`.
249
    """
250

251
    def __init__(self, x, shadeargs={}, lineargs={}):
1✔
252
        """Initialise.
253

254
        Parameters
255
        ----------
256
        x: array
257
            Independent variable.
258
        shadeargs: dict
259
            default arguments for shade function.
260
        lineargs: dict
261
            default arguments for line function.
262
        """
263
        self.x = x
1✔
264
        self.ys = []
1✔
265
        self.shadeargs = shadeargs
1✔
266
        self.lineargs = lineargs
1✔
267

268
    def add(self, y):
1✔
269
        """Add a possible prediction *y*."""
270
        self.ys.append(y)
1✔
271

272
    def set_shadeargs(self, **kwargs):
1✔
273
        """Set matplotlib style for shading."""
274
        self.shadeargs = kwargs
×
275

276
    def set_lineargs(self, **kwargs):
1✔
277
        """Set matplotlib style for line."""
278
        self.lineargs = kwargs
×
279

280
    def get_line(self, q=0.5):
1✔
281
        """Over prediction space x, get quantile *q*. Default is median."""
282
        if not 0 <= q <= 1:
1!
283
            raise ValueError("quantile q must be between 0 and 1, not %s" % q)
×
284
        assert len(self.ys) > 0, self.ys
1✔
285
        return scipy.stats.mstats.mquantiles(self.ys, q, axis=0)[0]
1✔
286

287
    def shade(self, q=0.341, ax=None, **kwargs):
1✔
288
        """Plot a shaded region between 0.5-q and 0.5+q, by default 1 sigma."""
289
        if not 0 <= q <= 0.5:
1✔
290
            raise ValueError("quantile distance from the median, q, must be between 0 and 0.5, not %s. For a 99%% quantile range, use q=0.48." % q)
1✔
291
        shadeargs = dict(self.shadeargs)
1✔
292
        shadeargs.update(kwargs)
1✔
293
        lo = self.get_line(0.5 - q)
1✔
294
        hi = self.get_line(0.5 + q)
1✔
295
        if ax is None:
1✔
296
            ax = plt
1✔
297
        return ax.fill_between(self.x, lo, hi, **shadeargs)
1✔
298

299
    def line(self, ax=None, **kwargs):
1✔
300
        """Plot the median curve."""
301
        lineargs = dict(self.lineargs)
1✔
302
        lineargs.update(kwargs)
1✔
303
        mid = self.get_line(0.5)
1✔
304
        if ax is None:
1✔
305
            ax = plt
1✔
306
        return ax.plot(self.x, mid, **lineargs)
1✔
307

308

309
# the following function is taken from https://github.com/joshspeagle/dynesty/blob/master/dynesty/plotting.py
310
# Copyright (c) 2017 - Present: Josh Speagle and contributors.
311
# Copyright (c) 2014 - 2017: Kyle Barbary and contributors.
312
# https://github.com/joshspeagle/dynesty/blob/master/LICENSE
313
def runplot(results, span=None, logplot=False, kde=True, nkde=1000,
1✔
314
            color='blue', plot_kwargs=None, label_kwargs=None, lnz_error=True,
315
            lnz_truth=None, truth_color='red', truth_kwargs=None,
316
            max_x_ticks=8, max_y_ticks=3, use_math_text=True,
317
            mark_final_live=True, fig=None
318
            ):
319
    """Plot live points, ln(likelihood), ln(weight), and ln(evidence) vs. ln(prior volume).
320

321
    Parameters
322
    ----------
323
    results : dynesty.results.Results instance
324
        dynesty.results.Results instance from a nested
325
        sampling run.
326
    span : iterable with shape (4,), optional
327
        A list where each element is either a length-2 tuple containing
328
        lower and upper bounds *or* a float from `(0., 1.]` giving the
329
        fraction below the maximum. If a fraction is provided,
330
        the bounds are chosen to be equal-tailed. An example would be::
331

332
            span = [(0., 10.), 0.001, 0.2, (5., 6.)]
333

334
        Default is `(0., 1.05 * max(data))` for each element.
335
    logplot : bool, optional
336
        Whether to plot the evidence on a log scale. Default is `False`.
337
    kde : bool, optional
338
        Whether to use kernel density estimation to estimate and plot
339
        the PDF of the importance weights as a function of log-volume
340
        (as opposed to the importance weights themselves). Default is
341
        `True`.
342
    nkde : int, optional
343
        The number of grid points used when plotting the kernel density
344
        estimate. Default is `1000`.
345
    color : str or iterable with shape (4,), optional
346
        A `~matplotlib`-style color (either a single color or a different
347
        value for each subplot) used when plotting the lines in each subplot.
348
        Default is `'blue'`.
349
    plot_kwargs : dict, optional
350
        Extra keyword arguments that will be passed to `plot`.
351
    label_kwargs : dict, optional
352
        Extra keyword arguments that will be sent to the
353
        `~matplotlib.axes.Axes.set_xlabel` and
354
        `~matplotlib.axes.Axes.set_ylabel` methods.
355
    lnz_error : bool, optional
356
        Whether to plot the 1, 2, and 3-sigma approximate error bars
357
        derived from the ln(evidence) error approximation over the course
358
        of the run. Default is True.
359
    lnz_truth : float, optional
360
        A reference value for the evidence that will be overplotted on the
361
        evidence subplot if provided.
362
    truth_color : str or iterable with shape (ndim,), optional
363
        A `~matplotlib`-style color used when plotting `lnz_truth`.
364
        Default is `'red'`.
365
    truth_kwargs : dict, optional
366
        Extra keyword arguments that will be used for plotting
367
        `lnz_truth`.
368
    max_x_ticks : int, optional
369
        Maximum number of ticks allowed for the x axis. Default is `8`.
370
    max_y_ticks : int, optional
371
        Maximum number of ticks allowed for the y axis. Default is `4`.
372
    use_math_text : bool, optional
373
        Whether the axis tick labels for very large/small exponents should be
374
        displayed as powers of 10 rather than using `e`. Default is `False`.
375
    mark_final_live : bool, optional
376
        Whether to indicate the final addition of recycled live points
377
        (if they were added to the resulting samples) using
378
        a dashed vertical line. Default is `True`.
379
    fig : (`~matplotlib.figure.Figure`, `~matplotlib.axes.Axes`), optional
380
        If provided, overplot the run onto the provided figure.
381
        Otherwise, by default an internal figure is generated.
382

383
    Returns
384
    -------
385
    runplot : (`~matplotlib.figure.Figure`, `~matplotlib.axes.Axes`)
386
        Output summary plot.
387

388
    """
389
    # Initialize values.
390
    if label_kwargs is None:
1!
391
        label_kwargs = dict()
1✔
392
    if plot_kwargs is None:
1!
393
        plot_kwargs = dict()
1✔
394
    if truth_kwargs is None:
1!
395
        truth_kwargs = dict()
1✔
396

397
    # Set defaults.
398
    plot_kwargs['linewidth'] = plot_kwargs.get('linewidth', 5)
1✔
399
    plot_kwargs['alpha'] = plot_kwargs.get('alpha', 0.7)
1✔
400
    truth_kwargs['linestyle'] = truth_kwargs.get('linestyle', 'solid')
1✔
401
    truth_kwargs['linewidth'] = truth_kwargs.get('linewidth', 3)
1✔
402

403
    # Extract results.
404
    niter = results['niter']  # number of iterations
1✔
405
    logvol = results['logvol']  # ln(prior volume)
1✔
406
    logl = results['logl'] - max(results['logl'])  # ln(normalized likelihood)
1✔
407
    logwt = results['logwt'] - results['logz'][-1]  # ln(importance weight)
1✔
408
    logz = results['logz']  # ln(evidence)
1✔
409
    logzerr = results['logzerr']  # error in ln(evidence)
1✔
410
    weights = results['weights']
1✔
411
    logzerr[~np.isfinite(logzerr)] = 0.
1✔
412
    nsamps = len(logwt)  # number of samples
1✔
413

414
    # Check whether the run was "static" or "dynamic".
415
    try:
1✔
416
        nlive = results['samples_n']
1✔
417
        mark_final_live = False
1✔
418
    except Exception:
×
419
        nlive = np.ones(niter) * results['nlive']
×
420
        if nsamps - niter == results['nlive']:
×
421
            nlive_final = np.arange(1, results['nlive'] + 1)[::-1]
×
422
            nlive = np.append(nlive, nlive_final)
×
423

424
    # Check if the final set of live points were added to the results.
425
    if mark_final_live:
1!
426
        if nsamps - niter == results['nlive']:
×
427
            live_idx = niter
×
428
        else:
429
            warnings.warn("The number of iterations and samples differ "
×
430
                          "by an amount that isn't the number of final "
431
                          "live points. `mark_final_live` has been disabled.",
432
                          stacklevel=3)
433
            mark_final_live = False
×
434

435
    # Determine plotting bounds for each subplot.
436
    data = [nlive, np.exp(logl), weights, logz if logplot else np.exp(logz)]
1✔
437

438
    kde = kde and (weights * len(logvol) > 0.1).sum() > 10
1✔
439
    if kde:
1!
440
        try:
1✔
441
            # from scipy.ndimage import gaussian_filter as norm_kde
442
            from scipy.stats import gaussian_kde
1✔
443

444
            # Derive kernel density estimate.
445
            wt_kde = gaussian_kde(resample_equal(-logvol, weights))  # KDE
1✔
446
            logvol_new = np.linspace(logvol[0], logvol[-1], nkde)  # resample
1✔
447
            data[2] = wt_kde.pdf(-logvol_new)  # evaluate KDE PDF
1✔
448
        except ImportError:
×
449
            kde = False
×
450

451
    if span is None:
1!
452
        span = [(0., 1.05 * max(d)) for d in data]
1✔
453
        no_span = True
1✔
454
    else:
455
        no_span = False
×
456
    span = list(span)
1✔
457
    if len(span) != 4:
1!
458
        raise ValueError("More bounds provided in `span` than subplots!")
×
459
    for i, _ in enumerate(span):
1✔
460
        try:
1✔
461
            ymin, ymax = span[i]
1✔
462
        except Exception:
×
463
            span[i] = (max(data[i]) * span[i], max(data[i]))
×
464
    if lnz_error and no_span:
1!
465
        if logplot:
1!
466
            zspan = (logz[-1] - 10.3 * 3. * logzerr[-1],
1✔
467
                     logz[-1] + 1.3 * 3. * logzerr[-1])
468
        else:
469
            zspan = (0., 1.05 * np.exp(logz[-1] + 3. * logzerr[-1]))
×
470
        span[3] = zspan
1✔
471

472
    # Setting up default plot layout.
473
    if fig is None:
1!
474
        fig, axes = pl.subplots(4, 1, figsize=(16, 16))
1✔
475
        xspan = [(0., -min(logvol)) for _ax in axes]
1✔
476
        yspan = span
1✔
477
    else:
478
        fig, axes = fig
×
479
        try:
×
480
            axes.reshape(4, 1)
×
481
        except Exception:
×
482
            raise ValueError("Provided axes do not match the required shape "
×
483
                             "for plotting samples.")
484
        # If figure is provided, keep previous bounds if they were larger.
485
        xspan = [ax.get_xlim() for ax in axes]
×
486
        yspan = [ax.get_ylim() for ax in axes]
×
487
        # One exception: if the bounds are the plotting default `(0., 1.)`,
488
        # overwrite them.
489
        xspan = [t if t != (0., 1.) else (None, None) for t in xspan]
×
490
        yspan = [t if t != (0., 1.) else (None, None) for t in yspan]
×
491

492
    # Set up bounds for plotting.
493
    for i in range(4):
1✔
494
        if xspan[i][0] is None:
1!
495
            xmin = None
×
496
        else:
497
            xmin = min(0., xspan[i][0])
1✔
498
        if xspan[i][1] is None:
1!
499
            xmax = -min(logvol)
×
500
        else:
501
            xmax = max(-min(logvol), xspan[i][1])
1✔
502
        if yspan[i][0] is None:
1!
503
            ymin = None
×
504
        else:
505
            ymin = min(span[i][0], yspan[i][0])
1✔
506
        if yspan[i][1] is None:
1!
507
            ymax = span[i][1]
×
508
        else:
509
            ymax = max(span[i][1], yspan[i][1])
1✔
510
        axes[i].set_xlim([xmin, xmax])
1✔
511
        axes[i].set_ylim([ymin, ymax])
1✔
512

513
    # Plotting.
514
    labels = ['Live Points', 'Likelihood\n(normalized)',
1✔
515
              'Importance\nWeight', 'Evidence']
516
    if kde:
1!
517
        labels[2] += ' PDF'
1✔
518

519
    for i, d in enumerate(data):
1✔
520

521
        # Establish axes.
522
        ax = axes[i]
1✔
523
        # Set color(s)/colormap(s).
524
        if isinstance(color, str_type):
1!
525
            c = color
1✔
526
        else:
527
            c = color[i]
×
528
        # Setup axes.
529
        if max_x_ticks == 0:
1!
530
            ax.xaxis.set_major_locator(NullLocator())
×
531
        else:
532
            ax.xaxis.set_major_locator(MaxNLocator(max_x_ticks))
1✔
533
        if max_y_ticks == 0:
1!
534
            ax.yaxis.set_major_locator(NullLocator())
×
535
        else:
536
            ax.yaxis.set_major_locator(MaxNLocator(max_y_ticks))
1✔
537
        # Label axes.
538
        sf = ScalarFormatter(useMathText=use_math_text)
1✔
539
        ax.yaxis.set_major_formatter(sf)
1✔
540
        ax.set_xlabel(r"$-\ln X$", **label_kwargs)
1✔
541
        ax.set_ylabel(labels[i], **label_kwargs)
1✔
542
        # Plot run.
543
        if logplot and i == 3:
1✔
544
            ax.plot(-logvol, d, color=c, **plot_kwargs)
1✔
545
            yspan = [ax.get_ylim() for _ax in axes]
1✔
546
        elif kde and i == 2:
1✔
547
            ax.plot(-logvol_new, d, color=c, **plot_kwargs)
1✔
548
        else:
549
            ax.plot(-logvol, d, color=c, **plot_kwargs)
1✔
550
        if i == 3 and lnz_error:
1✔
551
            if logplot:
1!
552
                mask = logz >= ax.get_ylim()[0] - 10
1✔
553
                [ax.fill_between(-logvol[mask], (logz + s * logzerr)[mask],
1✔
554
                                 (logz - s * logzerr)[mask],
555
                                 color=c, alpha=0.2)
556
                 for s in range(1, 4)]
557
            else:
558
                [ax.fill_between(-logvol, np.exp(logz + s * logzerr),
×
559
                                 np.exp(logz - s * logzerr), color=c, alpha=0.2)
560
                 for s in range(1, 4)]
561
        # Mark addition of final live points.
562
        if mark_final_live:
1!
563
            ax.axvline(-logvol[live_idx], color=c, ls="dashed", lw=2,
×
564
                       **plot_kwargs)
565
            if i == 0:
×
566
                ax.axhline(live_idx, color=c, ls="dashed", lw=2,
×
567
                           **plot_kwargs)
568
        # Add truth value(s).
569
        if i == 3 and lnz_truth is not None:
1!
570
            if logplot:
×
571
                ax.axhline(lnz_truth, color=truth_color, **truth_kwargs)
×
572
            else:
573
                ax.axhline(np.exp(lnz_truth), color=truth_color, **truth_kwargs)
×
574

575
    return fig, axes
1✔
576

577

578
def traceplot(results, span=None, quantiles=[0.025, 0.5, 0.975], smooth=0.02,
1✔
579
              post_color='blue', post_kwargs=None, kde=True, nkde=1000,
580
              trace_cmap='plasma', trace_color=None, trace_kwargs=None,
581
              connect=False, connect_highlight=10, connect_color='red',
582
              connect_kwargs=None, max_n_ticks=5, use_math_text=False,
583
              labels=None, label_kwargs=None,
584
              show_titles=False, title_fmt=".2f", title_kwargs=None,
585
              truths=None, truth_color='red', truth_kwargs=None,
586
              verbose=False, fig=None):
587
    """Plot traces and marginalized posteriors for each parameter.
588

589
    Parameters
590
    ----------
591
    results : `~dynesty.results.Results` instance
592
        A `~dynesty.results.Results` instance from a nested
593
        sampling run. **Compatible with results derived from**
594
        `nestle <http://kylebarbary.com/nestle/>`_.
595
    span : iterable with shape (ndim,), optional
596
        A list where each element is either a length-2 tuple containing
597
        lower and upper bounds or a float from `(0., 1.]` giving the
598
        fraction of (weighted) samples to include. If a fraction is provided,
599
        the bounds are chosen to be equal-tailed. An example would be::
600

601
            span = [(0., 10.), 0.95, (5., 6.)]
602

603
        Default is `0.999999426697` (5-sigma credible interval) for each
604
        parameter.
605
    quantiles : iterable, optional
606
        A list of fractional quantiles to overplot on the 1-D marginalized
607
        posteriors as vertical dashed lines. Default is `[0.025, 0.5, 0.975]`
608
        (the 95%/2-sigma credible interval).
609
    smooth : float or iterable with shape (ndim,), optional
610
        The standard deviation (either a single value or a different value for
611
        each subplot) for the Gaussian kernel used to smooth the 1-D
612
        marginalized posteriors, expressed as a fraction of the span.
613
        Default is `0.02` (2% smoothing). If an integer is provided instead,
614
        this will instead default to a simple (weighted) histogram with
615
        `bins=smooth`.
616
    post_color : str or iterable with shape (ndim,), optional
617
        A `~matplotlib`-style color (either a single color or a different
618
        value for each subplot) used when plotting the histograms.
619
        Default is `'blue'`.
620
    post_kwargs : dict, optional
621
        Extra keyword arguments that will be used for plotting the
622
        marginalized 1-D posteriors.
623
    kde : bool, optional
624
        Whether to use kernel density estimation to estimate and plot
625
        the PDF of the importance weights as a function of log-volume
626
        (as opposed to the importance weights themselves). Default is
627
        `True`.
628
    nkde : int, optional
629
        The number of grid points used when plotting the kernel density
630
        estimate. Default is `1000`.
631
    trace_cmap : str or iterable with shape (ndim,), optional
632
        A `~matplotlib`-style colormap (either a single colormap or a
633
        different colormap for each subplot) used when plotting the traces,
634
        where each point is colored according to its weight. Default is
635
        `'plasma'`.
636
    trace_color : str or iterable with shape (ndim,), optional
637
        A `~matplotlib`-style color (either a single color or a
638
        different color for each subplot) used when plotting the traces.
639
        This overrides the `trace_cmap` option by giving all points
640
        the same color. Default is `None` (not used).
641
    trace_kwargs : dict, optional
642
        Extra keyword arguments that will be used for plotting the traces.
643
    connect : bool, optional
644
        Whether to draw lines connecting the paths of unique particles.
645
        Default is `False`.
646
    connect_highlight : int or iterable, optional
647
        If `connect=True`, highlights the paths of a specific set of
648
        particles. If an integer is passed, :data:`connect_highlight`
649
        random particle paths will be highlighted. If an iterable is passed,
650
        then the particle paths corresponding to the provided indices
651
        will be highlighted.
652
    connect_color : str, optional
653
        The color of the highlighted particle paths. Default is `'red'`.
654
    connect_kwargs : dict, optional
655
        Extra keyword arguments used for plotting particle paths.
656
    max_n_ticks : int, optional
657
        Maximum number of ticks allowed. Default is `5`.
658
    use_math_text : bool, optional
659
        Whether the axis tick labels for very large/small exponents should be
660
        displayed as powers of 10 rather than using `e`. Default is `False`.
661
    labels : iterable with shape (ndim,), optional
662
        A list of names for each parameter. If not provided, the default name
663
        used when plotting will follow :math:`x_i` style.
664
    label_kwargs : dict, optional
665
        Extra keyword arguments that will be sent to the
666
        `~matplotlib.axes.Axes.set_xlabel` and
667
        `~matplotlib.axes.Axes.set_ylabel` methods.
668
    show_titles : bool, optional
669
        Whether to display a title above each 1-D marginalized posterior
670
        showing the 0.5 quantile along with the upper/lower bounds associated
671
        with the 0.025 and 0.975 (95%/2-sigma credible interval) quantiles.
672
        Default is `True`.
673
    title_fmt : str, optional
674
        The format string for the quantiles provided in the title. Default is
675
        `'.2f'`.
676
    title_kwargs : dict, optional
677
        Extra keyword arguments that will be sent to the
678
        `~matplotlib.axes.Axes.set_title` command.
679
    truths : iterable with shape (ndim,), optional
680
        A list of reference values that will be overplotted on the traces and
681
        marginalized 1-D posteriors as solid horizontal/vertical lines.
682
        Individual values can be exempt using `None`. Default is `None`.
683
    truth_color : str or iterable with shape (ndim,), optional
684
        A `~matplotlib`-style color (either a single color or a different
685
        value for each subplot) used when plotting `truths`.
686
        Default is `'red'`.
687
    truth_kwargs : dict, optional
688
        Extra keyword arguments that will be used for plotting the vertical
689
        and horizontal lines with `truths`.
690
    verbose : bool, optional
691
        Whether to print the values of the computed quantiles associated with
692
        each parameter. Default is `False`.
693
    fig : (`~matplotlib.figure.Figure`, `~matplotlib.axes.Axes`), optional
694
        If provided, overplot the traces and marginalized 1-D posteriors
695
        onto the provided figure. Otherwise, by default an
696
        internal figure is generated.
697

698
    Returns
699
    -------
700
    traceplot : (`~matplotlib.figure.Figure`, `~matplotlib.axes.Axes`)
701
        Output trace plot.
702

703
    """
704
    # Initialize values.
705
    if title_kwargs is None:
1!
706
        title_kwargs = dict()
1✔
707
    if label_kwargs is None:
1!
708
        label_kwargs = dict()
1✔
709
    if trace_kwargs is None:
1!
710
        trace_kwargs = dict()
1✔
711
    if connect_kwargs is None:
1!
712
        connect_kwargs = dict()
1✔
713
    if post_kwargs is None:
1!
714
        post_kwargs = dict()
1✔
715
    if truth_kwargs is None:
1!
716
        truth_kwargs = dict()
1✔
717

718
    # Set defaults.
719
    connect_kwargs['alpha'] = connect_kwargs.get('alpha', 0.7)
1✔
720
    post_kwargs['alpha'] = post_kwargs.get('alpha', 0.6)
1✔
721
    trace_kwargs['s'] = trace_kwargs.get('s', 3)
1✔
722
    truth_kwargs['linestyle'] = truth_kwargs.get('linestyle', 'solid')
1✔
723
    truth_kwargs['linewidth'] = truth_kwargs.get('linewidth', 2)
1✔
724

725
    # Extract weighted samples.
726
    samples = results['samples']
1✔
727
    logvol = results['logvol']
1✔
728
    weights = results['weights']
1✔
729

730
    wts = weights
1✔
731
    kde = kde and (weights * len(logvol) > 0.1).sum() > 10
1✔
732
    if kde:
1!
733
        try:
1✔
734
            from scipy.ndimage import gaussian_filter as norm_kde
1✔
735
            from scipy.stats import gaussian_kde
1✔
736

737
            # Derive kernel density estimate.
738
            wt_kde = gaussian_kde(resample_equal(-logvol, weights))  # KDE
1✔
739
            logvol_grid = np.linspace(logvol[0], logvol[-1], nkde)  # resample
1✔
740
            wt_grid = wt_kde.pdf(-logvol_grid)  # evaluate KDE PDF
1✔
741
            wts = np.interp(-logvol, -logvol_grid, wt_grid)  # interpolate
1✔
742
        except ImportError:
×
743
            kde = False
×
744

745
    # Deal with 1D results. A number of extra catches are also here
746
    # in case users are trying to plot other results besides the `Results`
747
    # instance generated by `dynesty`.
748
    samples = np.atleast_1d(samples)
1✔
749
    if len(samples.shape) == 1:
1!
750
        samples = np.atleast_2d(samples)
×
751
    else:
752
        assert len(samples.shape) == 2, "Samples must be 1- or 2-D."
1✔
753
        samples = samples.T
1✔
754
    assert samples.shape[0] <= samples.shape[1], "There are more dimensions than samples!"
1✔
755
    ndim, nsamps = samples.shape
1✔
756

757
    # Check weights.
758
    if weights.ndim != 1:
1!
759
        raise ValueError("Weights must be 1-D.")
×
760
    if nsamps != weights.shape[0]:
1!
761
        raise ValueError("The number of weights and samples disagree!")
×
762

763
    # Check ln(volume).
764
    if logvol.ndim != 1:
1!
765
        raise ValueError("Ln(volume)'s must be 1-D.")
×
766
    if nsamps != logvol.shape[0]:
1!
767
        raise ValueError("The number of ln(volume)'s and samples disagree!")
×
768

769
    # Check sample IDs.
770
    if connect:
1!
771
        try:
×
772
            samples_id = results['samples_id']
×
773
            uid = np.unique(samples_id)
×
774
        except Exception:
×
775
            raise ValueError("Sample IDs are not defined!")
×
776
        try:
×
777
            ids = connect_highlight[0]
×
778
            ids = connect_highlight
×
779
        except Exception:
×
780
            ids = np.random.choice(uid, size=connect_highlight, replace=False)
×
781

782
    # Determine plotting bounds for marginalized 1-D posteriors.
783
    if span is None:
1!
784
        span = [0.999999426697 for i in range(ndim)]
1✔
785
    span = list(span)
1✔
786
    if len(span) != ndim:
1!
787
        raise ValueError("Dimension mismatch between samples and span.")
×
788
    for i, _ in enumerate(span):
1✔
789
        try:
1✔
790
            xmin, xmax = span[i]
1✔
791
        except Exception:
1✔
792
            q = [0.5 - 0.5 * span[i], 0.5 + 0.5 * span[i]]
1✔
793
            span[i] = _quantile(samples[i], q, weights=weights)
1✔
794

795
    # Setting up labels.
796
    if labels is None:
1!
797
        labels = [r"$x_{%d}$" % (i + 1) for i in range(ndim)]
×
798

799
    # Setting up smoothing.
800
    if (isinstance(smooth, int_type) or isinstance(smooth, float_type)):  # noqa: SIM101
1!
801
        smooth = [smooth for i in range(ndim)]
1✔
802

803
    # Setting up default plot layout.
804
    if fig is None:
1!
805
        fig, axes = pl.subplots(ndim, 2, figsize=(12, 3 * ndim))
1✔
806
    else:
807
        fig, axes = fig
×
808
        try:
×
809
            axes.reshape(ndim, 2)
×
810
        except Exception:
×
811
            raise ValueError("Provided axes do not match the required shape "
×
812
                             "for plotting samples.")
813

814
    # Plotting.
815
    for i, x in enumerate(samples):
1✔
816

817
        # Plot trace.
818

819
        # Establish axes.
820
        if np.shape(samples)[0] == 1:
1✔
821
            ax = axes[1]
1✔
822
        else:
823
            ax = axes[i, 0]
1✔
824
        # Set color(s)/colormap(s).
825
        if trace_color is not None:
1!
826
            if isinstance(trace_color, str_type):
×
827
                color = trace_color
×
828
            else:
829
                color = trace_color[i]
×
830
        else:
831
            color = wts
1✔
832
        if isinstance(trace_cmap, str_type):
1!
833
            cmap = trace_cmap
1✔
834
        else:
835
            cmap = trace_cmap[i]
×
836
        # Setup axes.
837
        ax.set_xlim([0., -min(logvol)])
1✔
838
        ax.set_ylim([min(x), max(x)])
1✔
839
        if max_n_ticks == 0:
1!
840
            ax.xaxis.set_major_locator(NullLocator())
×
841
            ax.yaxis.set_major_locator(NullLocator())
×
842
        else:
843
            ax.xaxis.set_major_locator(MaxNLocator(max_n_ticks))
1✔
844
            ax.yaxis.set_major_locator(MaxNLocator(max_n_ticks))
1✔
845
        # Label axes.
846
        sf = ScalarFormatter(useMathText=use_math_text)
1✔
847
        ax.yaxis.set_major_formatter(sf)
1✔
848
        ax.set_xlabel(r"$-\ln X$", **label_kwargs)
1✔
849
        ax.set_ylabel(labels[i], **label_kwargs)
1✔
850
        # Generate scatter plot.
851
        ax.scatter(-logvol, x, c=color, cmap=cmap, **trace_kwargs)
1✔
852
        if connect:
1!
853
            # Add lines highlighting specific particle paths.
854
            for j in ids:
×
855
                sel = (samples_id == j)
×
856
                ax.plot(-logvol[sel], x[sel], color=connect_color,
×
857
                        **connect_kwargs)
858
        # Add truth value(s).
859
        if truths is not None and truths[i] is not None:
1!
860
            try:
×
861
                [ax.axhline(t, color=truth_color, **truth_kwargs)
×
862
                 for t in truths[i]]
863
            except Exception:
×
864
                ax.axhline(truths[i], color=truth_color, **truth_kwargs)
×
865

866
        # Plot marginalized 1-D posterior.
867

868
        # Establish axes.
869
        if np.shape(samples)[0] == 1:
1✔
870
            ax = axes[0]
1✔
871
        else:
872
            ax = axes[i, 1]
1✔
873
        # Set color(s).
874
        if isinstance(post_color, str_type):
1!
875
            color = post_color
1✔
876
        else:
877
            color = post_color[i]
×
878
        # Setup axes
879
        ax.set_xlim(span[i])
1✔
880
        if max_n_ticks == 0:
1!
881
            ax.xaxis.set_major_locator(NullLocator())
×
882
            ax.yaxis.set_major_locator(NullLocator())
×
883
        else:
884
            ax.xaxis.set_major_locator(MaxNLocator(max_n_ticks))
1✔
885
            ax.yaxis.set_major_locator(NullLocator())
1✔
886
        # Label axes.
887
        sf = ScalarFormatter(useMathText=use_math_text)
1✔
888
        ax.xaxis.set_major_formatter(sf)
1✔
889
        ax.set_xlabel(labels[i], **label_kwargs)
1✔
890
        # Generate distribution.
891
        s = smooth[i]
1✔
892
        if isinstance(s, int_type):
1!
893
            # If `s` is an integer, plot a weighted histogram with
894
            # `s` bins within the provided bounds.
895
            n, b, _ = ax.hist(x, bins=s, weights=weights, color=color,
×
896
                              range=np.sort(span[i]), **post_kwargs)
897
            x0 = np.array(list(zip(b[:-1], b[1:]))).flatten()
×
898
            y0 = np.array(list(zip(n, n))).flatten()
×
899
        else:
900
            # If `s` is a float, oversample the data relative to the
901
            # smoothing filter by a factor of 10, then use a Gaussian
902
            # filter to smooth the results.
903
            if kde:
1!
904
                bins = int(round(10. / s))
1✔
905
                n, b = np.histogram(x, bins=bins, weights=weights,
1✔
906
                                    range=np.sort(span[i]))
907
                x0 = 0.5 * (b[1:] + b[:-1])
1✔
908
                n = norm_kde(n, 10.)
1✔
909
                y0 = n
1✔
910
                ax.fill_between(x0, y0, color=color, **post_kwargs)
1✔
911
            else:
912
                bins = 40
×
913
                n, b = np.histogram(x, bins=bins, weights=weights,
×
914
                                    range=np.sort(span[i]))
915
                x0 = 0.5 * (b[1:] + b[:-1])
×
916
                y0 = n
×
917
                ax.fill_between(x0, y0, color=color, **post_kwargs)
×
918
        ax.set_ylim([0., max(y0) * 1.05])
1✔
919
        # Plot quantiles.
920
        if quantiles is not None and len(quantiles) > 0:
1!
921
            qs = _quantile(x, quantiles, weights=weights)
1✔
922
            for q in qs:
1✔
923
                ax.axvline(q, lw=2, ls="dashed", color=color)
1✔
924
            if verbose:
1!
925
                print("Quantiles:")
×
926
                print(labels[i], [blob for blob in zip(quantiles, qs)])
×
927
        # Add truth value(s).
928
        if truths is not None and truths[i] is not None:
1!
929
            try:
×
930
                [ax.axvline(t, color=truth_color, **truth_kwargs)
×
931
                 for t in truths[i]]
932
            except Exception:
×
933
                ax.axvline(truths[i], color=truth_color, **truth_kwargs)
×
934
        # Set titles.
935
        if show_titles:
1!
936
            title = None
×
937
            if title_fmt is not None:
×
938
                ql, qm, qh = _quantile(x, [0.025, 0.5, 0.975], weights=weights)
×
939
                q_minus, q_plus = qm - ql, qh - qm
×
940
                fmt = "{{0:{0}}}".format(title_fmt).format
×
941
                title = r"${{{0}}}_{{-{1}}}^{{+{2}}}$"
×
942
                title = title.format(fmt(qm), fmt(q_minus), fmt(q_plus))
×
943
                title = "{0} = {1}".format(labels[i], title)
×
944
                ax.set_title(title, **title_kwargs)
×
945

946
    return fig, axes
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc