• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

JohannesBuchner / UltraNest / 9f2dd4f6-0775-47e9-b700-af647027ebfa

22 Apr 2024 12:51PM UTC coverage: 74.53% (+0.3%) from 74.242%
9f2dd4f6-0775-47e9-b700-af647027ebfa

push

circleci

web-flow
Merge pull request #118 from njzifjoiez/fixed-size-vectorised-slice-sampler

vectorised slice sampler of fixed batch size

1329 of 2026 branches covered (65.6%)

Branch coverage included in aggregate %.

79 of 80 new or added lines in 1 file covered. (98.75%)

1 existing line in 1 file now uncovered.

4026 of 5159 relevant lines covered (78.04%)

0.78 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

77.73
/ultranest/integrator.py
1
"""
2
Nested sampling integrators
3
---------------------------
4

5
This module provides the high-level class :py:class:`ReactiveNestedSampler`,
6
for calculating the Bayesian evidence and posterior samples of arbitrary models.
7

8
"""
9

10
# Some parts are from the Nestle library by Kyle Barbary (https://github.com/kbarbary/nestle)
11
# Some parts are from the nnest library by Adam Moss (https://github.com/adammoss/nnest)
12

13
from __future__ import print_function, division
1✔
14

15
import os
1✔
16
import sys
1✔
17
import csv
1✔
18
import json
1✔
19
import operator
1✔
20
import time
1✔
21
import warnings
1✔
22

23
from numpy import log, exp, logaddexp
1✔
24
import numpy as np
1✔
25

26
from .utils import create_logger, make_run_dir, resample_equal, vol_prefactor, vectorize, listify as _listify
1✔
27
from .utils import is_affine_transform, normalised_kendall_tau_distance, distributed_work_chunk_size
1✔
28
from ultranest.mlfriends import MLFriends, AffineLayer, LocalAffineLayer, ScalingLayer, find_nearby, WrappingEllipsoid, RobustEllipsoidRegion
1✔
29
from .store import HDF5PointStore, TextPointStore, NullPointStore
1✔
30
from .viz import get_default_viz_callback
1✔
31
from .ordertest import UniformOrderAccumulator
1✔
32
from .netiter import PointPile, SingleCounter, MultiCounter, BreadthFirstIterator, TreeNode, count_tree_between, find_nodes_before, logz_sequence
1✔
33
from .netiter import dump_tree, combine_results
1✔
34
from .hotstart import get_auxiliary_contbox_parameterization
1✔
35

36

37
__all__ = ['ReactiveNestedSampler', 'NestedSampler', 'read_file', 'warmstart_from_similar_file']
1✔
38

39

40
def _get_cumsum_range(pi, dp):
1✔
41
    """Compute quantile indices from probabilities.
42

43
    Parameters
44
    ------------
45
    pi: array
46
        probability of each item.
47
    dp: float
48
        Quantile (between 0 and 0.5).
49

50
    Returns
51
    ---------
52
    index_lo: int
53
        Index of the item corresponding to quantile ``dp``.
54
    index_hi: int
55
        Index of the item corresponding to quantile ``1-dp``.
56
    """
57
    ci = pi.cumsum()
1✔
58
    # this builds a conservatively narrow interval
59
    # find first index where the cumulative is surely above
60
    ilo, = np.where(ci >= dp)
1✔
61
    ilo = ilo[0] if len(ilo) > 0 else 0
1✔
62
    # find last index where the cumulative is surely below
63
    ihi, = np.where(ci <= 1. - dp)
1✔
64
    ihi = ihi[-1] if len(ihi) > 0 else -1
1✔
65
    return ilo, ihi
1✔
66

67

68
def _sequentialize_width_sequence(minimal_widths, min_width):
1✔
69
    """Turn a list of required tree width into an ordered sequence.
70

71
    Parameters
72
    ------------
73
    minimal_widths: list of (Llo, Lhi, width)
74
        Defines the required width between Llo and Lhi.
75
    min_width: int
76
        Minimum width everywhere.
77

78
    Returns
79
    ---------
80
    Lsequence: list of (L, width)
81
        A sequence of L points and the expected tree width at and above it.
82

83
    """
84
    Lpoints = np.unique(_listify(
1✔
85
        [-np.inf], [L for L, _, _ in minimal_widths],
86
        [L for _, L, _ in minimal_widths], [np.inf]))
87
    widths = np.ones(len(Lpoints)) * min_width
1✔
88

89
    for Llo, Lhi, width in minimal_widths:
1✔
90
        # all Lpoints within that range should be maximized to width
91
        # mask = np.logical_and(Lpoints >= Llo, Lpoints <= Lhi)
92
        # the following allows segments to specify -inf..L ranges
93
        mask = ~np.logical_or(Lpoints < Llo, Lpoints > Lhi)
1✔
94
        widths[mask] = np.where(widths[mask] < width, width, widths[mask])
1✔
95

96
    # the width has to monotonically increase to the maximum from both sides
97
    # so we fill up any intermediate dips
98
    max_width = widths.max()
1✔
99
    mid = np.where(widths == max_width)[0][0]
1✔
100
    widest = 0
1✔
101
    for i in range(mid):
1✔
102
        widest = widths[i] = max(widest, widths[i])
1✔
103
    widest = 0
1✔
104
    for i in range(len(widths) - 1, mid, -1):
1✔
105
        widest = widths[i] = max(widest, widths[i])
1✔
106

107
    return list(zip(Lpoints, widths))
1✔
108

109

110
def _explore_iterator_batch(explorer, pop, x_dim, num_params, pointpile, batchsize=1):
1✔
111
    batch = []
1✔
112

113
    while True:
114
        next_node = explorer.next_node()
1✔
115
        if next_node is None:
1!
116
            break
×
117
        rootid, node, (_, active_rootids, active_values, active_node_ids) = next_node
1✔
118
        Lmin = node.value
1✔
119
        children = []
1✔
120

121
        _, row = pop(Lmin)
1✔
122
        if row is not None:
1!
123
            logl = row[1]
1✔
124
            u = row[3:3 + x_dim]
1✔
125
            v = row[3 + x_dim:3 + x_dim + num_params]
1✔
126

127
            assert u.shape == (x_dim,)
1✔
128
            assert v.shape == (num_params,)
1✔
129
            assert logl > Lmin
1✔
130
            children.append((u, v, logl))
1✔
131
            child = pointpile.make_node(logl, u, v)
1✔
132
            node.children.append(child)
1✔
133

134
        batch.append((Lmin, active_values.copy(), children))
1✔
135
        if len(batch) >= batchsize:
1✔
136
            yield batch
1✔
137
            batch = []
1✔
138
        explorer.expand_children_of(rootid, node)
1✔
139
    if len(batch) > 0:
×
140
        yield batch
×
141

142

143
def resume_from_similar_file(
1✔
144
    log_dir, x_dim, loglikelihood, transform,
145
    max_tau=0, verbose=False, ndraw=400
146
):
147
    """
148
    Change a stored UltraNest run to a modified loglikelihood/transform.
149

150
    Parameters
151
    ----------
152
    log_dir: str
153
        Folder containing results
154
    x_dim: int
155
        number of dimensions
156
    loglikelihood: function
157
        new likelihood function
158
    transform: function
159
        new transform function
160
    verbose: bool
161
        show progress
162
    ndraw: int
163
        set to >1 if functions can take advantage of vectorized computations
164
    max_tau: float
165
        Allowed dissimilarity in the live point ordering, quantified as
166
        normalised Kendall tau distance.
167

168
        max_tau=0 is the very conservative choice of stopping the warm start
169
        when the live point order differs.
170
        Near 1 are completely different live point orderings.
171
        Values in between permit mild disorder.
172

173
    Returns
174
    ----------
175
    sequence: dict
176
        contains arrays storing for each iteration estimates of:
177

178
            * logz: log evidence estimate
179
            * logzerr: log evidence uncertainty estimate
180
            * logvol: log volume estimate
181
            * samples_n: number of live points
182
            * logwt: log weight
183
            * logl: log likelihood
184

185
    final: dict
186
        same as ReactiveNestedSampler.results and
187
        ReactiveNestedSampler.run return values
188

189
    """
190
    import h5py
1✔
191
    filepath = os.path.join(log_dir, 'results', 'points.hdf5')
1✔
192
    filepath2 = os.path.join(log_dir, 'results', 'points.hdf5.new')
1✔
193
    fileobj = h5py.File(filepath, 'r')
1✔
194
    _, ncols = fileobj['points'].shape
1✔
195
    num_params = ncols - 3 - x_dim
1✔
196

197
    points = fileobj['points'][:]
1✔
198
    fileobj.close()
1✔
199
    del fileobj
1✔
200

201
    pointstore2 = HDF5PointStore(filepath2, ncols, mode='w')
1✔
202
    stack = list(enumerate(points))
1✔
203

204
    pointpile = PointPile(x_dim, num_params)
1✔
205
    pointpile2 = PointPile(x_dim, num_params)
1✔
206

207
    def pop(Lmin):
1✔
208
        """Find matching sample from points file."""
209
        # look forward to see if there is an exact match
210
        # if we do not use the exact matches
211
        #   this causes a shift in the loglikelihoods
212
        for i, (idx, next_row) in enumerate(stack):
1✔
213
            row_Lmin = next_row[0]
1✔
214
            L = next_row[1]
1✔
215
            if row_Lmin <= Lmin and L > Lmin:
1✔
216
                idx, row = stack.pop(i)
1✔
217
                return idx, row
1✔
218
        return None, None
1✔
219

220
    roots = []
1✔
221
    roots2 = []
1✔
222
    initial_points_u = []
1✔
223
    initial_points_v = []
1✔
224
    initial_points_logl = []
1✔
225
    while True:
226
        _, row = pop(-np.inf)
1✔
227
        if row is None:
1✔
228
            break
1✔
229
        logl = row[1]
1✔
230
        u = row[3:3 + x_dim]
1✔
231
        v = row[3 + x_dim:3 + x_dim + num_params]
1✔
232
        initial_points_u.append(u)
1✔
233
        initial_points_v.append(v)
1✔
234
        initial_points_logl.append(logl)
1✔
235

236
    v2 = transform(np.array(initial_points_u, ndmin=2, dtype=float))
1✔
237
    assert np.allclose(v2, initial_points_v), 'transform inconsistent, cannot resume'
1✔
238
    logls_new = loglikelihood(v2)
1✔
239

240
    for u, v, logl, logl_new in zip(initial_points_u, initial_points_v, initial_points_logl, logls_new):
1✔
241
        roots.append(pointpile.make_node(logl, u, v))
1✔
242
        roots2.append(pointpile2.make_node(logl_new, u, v))
1✔
243
        pointstore2.add(_listify([-np.inf, logl_new, 0.0], u, v), 1)
1✔
244

245
    batchsize = ndraw
1✔
246
    explorer = BreadthFirstIterator(roots)
1✔
247
    explorer2 = BreadthFirstIterator(roots2)
1✔
248
    main_iterator2 = SingleCounter()
1✔
249
    main_iterator2.Lmax = logls_new.max()
1✔
250
    good_state = True
1✔
251

252
    indices1, indices2 = np.meshgrid(np.arange(len(logls_new)), np.arange(len(logls_new)))
1✔
253
    last_good_like = -1e300
1✔
254
    last_good_state = 0
1✔
255
    epsilon = 1 + 1e-6
1✔
256
    niter = 0
1✔
257
    for batch in _explore_iterator_batch(explorer, pop, x_dim, num_params, pointpile, batchsize=batchsize):
1!
258
        assert len(batch) > 0
1✔
259
        batch_u = np.array([u for _, _, children in batch for u, _, _ in children], ndmin=2, dtype=float)
1✔
260
        if batch_u.size > 0:
1!
261
            assert batch_u.shape[1] == x_dim, batch_u.shape
1✔
262
            batch_v = np.array([v for _, _, children in batch for _, v, _ in children], ndmin=2, dtype=float)
1✔
263
            # print("calling likelihood with %d points" % len(batch_u))
264
            v2 = transform(batch_u)
1✔
265
            assert batch_v.shape[1] == num_params, batch_v.shape
1✔
266
            assert np.allclose(v2, batch_v), 'transform inconsistent, cannot resume'
1✔
267
            logls_new = loglikelihood(batch_v)
1✔
268
        else:
269
            # no new points
270
            logls_new = []
×
271

272
        j = 0
1✔
273
        for Lmin, active_values, children in batch:
1✔
274

275
            next_node2 = explorer2.next_node()
1✔
276
            rootid2, node2, (active_nodes2, _, active_values2, _) = next_node2
1✔
277
            Lmin2 = float(node2.value)
1✔
278

279
            # in the tails of distributions it can happen that two points are out of order
280
            # but that may not be very important
281
            # in the interest of practicality, we allow this and only stop the
282
            # warmstart copying when some bulk of points differ.
283
            # in any case, warmstart should not be considered safe, but help iterating
284
            # and a final clean run is needed to finalise the results.
285
            if len(active_values) != len(active_values2):
1!
286
                if verbose == 2:
×
287
                    print("stopping, number of live points differ (%d vs %d)" % (len(active_values), len(active_values2)))
×
288
                    good_state = False
×
289
                break
×
290

291
            if len(active_values) != len(indices1):
1!
292
                indices1, indices2 = np.meshgrid(np.arange(len(active_values)), np.arange(len(active_values2)))
×
293
            tau = normalised_kendall_tau_distance(active_values, active_values2, indices1, indices2)
1✔
294
            order_consistent = tau <= max_tau
1✔
295
            if order_consistent and len(active_values) > 10 and len(active_values) > 10:
1✔
296
                good_state = True
1✔
297
            elif not order_consistent:
1!
298
                good_state = False
1✔
299
            else:
300
                # maintain state
301
                pass
302
            if verbose == 2:
1!
303
                print(niter, tau)
×
304
            if good_state:
1✔
305
                # print("        (%.1e)   L=%.1f" % (last_good_like, Lmin2))
306
                # assert last_good_like < Lmin2, (last_good_like, Lmin2)
307
                last_good_like = Lmin2
1✔
308
                last_good_state = niter
1✔
309
            else:
310
                # interpolate a increasing likelihood
311
                # in the hope that the step size is smaller than
312
                # the likelihood increase
313
                Lmin2 = last_good_like
1✔
314
                node2.value = Lmin2
1✔
315
                last_good_like = last_good_like * epsilon
1✔
316
                break
1✔
317

318
            for u, v, logl_old in children:
1✔
319
                logl_new = logls_new[j]
1✔
320
                j += 1
1✔
321

322
                # print(j, Lmin2, '->', logl_new, 'instead of', Lmin, '->', [c.value for c in node2.children])
323
                child2 = pointpile2.make_node(logl_new, u, v)
1✔
324
                node2.children.append(child2)
1✔
325
                if logl_new > Lmin2:
1✔
326
                    pointstore2.add(_listify([Lmin2, logl_new, 0.0], u, v), 1)
1✔
327
                else:
328
                    if verbose == 2:
1!
329
                        print("cannot use new point because it would decrease likelihood (%.1f->%.1f)" % (Lmin2, logl_new))
×
330
                    # good_state = False
331
                    # break
332

333
            main_iterator2.passing_node(node2, active_nodes2)
1✔
334

335
            niter += 1
1✔
336
            if verbose:
1!
337
                sys.stderr.write("%d...\r" % niter)
×
338

339
            explorer2.expand_children_of(rootid2, node2)
1✔
340

341
        if not good_state:
1✔
342
            break
1✔
343
        if main_iterator2.logZremain < main_iterator2.logZ and not good_state:
1!
344
            # stop as the results diverged already
345
            break
×
346

347
    if verbose:
1!
348
        sys.stderr.write("%d/%d iterations salvaged (%.2f%%).\n" % (
×
349
            last_good_state + 1, len(points), (last_good_state + 1) * 100. / len(points)))
350
    # delete the ones at the end from last_good_state onwards
351
    # assert len(pointstore2.fileobj['points']) == niter, (len(pointstore2.fileobj['points']), niter)
352
    mask = pointstore2.fileobj['points'][:,0] <= last_good_like
1✔
353
    points2 = pointstore2.fileobj['points'][:][mask,:]
1✔
354
    del pointstore2.fileobj['points']
1✔
355
    pointstore2.fileobj.create_dataset(
1✔
356
        'points', dtype=np.float64,
357
        shape=(0, pointstore2.ncols), maxshape=(None, pointstore2.ncols))
358
    pointstore2.fileobj['points'].resize(len(points2), axis=0)
1✔
359
    pointstore2.fileobj['points'][:] = points2
1✔
360
    pointstore2.close()
1✔
361
    del pointstore2
1✔
362

363
    os.replace(filepath2, filepath)
1✔
364

365

366
def _update_region_bootstrap(region, nbootstraps, minvol=0., comm=None, mpi_size=1):
1✔
367
    """
368
    Update *region* with *nbootstraps* rounds of excluding points randomly.
369

370
    Stiffen ellipsoid size using the minimum volume *minvol*.
371

372
    If the mpi communicator *comm* is not None, use MPI to distribute
373
    the bootstraps over the *mpi_size* processes.
374
    """
375
    assert nbootstraps > 0, nbootstraps
1✔
376
    # catch potential errors so MPI syncing still works
377
    e = None
1✔
378
    try:
1✔
379
        r, f = region.compute_enlargement(
1✔
380
            minvol=minvol,
381
            nbootstraps=max(1, nbootstraps // mpi_size))
382
    except np.linalg.LinAlgError as e1:
1✔
383
        e = e1
1✔
384
        r, f = np.nan, np.nan
1✔
385

386
    if comm is not None:
1!
387
        recv_maxradii = comm.gather(r, root=0)
×
388
        recv_maxradii = comm.bcast(recv_maxradii, root=0)
×
389
        # if there are very many processors, we may have more
390
        # rounds than requested, leading to slowdown
391
        # thus we throw away the extra ones
392
        r = np.max(recv_maxradii[:nbootstraps])
×
393
        recv_enlarge = comm.gather(f, root=0)
×
394
        recv_enlarge = comm.bcast(recv_enlarge, root=0)
×
395
        f = np.max(recv_enlarge[:nbootstraps])
×
396

397
    if not np.isfinite(r) and not np.isfinite(r):
1✔
398
        # reraise error if needed
399
        if e is None:
1!
400
            raise np.linalg.LinAlgError("compute_enlargement failed")
×
401
        else:
402
            raise e
1✔
403

404
    region.maxradiussq = r
1✔
405
    region.enlarge = f
1✔
406
    return r, f
1✔
407

408

409
class NestedSampler(object):
1✔
410
    """Simple Nested sampler for reference."""
411

412
    def __init__(self,
1✔
413
                 param_names,
414
                 loglike,
415
                 transform=None,
416
                 derived_param_names=[],
417
                 resume='subfolder',
418
                 run_num=None,
419
                 log_dir='logs/test',
420
                 num_live_points=1000,
421
                 vectorized=False,
422
                 wrapped_params=[],
423
                 ):
424
        """Set up nested sampler.
425

426
        Parameters
427
        -----------
428
        param_names: list of str, names of the parameters.
429
            Length gives dimensionality of the sampling problem.
430
        loglike: function
431
            log-likelihood function.
432
            Receives multiple parameter vectors, returns vector of likelihood.
433
        transform: function
434
            parameter transform from unit cube to physical parameters.
435
            Receives multiple cube vectors, returns multiple parameter vectors.
436
        derived_param_names: list of str
437
            Additional derived parameters created by transform. (empty by default)
438
        log_dir: str
439
            where to store output files
440
        resume: 'resume', 'overwrite' or 'subfolder'
441
            if 'overwrite', overwrite previous data.
442
            if 'subfolder', create a fresh subdirectory in log_dir.
443
            if 'resume' or True, continue previous run if available.
444
        wrapped_params: list of bools
445
            indicating whether this parameter wraps around (circular parameter).
446
        num_live_points: int
447
            Number of live points
448
        vectorized: bool
449
            If true, loglike and transform function can receive arrays
450
            of points.
451
        run_num: int
452
            unique run number. If None, will be automatically incremented.
453

454
        """
455
        self.paramnames = param_names
1✔
456
        x_dim = len(self.paramnames)
1✔
457
        self.num_live_points = num_live_points
1✔
458
        self.sampler = 'nested'
1✔
459
        self.x_dim = x_dim
1✔
460
        self.derivedparamnames = derived_param_names
1✔
461
        num_derived = len(self.derivedparamnames)
1✔
462
        self.num_params = x_dim + num_derived
1✔
463
        self.volfactor = vol_prefactor(self.x_dim)
1✔
464
        if wrapped_params is None:
1!
465
            self.wrapped_axes = []
×
466
        else:
467
            self.wrapped_axes = np.where(wrapped_params)[0]
1✔
468

469
        assert resume or resume in ('overwrite', 'subfolder', 'resume'), "resume should be one of 'overwrite' 'subfolder' or 'resume'"
1✔
470
        append_run_num = resume == 'subfolder'
1✔
471
        resume = resume == 'resume' or resume
1✔
472

473
        if not vectorized:
1!
474
            transform = vectorize(transform)
×
475
            loglike = vectorize(loglike)
×
476

477
        if transform is None:
1!
478
            self.transform = lambda x: x
×
479
        else:
480
            self.transform = transform
1✔
481

482
        u = np.random.uniform(size=(2, self.x_dim))
1✔
483
        p = self.transform(u)
1✔
484
        assert p.shape == (2, self.num_params), ("Error in transform function: returned shape is %s, expected %s" % (p.shape, (2, self.num_params)))
1✔
485
        logl = loglike(p)
1✔
486
        assert np.logical_and(u > 0, u < 1).all(), ("Error in transform function: u was modified!")
1✔
487
        assert np.shape(logl) == (2,), ("Error in loglikelihood function: returned shape is %s, expected %s" % (p.shape, (2, self.num_params)))
1✔
488
        assert np.isfinite(logl).all(), ("Error in loglikelihood function: returned non-finite number: %s for input u=%s p=%s" % (logl, u, p))
1✔
489

490
        def safe_loglike(x):
1✔
491
            """Call likelihood function safely wrapped to avoid non-finite values."""
492
            x = np.asarray(x)
1✔
493
            logl = loglike(x)
1✔
494
            assert np.isfinite(logl).all(), (
1✔
495
                'User-provided loglikelihood returned non-finite value:',
496
                logl[~np.isfinite(logl)][0],
497
                "for input value:",
498
                x[~np.isfinite(logl),:][0,:])
499
            return logl
1✔
500

501
        self.loglike = safe_loglike
1✔
502

503
        self.use_mpi = False
1✔
504
        try:
1✔
505
            from mpi4py import MPI
1✔
506
            self.comm = MPI.COMM_WORLD
1✔
507
            self.mpi_size = self.comm.Get_size()
1✔
508
            self.mpi_rank = self.comm.Get_rank()
1✔
509
            if self.mpi_size > 1:
1!
510
                self.use_mpi = True
×
511
        except Exception:
×
512
            self.mpi_size = 1
×
513
            self.mpi_rank = 0
×
514

515
        self.log = self.mpi_rank == 0
1✔
516
        self.log_to_disk = self.log and log_dir is not None
1✔
517

518
        if self.log and log_dir is not None:
1!
519
            self.logs = make_run_dir(log_dir, run_num, append_run_num=append_run_num)
1✔
520
            log_dir = self.logs['run_dir']
1✔
521
        else:
522
            log_dir = None
×
523

524
        self.logger = create_logger(__name__ + '.' + type(self).__name__, log_dir=log_dir)
1✔
525

526
        if self.log:
1!
527
            self.logger.info('Num live points [%d]', self.num_live_points)
1✔
528

529
        if self.log_to_disk:
1!
530
            # self.pointstore = TextPointStore(os.path.join(self.logs['results'], 'points.tsv'), 2 + self.x_dim + self.num_params)
531
            self.pointstore = HDF5PointStore(
1✔
532
                os.path.join(self.logs['results'], 'points.hdf5'),
533
                3 + self.x_dim + self.num_params, mode='a' if resume else 'w')
534
        else:
535
            self.pointstore = NullPointStore(3 + self.x_dim + self.num_params)
×
536

537
    def run(
1✔
538
            self,
539
            update_interval_iter=None,
540
            update_interval_ncall=None,
541
            log_interval=None,
542
            dlogz=0.001,
543
            max_iters=None):
544
        """Explore parameter space.
545

546
        Parameters
547
        ----------
548
        update_interval_iter:
549
            Update region after this many iterations.
550
        update_interval_ncall:
551
            Update region after update_interval_ncall likelihood calls.
552
        log_interval:
553
            Update stdout status line every log_interval iterations
554
        dlogz:
555
            Target evidence uncertainty.
556
        max_iters:
557
            maximum number of integration iterations.
558

559
        """
560
        if update_interval_ncall is None:
1!
561
            update_interval_ncall = max(1, round(self.num_live_points))
1✔
562

563
        if update_interval_iter is None:
1!
564
            if update_interval_ncall == 0:
1!
565
                update_interval_iter = max(1, round(self.num_live_points))
×
566
            else:
567
                update_interval_iter = max(1, round(0.2 * self.num_live_points))
1✔
568

569
        if log_interval is None:
1!
570
            log_interval = max(1, round(0.2 * self.num_live_points))
×
571
        else:
572
            log_interval = round(log_interval)
1✔
573
            if log_interval < 1:
1!
574
                raise ValueError("log_interval must be >= 1")
×
575

576
        viz_callback = get_default_viz_callback()
1✔
577

578
        prev_u = []
1✔
579
        prev_v = []
1✔
580
        prev_logl = []
1✔
581
        if self.log:
1!
582
            # try to resume:
583
            self.logger.info('Resuming...')
1✔
584
            for i in range(self.num_live_points):
1!
585
                _, row = self.pointstore.pop(-np.inf)
1✔
586
                if row is not None:
1!
587
                    prev_logl.append(row[1])
×
588
                    prev_u.append(row[3:3 + self.x_dim])
×
589
                    prev_v.append(row[3 + self.x_dim:3 + self.x_dim + self.num_params])
×
590
                else:
591
                    break
1✔
592

593
            prev_u = np.array(prev_u)
1✔
594
            prev_v = np.array(prev_v)
1✔
595
            prev_logl = np.array(prev_logl)
1✔
596

597
            num_live_points_missing = self.num_live_points - len(prev_logl)
1✔
598
        else:
599
            num_live_points_missing = -1
×
600

601
        if self.use_mpi:
1!
602
            num_live_points_missing = self.comm.bcast(num_live_points_missing, root=0)
×
603
            prev_u = self.comm.bcast(prev_u, root=0)
×
604
            prev_v = self.comm.bcast(prev_v, root=0)
×
605
            prev_logl = self.comm.bcast(prev_logl, root=0)
×
606

607
        use_point_stack = True
1✔
608

609
        assert num_live_points_missing >= 0
1✔
610
        if num_live_points_missing > 0:
1!
611
            if self.use_mpi:
1!
612
                # self.logger.info('Using MPI with rank [%d]', self.mpi_rank)
613
                if self.mpi_rank == 0:
×
614
                    active_u = np.random.uniform(size=(num_live_points_missing, self.x_dim))
×
615
                else:
616
                    active_u = np.empty((num_live_points_missing, self.x_dim), dtype=np.float64)
×
617
                active_u = self.comm.bcast(active_u, root=0)
×
618
            else:
619
                active_u = np.random.uniform(size=(num_live_points_missing, self.x_dim))
1✔
620
            active_v = self.transform(active_u)
1✔
621

622
            if self.use_mpi:
1!
623
                if self.mpi_rank == 0:
×
624
                    chunks = [[] for _ in range(self.mpi_size)]
×
625
                    for i, chunk in enumerate(active_v):
×
626
                        chunks[i % self.mpi_size].append(chunk)
×
627
                else:
628
                    chunks = None
×
629
                data = self.comm.scatter(chunks, root=0)
×
630
                active_logl = self.loglike(data)
×
631
                recv_active_logl = self.comm.gather(active_logl, root=0)
×
632
                recv_active_logl = self.comm.bcast(recv_active_logl, root=0)
×
633
                active_logl = np.concatenate(recv_active_logl, axis=0)
×
634
            else:
635
                active_logl = self.loglike(active_v)
1✔
636

637
            if self.log_to_disk:
1!
638
                for i in range(num_live_points_missing):
1✔
639
                    self.pointstore.add(
1✔
640
                        _listify([-np.inf, active_logl[i], 0.], active_u[i,:], active_v[i,:]),
641
                        num_live_points_missing)
642

643
            if len(prev_u) > 0:
1!
644
                active_u = np.concatenate((prev_u, active_u))
×
645
                active_v = np.concatenate((prev_v, active_v))
×
646
                active_logl = np.concatenate((prev_logl, active_logl))
×
647
            assert active_u.shape == (self.num_live_points, self.x_dim)
1✔
648
            assert active_v.shape == (self.num_live_points, self.num_params)
1✔
649
            assert active_logl.shape == (self.num_live_points,)
1✔
650
        else:
651
            active_u = prev_u
×
652
            active_v = prev_v
×
653
            active_logl = prev_logl
×
654

655
        saved_u = []
1✔
656
        saved_v = []  # Stored points for posterior results
1✔
657
        saved_logl = []
1✔
658
        saved_logwt = []
1✔
659
        h = 0.0  # Information, initially 0.
1✔
660
        logz = -1e300  # ln(Evidence Z), initially Z=0
1✔
661
        logvol = log(1.0 - exp(-1.0 / self.num_live_points))
1✔
662
        logz_remain = np.max(active_logl)
1✔
663
        fraction_remain = 1.0
1✔
664
        ncall = num_live_points_missing  # number of calls we already made
1✔
665
        first_time = True
1✔
666
        if self.x_dim > 1:
1!
667
            transformLayer = AffineLayer(wrapped_dims=self.wrapped_axes)
1✔
668
        else:
669
            transformLayer = ScalingLayer(wrapped_dims=self.wrapped_axes)
×
670
        transformLayer.optimize(active_u, active_u)
1✔
671
        region = MLFriends(active_u, transformLayer)
1✔
672

673
        if self.log:
1!
674
            self.logger.info('Starting sampling ...')
1✔
675
        ib = 0
1✔
676
        samples = []
1✔
677
        ndraw = 100
1✔
678
        it = 0
1✔
679
        next_update_interval_ncall = -1
1✔
680
        next_update_interval_iter = -1
1✔
681

682
        while max_iters is None or it < max_iters:
1!
683

684
            # Worst object in collection and its weight (= volume * likelihood)
685
            worst = np.argmin(active_logl)
1✔
686
            logwt = logvol + active_logl[worst]
1✔
687

688
            # Update evidence Z and information h.
689
            logz_new = np.logaddexp(logz, logwt)
1✔
690
            h = (exp(logwt - logz_new) * active_logl[worst] + exp(logz - logz_new) * (h + logz) - logz_new)
1✔
691
            logz = logz_new
1✔
692

693
            # Add worst object to samples.
694
            saved_u.append(np.array(active_u[worst]))
1✔
695
            saved_v.append(np.array(active_v[worst]))
1✔
696
            saved_logwt.append(logwt)
1✔
697
            saved_logl.append(active_logl[worst])
1✔
698

699
            # expected_vol = np.exp(-it / self.num_live_points)
700

701
            # The new likelihood constraint is that of the worst object.
702
            loglstar = active_logl[worst]
1✔
703

704
            if ncall > next_update_interval_ncall and it > next_update_interval_iter:
1✔
705

706
                if first_time:
1✔
707
                    nextregion = region
1✔
708
                else:
709
                    # rebuild space
710
                    # print()
711
                    # print("rebuilding space...", active_u.shape, active_u)
712
                    nextTransformLayer = transformLayer.create_new(active_u, region.maxradiussq)
1✔
713
                    nextregion = MLFriends(active_u, nextTransformLayer)
1✔
714

715
                # print("computing maxradius...")
716
                r, f = _update_region_bootstrap(nextregion, 30, 0., self.comm if self.use_mpi else None, self.mpi_size)
1✔
717

718
                nextregion.maxradiussq = r
1✔
719
                nextregion.enlarge = f
1✔
720
                # force shrinkage of volume
721
                # this is to avoid re-connection of dying out nodes
722
                if nextregion.estimate_volume() < region.estimate_volume():
1✔
723
                    region = nextregion
1✔
724
                    transformLayer = region.transformLayer
1✔
725
                region.create_ellipsoid(minvol=exp(-it / self.num_live_points) * self.volfactor)
1✔
726

727
                if self.log:
1!
728
                    viz_callback(
1✔
729
                        points=dict(u=active_u, p=active_v, logl=active_logl),
730
                        info=dict(
731
                            it=it, ncall=ncall, logz=logz, logz_remain=logz_remain,
732
                            paramnames=self.paramnames + self.derivedparamnames,
733
                            logvol=logvol),
734
                        region=region, transformLayer=transformLayer)
735
                    self.pointstore.flush()
1✔
736

737
                next_update_interval_ncall = ncall + update_interval_ncall
1✔
738
                next_update_interval_iter = it + update_interval_iter
1✔
739
                first_time = False
1✔
740

741
            while True:
742
                if ib >= len(samples) and use_point_stack:
1✔
743
                    # root checks the point store
744
                    next_point = np.zeros((1, 3 + self.x_dim + self.num_params))
1✔
745

746
                    if self.log_to_disk:
1!
747
                        _, stored_point = self.pointstore.pop(loglstar)
1✔
748
                        if stored_point is not None:
1!
749
                            next_point[0,:] = stored_point
×
750
                        else:
751
                            next_point[0,:] = -np.inf
1✔
752
                        use_point_stack = not self.pointstore.stack_empty
1✔
753

754
                    if self.use_mpi:  # and informs everyone
1!
755
                        use_point_stack = self.comm.bcast(use_point_stack, root=0)
×
756
                        next_point = self.comm.bcast(next_point, root=0)
×
757

758
                    # assert not use_point_stack
759

760
                    # unpack
761
                    likes = next_point[:,1]
1✔
762
                    samples = next_point[:,3:3 + self.x_dim]
1✔
763
                    samplesv = next_point[:,3 + self.x_dim:3 + self.x_dim + self.num_params]
1✔
764
                    # skip if we already know it is not useful
765
                    ib = 0 if np.isfinite(likes[0]) else 1
1✔
766

767
                while ib >= len(samples):
1✔
768
                    # get new samples
769
                    ib = 0
1✔
770

771
                    nc = 0
1✔
772
                    u = region.sample(nsamples=ndraw)
1✔
773
                    nu = u.shape[0]
1✔
774
                    if nu == 0:
1✔
775
                        v = np.empty((0, self.x_dim))
1✔
776
                        logl = np.empty((0,))
1✔
777
                    else:
778
                        v = self.transform(u)
1✔
779
                        logl = self.loglike(v)
1✔
780
                        nc += nu
1✔
781
                        accepted = logl > loglstar
1✔
782
                        u = u[accepted,:]
1✔
783
                        v = v[accepted,:]
1✔
784
                        logl = logl[accepted]
1✔
785
                        # father = father[accepted]
786

787
                    # collect results from all MPI members
788
                    if self.use_mpi:
1!
789
                        recv_samples = self.comm.gather(u, root=0)
×
790
                        recv_samplesv = self.comm.gather(v, root=0)
×
791
                        recv_likes = self.comm.gather(logl, root=0)
×
792
                        recv_nc = self.comm.gather(nc, root=0)
×
793
                        recv_samples = self.comm.bcast(recv_samples, root=0)
×
794
                        recv_samplesv = self.comm.bcast(recv_samplesv, root=0)
×
795
                        recv_likes = self.comm.bcast(recv_likes, root=0)
×
796
                        recv_nc = self.comm.bcast(recv_nc, root=0)
×
797
                        samples = np.concatenate(recv_samples, axis=0)
×
798
                        samplesv = np.concatenate(recv_samplesv, axis=0)
×
799
                        likes = np.concatenate(recv_likes, axis=0)
×
800
                        ncall += sum(recv_nc)
×
801
                    else:
802
                        samples = np.array(u)
1✔
803
                        samplesv = np.array(v)
1✔
804
                        likes = np.array(logl)
1✔
805
                        ncall += nc
1✔
806

807
                    if self.log:
1!
808
                        for ui, vi, logli in zip(samples, samplesv, likes):
1✔
809
                            self.pointstore.add(
1✔
810
                                _listify([loglstar, logli, 0.0], ui, vi),
811
                                ncall)
812

813
                if likes[ib] > loglstar:
1✔
814
                    active_u[worst] = samples[ib, :]
1✔
815
                    active_v[worst] = samplesv[ib,:]
1✔
816
                    active_logl[worst] = likes[ib]
1✔
817

818
                    # if we keep the region informed about the new live points
819
                    # then the region follows the live points even if maxradius is not updated
820
                    region.u[worst,:] = active_u[worst]
1✔
821
                    region.unormed[worst,:] = region.transformLayer.transform(region.u[worst,:])
1✔
822

823
                    # if we track the cluster assignment, then in the next round
824
                    # the ids with the same members are likely to have the same id
825
                    # this is imperfect
826
                    # transformLayer.clusterids[worst] = transformLayer.clusterids[father[ib]]
827
                    # so we just mark the replaced ones as "unassigned"
828
                    transformLayer.clusterids[worst] = 0
1✔
829
                    ib = ib + 1
1✔
830
                    break
1✔
831
                else:
832
                    ib = ib + 1
1✔
833

834
            # Shrink interval
835
            logvol -= 1.0 / self.num_live_points
1✔
836
            logz_remain = np.max(active_logl) - it / self.num_live_points
1✔
837
            fraction_remain = np.logaddexp(logz, logz_remain) - logz
1✔
838

839
            if it % log_interval == 0 and self.log:
1✔
840
                # nicelogger(self.paramnames, active_u, active_v, active_logl, it, ncall, logz, logz_remain, region=region)
841
                sys.stdout.write('Z=%.1g+%.1g | Like=%.1g..%.1g | it/evals=%d/%d eff=%.4f%%  \r' % (
1✔
842
                    logz, logz_remain, loglstar, np.max(active_logl), it,
843
                    ncall, np.inf if ncall == 0 else it * 100 / ncall))
844
                sys.stdout.flush()
1✔
845

846
                # if efficiency becomes low, bulk-process larger arrays
847
                ndraw = max(128, min(16384, round((ncall + 1) / (it + 1) / self.mpi_size)))
1✔
848

849
            # Stopping criterion
850
            if fraction_remain < dlogz:
1✔
851
                break
1✔
852
            it = it + 1
1✔
853

854
        logvol = -len(saved_v) / self.num_live_points - log(self.num_live_points)
1✔
855
        for i in range(self.num_live_points):
1✔
856
            logwt = logvol + active_logl[i]
1✔
857
            logz_new = np.logaddexp(logz, logwt)
1✔
858
            h = (exp(logwt - logz_new) * active_logl[i] + exp(logz - logz_new) * (h + logz) - logz_new)
1✔
859
            logz = logz_new
1✔
860
            saved_u.append(np.array(active_u[i]))
1✔
861
            saved_v.append(np.array(active_v[i]))
1✔
862
            saved_logwt.append(logwt)
1✔
863
            saved_logl.append(active_logl[i])
1✔
864

865
        saved_u = np.array(saved_u)
1✔
866
        saved_v = np.array(saved_v)
1✔
867
        saved_wt = exp(np.array(saved_logwt) - logz)
1✔
868
        saved_logl = np.array(saved_logl)
1✔
869
        logzerr = np.sqrt(h / self.num_live_points)
1✔
870

871
        if self.log_to_disk:
1!
872
            with open(os.path.join(self.logs['results'], 'final.csv'), 'w') as f:
1✔
873
                writer = csv.writer(f)
1✔
874
                writer.writerow(['niter', 'ncall', 'logz', 'logzerr', 'h'])
1✔
875
                writer.writerow([it + 1, ncall, logz, logzerr, h])
1✔
876
            self.pointstore.close()
1✔
877

878
        if not self.use_mpi or self.mpi_rank == 0:
1!
879
            print()
1✔
880
            print("niter: {:d}\n ncall: {:d}\n nsamples: {:d}\n logz: {:6.3f} +/- {:6.3f}\n h: {:6.3f}"
1✔
881
                  .format(it + 1, ncall, len(saved_v), logz, logzerr, h))
882

883
        self.results = dict(
1✔
884
            samples=resample_equal(saved_v, saved_wt / saved_wt.sum()),
885
            ncall=ncall, niter=it, logz=logz, logzerr=logzerr,
886
            weighted_samples=dict(
887
                upoints=saved_u, points=saved_v, weights=saved_wt,
888
                logweights=saved_logwt, logl=saved_logl),
889
        )
890

891
        return self.results
1✔
892

893
    def print_results(self):
1✔
894
        """Give summary of marginal likelihood and parameters."""
895
        print()
×
896
        print('logZ = %(logz).3f +- %(logzerr).3f' % self.results)
×
897

898
        print()
×
899
        for i, p in enumerate(self.paramnames + self.derivedparamnames):
×
900
            v = self.results['samples'][:,i]
×
901
            sigma = v.std()
×
902
            med = v.mean()
×
903
            if sigma == 0:
×
904
                i = 3
×
905
            else:
906
                i = max(0, int(-np.floor(np.log10(sigma))) + 1)
×
907
            fmt = '%%.%df' % i
×
908
            fmts = '\t'.join(['    %-20s' + fmt + " +- " + fmt])
×
909
            print(fmts % (p, med, sigma))
×
910

911
    def plot(self):
1✔
912
        """Make corner plot."""
913
        if self.log_to_disk:
1!
914
            import matplotlib.pyplot as plt
1✔
915
            import corner
1✔
916
            data = np.array(self.results['weighted_samples']['points'])
1✔
917
            weights = np.array(self.results['weighted_samples']['weights'])
1✔
918
            cumsumweights = np.cumsum(weights)
1✔
919

920
            mask = cumsumweights > 1e-4
1✔
921

922
            corner.corner(
1✔
923
                data[mask,:], weights=weights[mask],
924
                labels=self.paramnames + self.derivedparamnames,
925
                show_titles=True)
926
            plt.savefig(os.path.join(self.logs['plots'], 'corner.pdf'), bbox_inches='tight')
1✔
927
            plt.close()
1✔
928

929

930
def warmstart_from_similar_file(
1✔
931
    usample_filename,
932
    param_names,
933
    loglike,
934
    transform,
935
    vectorized=False,
936
    min_num_samples=50
937
):
938
    """Warmstart from a previous run.
939

940
    Usage::
941

942
        aux_paramnames, aux_log_likelihood, aux_prior_transform, vectorized = warmstart_from_similar_file(
943
            'model1/chains/weighted_post_untransformed.txt', parameters, log_likelihood_with_background, prior_transform)
944

945
        aux_sampler = ReactiveNestedSampler(aux_paramnames, aux_log_likelihood, transform=aux_prior_transform,vectorized=vectorized)
946
        aux_sampler.run()
947
        posterior_samples = aux_results['samples'][:,-1]
948

949
    See :py:func:`ultranest.hotstart.get_auxiliary_contbox_parameterization`
950
    for more information.
951

952
    The remaining parameters have the same meaning as in :py:class:`ReactiveNestedSampler`.
953

954
    Parameters
955
    ------------
956
    usample_filename: str
957
        'directory/chains/weighted_post_untransformed.txt'
958
        contains posteriors in u-space (untransformed) of a previous run.
959
        Columns are weight, logl, param1, param2, ...
960
    min_num_samples: int
961
        minimum number of samples in the usample_filename file required.
962
        Too few samples will give a poor approximation.
963

964
    Other Parameters
965
    -----------------
966
    param_names: list
967
    loglike: function
968
    transform: function
969
    vectorized: bool
970

971
    Returns
972
    ---------
973
    aux_param_names: list
974
        new parameter list
975
    aux_loglikelihood: function
976
        new loglikelihood function
977
    aux_transform: function
978
        new prior transform function
979
    vectorized: bool
980
        whether the new functions are vectorized
981
    """
982
    # load samples
983
    try:
1✔
984
        with open(usample_filename) as f:
1✔
985
            old_param_names = f.readline().lstrip('#').strip().split()
1✔
986
            auxiliary_usamples = np.loadtxt(f)
1✔
987
    except IOError:
×
988
        warnings.warn('not hot-resuming, could not load file "%s"' % usample_filename)
×
989
        return param_names, loglike, transform, vectorized
×
990

991
    ulogl = auxiliary_usamples[:,1]
1✔
992
    uweights_full = auxiliary_usamples[:,0] * np.exp(ulogl - ulogl.max())
1✔
993
    mask = uweights_full > 0
1✔
994
    uweights = uweights_full[mask]
1✔
995
    uweights /= uweights.sum()
1✔
996
    upoints = auxiliary_usamples[mask,2:]
1✔
997
    del auxiliary_usamples
1✔
998

999
    nsamples = len(upoints)
1✔
1000
    if nsamples < min_num_samples:
1!
1001
        raise ValueError('file "%s" has too few samples (%d) to hot-resume' % (usample_filename, nsamples))
×
1002

1003
    # check that the parameter meanings have not changed
1004
    if old_param_names != ['weight', 'logl'] + param_names:
1!
1005
        raise ValueError('file "%s" has parameters %s, expected %s, cannot hot-resume.' % (usample_filename, old_param_names, param_names))
×
1006

1007
    return get_auxiliary_contbox_parameterization(
1✔
1008
        param_names, loglike=loglike, transform=transform,
1009
        vectorized=vectorized,
1010
        upoints=upoints,
1011
        uweights=uweights,
1012
    )
1013

1014

1015
class ReactiveNestedSampler(object):
1✔
1016
    """Nested sampler with reactive exploration strategy.
1017

1018
    Storage & resume capable, optionally MPI parallelised.
1019
    """
1020

1021
    def __init__(self,
1✔
1022
                 param_names,
1023
                 loglike,
1024
                 transform=None,
1025
                 derived_param_names=[],
1026
                 wrapped_params=None,
1027
                 resume='subfolder',
1028
                 run_num=None,
1029
                 log_dir=None,
1030
                 num_test_samples=2,
1031
                 draw_multiple=True,
1032
                 num_bootstraps=30,
1033
                 vectorized=False,
1034
                 ndraw_min=128,
1035
                 ndraw_max=65536,
1036
                 storage_backend='hdf5',
1037
                 warmstart_max_tau=-1,
1038
                 ):
1039
        """Initialise nested sampler.
1040

1041
        Parameters
1042
        -----------
1043
        param_names: list of str, names of the parameters.
1044
            Length gives dimensionality of the sampling problem.
1045

1046
        loglike: function
1047
            log-likelihood function.
1048
            Receives multiple parameter vectors, returns vector of likelihood.
1049
        transform: function
1050
            parameter transform from unit cube to physical parameters.
1051
            Receives multiple cube vectors, returns multiple parameter vectors.
1052

1053
        derived_param_names: list of str
1054
            Additional derived parameters created by transform. (empty by default)
1055

1056
        log_dir: str
1057
            where to store output files
1058
        resume: 'resume', 'resume-similar', 'overwrite' or 'subfolder'
1059

1060
            if 'overwrite', overwrite previous data.
1061

1062
            if 'subfolder', create a fresh subdirectory in log_dir.
1063

1064
            if 'resume' or True, continue previous run if available.
1065
            Only works when dimensionality, transform or likelihood are consistent.
1066

1067
            if 'resume-similar', continue previous run if available.
1068
            Only works when dimensionality and transform are consistent.
1069
            If a likelihood difference is detected, the existing likelihoods
1070
            are updated until the live point order differs.
1071
            Otherwise, behaves like resume.
1072

1073
        run_num: int or None
1074
            If resume=='subfolder', this is the subfolder number.
1075
            Automatically increments if set to None.
1076

1077
        wrapped_params: list of bools
1078
            indicating whether this parameter wraps around (circular parameter).
1079

1080
        num_test_samples: int
1081
            test transform and likelihood with this number of
1082
            random points for errors first. Useful to catch bugs.
1083

1084
        vectorized: bool
1085
            If true, loglike and transform function can receive arrays
1086
            of points.
1087

1088
        draw_multiple: bool
1089
            If efficiency goes down, dynamically draw more points
1090
            from the region between `ndraw_min` and `ndraw_max`.
1091
            If set to False, few points are sampled at once.
1092

1093
        ndraw_min: int
1094
            Minimum number of points to simultaneously propose.
1095
            Increase this if your likelihood makes vectorization very cheap.
1096

1097
        ndraw_max: int
1098
            Maximum number of points to simultaneously propose.
1099
            Increase this if your likelihood makes vectorization very cheap.
1100
            Memory allocation may be slow for extremely high values.
1101

1102
        num_bootstraps: int
1103
            number of logZ estimators and MLFriends region
1104
            bootstrap rounds.
1105

1106
        storage_backend: str or class
1107
            Class to use for storing the evaluated points (see ultranest.store)
1108
            'hdf5' is strongly recommended. 'tsv' and 'csv' are also possible.
1109

1110
        warmstart_max_tau: float
1111
            Maximum disorder to accept when resume='resume-similar';
1112
            Live points are reused as long as the live point order
1113
            is below this normalised Kendall tau distance.
1114
            Values from 0 (highly conservative) to 1 (extremely negligent).
1115
        """
1116
        self.paramnames = param_names
1✔
1117
        x_dim = len(self.paramnames)
1✔
1118

1119
        self.sampler = 'reactive-nested'
1✔
1120
        self.x_dim = x_dim
1✔
1121
        self.transform_layer_class = LocalAffineLayer if x_dim > 1 else ScalingLayer
1✔
1122
        self.derivedparamnames = derived_param_names
1✔
1123
        self.num_bootstraps = int(num_bootstraps)
1✔
1124
        num_derived = len(self.derivedparamnames)
1✔
1125
        self.num_params = x_dim + num_derived
1✔
1126
        if wrapped_params is None:
1!
1127
            self.wrapped_axes = []
1✔
1128
        else:
1129
            assert len(wrapped_params) == self.x_dim, ("wrapped_params has the number of entries:", wrapped_params, ", expected", self.x_dim)
×
1130
            self.wrapped_axes = np.where(wrapped_params)[0]
×
1131

1132
        self.use_mpi = False
1✔
1133
        try:
1✔
1134
            from mpi4py import MPI
1✔
1135
            self.comm = MPI.COMM_WORLD
1✔
1136
            self.mpi_size = self.comm.Get_size()
1✔
1137
            self.mpi_rank = self.comm.Get_rank()
1✔
1138
            if self.mpi_size > 1:
1!
1139
                self.use_mpi = True
×
1140
                self._setup_distributed_seeds()
×
1141
        except Exception:
×
1142
            self.mpi_size = 1
×
1143
            self.mpi_rank = 0
×
1144

1145
        self.log = self.mpi_rank == 0
1✔
1146
        self.log_to_disk = self.log and log_dir is not None
1✔
1147
        self.log_to_pointstore = self.log_to_disk
1✔
1148

1149
        assert resume in (True, 'overwrite', 'subfolder', 'resume', 'resume-similar'), \
1✔
1150
            "resume should be one of 'overwrite' 'subfolder', 'resume' or 'resume-similar'"
1151
        append_run_num = resume == 'subfolder'
1✔
1152
        resume_similar = resume == 'resume-similar'
1✔
1153
        resume = resume in ('resume-similar', 'resume', True)
1✔
1154

1155
        if self.log and log_dir is not None:
1✔
1156
            self.logs = make_run_dir(log_dir, run_num, append_run_num=append_run_num)
1✔
1157
            log_dir = self.logs['run_dir']
1✔
1158
        else:
1159
            log_dir = None
1✔
1160

1161
        if self.log:
1!
1162
            self.logger = create_logger('ultranest', log_dir=log_dir)
1✔
1163
            self.logger.debug('ReactiveNestedSampler: dims=%d+%d, resume=%s, log_dir=%s, backend=%s, vectorized=%s, nbootstraps=%s, ndraw=%s..%s' % (
1✔
1164
                x_dim, num_derived, resume, log_dir, storage_backend, vectorized,
1165
                num_bootstraps, ndraw_min, ndraw_max,
1166
            ))
1167
        self.root = TreeNode(id=-1, value=-np.inf)
1✔
1168

1169
        self.pointpile = PointPile(self.x_dim, self.num_params)
1✔
1170
        if self.log_to_pointstore:
1✔
1171
            storage_filename = os.path.join(self.logs['results'], 'points.' + storage_backend)
1✔
1172
            storage_num_cols = 3 + self.x_dim + self.num_params
1✔
1173
            if storage_backend == 'tsv':
1✔
1174
                self.pointstore = TextPointStore(storage_filename, storage_num_cols)
1✔
1175
                self.pointstore.delimiter = '\n'
1✔
1176
            elif storage_backend == 'csv':
1✔
1177
                self.pointstore = TextPointStore(storage_filename, storage_num_cols)
1✔
1178
                self.pointstore.delimiter = ','
1✔
1179
            elif storage_backend == 'hdf5':
1!
1180
                self.pointstore = HDF5PointStore(storage_filename, storage_num_cols, mode='a' if resume else 'w')
1✔
1181
            else:
1182
                # use custom backend
1183
                self.pointstore = storage_backend
×
1184
        else:
1185
            self.pointstore = NullPointStore(3 + self.x_dim + self.num_params)
1✔
1186
        self.ncall = self.pointstore.ncalls
1✔
1187
        self.ncall_region = 0
1✔
1188

1189
        if not vectorized:
1✔
1190
            if transform is not None:
1✔
1191
                transform = vectorize(transform)
1✔
1192
            loglike = vectorize(loglike)
1✔
1193
            draw_multiple = False
1✔
1194

1195
        self.draw_multiple = draw_multiple
1✔
1196
        self.ndraw_min = ndraw_min
1✔
1197
        self.ndraw_max = ndraw_max
1✔
1198
        self.build_tregion = transform is not None
1✔
1199
        if not self._check_likelihood_function(transform, loglike, num_test_samples):
1✔
1200
            assert self.log_to_disk
1✔
1201
            if resume_similar and self.log_to_disk:
1✔
1202
                assert storage_backend == 'hdf5', 'resume-similar is only supported for HDF5 files'
1✔
1203
                assert 0 <= warmstart_max_tau <= 1, 'warmstart_max_tau parameter needs to be set to a value between 0 and 1'
1✔
1204
                # close
1205
                self.pointstore.close()
1✔
1206
                del self.pointstore
1✔
1207
                # rewrite points file
1208
                if self.log:
1!
1209
                    self.logger.info('trying to salvage points from previous, different run ...')
1✔
1210
                resume_from_similar_file(
1✔
1211
                    log_dir, x_dim, loglike, transform,
1212
                    ndraw=ndraw_min if vectorized else 1,
1213
                    max_tau=warmstart_max_tau, verbose=False)
1214
                self.pointstore = HDF5PointStore(
1✔
1215
                    os.path.join(self.logs['results'], 'points.hdf5'),
1216
                    3 + self.x_dim + self.num_params, mode='a' if resume else 'w')
1217
            elif resume:
1!
1218
                raise Exception("Cannot resume because loglikelihood function changed, "
1✔
1219
                                "unless resume=resume-similar. To start from scratch, delete '%s'." % (log_dir))
1220
        self._set_likelihood_function(transform, loglike, num_test_samples)
1✔
1221
        self.stepsampler = None
1✔
1222

1223
    def _setup_distributed_seeds(self):
1✔
1224
        if not self.use_mpi:
×
1225
            return
×
1226
        seed = 0
×
1227
        if self.mpi_rank == 0:
×
1228
            seed = np.random.randint(0, 1000000)
×
1229

1230
        seed = self.comm.bcast(seed, root=0)
×
1231
        if self.mpi_rank > 0:
×
1232
            # from http://arxiv.org/abs/1005.4117
1233
            seed = int(abs(((seed * 181) * ((self.mpi_rank - 83) * 359)) % 104729))
×
1234
            # print('setting seed:', self.mpi_rank, seed)
1235
            np.random.seed(seed)
×
1236

1237
    def _check_likelihood_function(self, transform, loglike, num_test_samples):
1✔
1238
        """Test the `transform` and `loglike`lihood functions.
1239

1240
        `num_test_samples` samples are used to check whether they work and give the correct output.
1241

1242
        returns whether the most recently stored point (if any)
1243
        still returns the same likelihood value.
1244
        """
1245
        # do some checks on the likelihood function
1246
        # this makes debugging easier by failing early with meaningful errors
1247

1248
        # if we are resuming, check that last sample still gives same result
1249
        num_resume_test_samples = 0
1✔
1250
        if num_test_samples and not self.pointstore.stack_empty:
1✔
1251
            num_resume_test_samples = 1
1✔
1252
            num_test_samples -= 1
1✔
1253

1254
        if num_test_samples > 0:
1!
1255
            # test with num_test_samples random points
1256
            u = np.random.uniform(size=(num_test_samples, self.x_dim))
1✔
1257
            p = transform(u) if transform is not None else u
1✔
1258
            assert np.shape(p) == (num_test_samples, self.num_params), (
1✔
1259
                "Error in transform function: returned shape is %s, expected %s" % (
1260
                    np.shape(p), (num_test_samples, self.num_params)))
1261
            logl = loglike(p)
1✔
1262
            assert np.logical_and(u > 0, u < 1).all(), (
1✔
1263
                "Error in transform function: u was modified!")
1264
            assert np.shape(logl) == (num_test_samples,), (
1✔
1265
                "Error in loglikelihood function: returned shape is %s, expected %s" % (np.shape(logl), (num_test_samples,)))
1266
            assert np.isfinite(logl).all(), (
1✔
1267
                "Error in loglikelihood function: returned non-finite number: %s for input u=%s p=%s" % (logl, u, p))
1268

1269
        if not self.pointstore.stack_empty and num_resume_test_samples > 0:
1✔
1270
            # test that last sample gives the same likelihood value
1271
            _, lastrow = self.pointstore.stack[-1]
1✔
1272
            assert len(lastrow) == 3 + self.x_dim + self.num_params, (
1✔
1273
                "Cannot resume: problem has different dimensionality",
1274
                len(lastrow), (2, self.x_dim, self.num_params))
1275
            lastL = lastrow[1]
1✔
1276
            lastu = lastrow[3:3 + self.x_dim]
1✔
1277
            u = lastu.reshape((1, -1))
1✔
1278
            lastp = lastrow[3 + self.x_dim:3 + self.x_dim + self.num_params]
1✔
1279
            if self.log:
1!
1280
                self.logger.debug("Testing resume consistency: %s: u=%s -> p=%s -> L=%s ", lastrow, lastu, lastp, lastL)
1✔
1281
            p = transform(u) if transform is not None else u
1✔
1282
            if not np.allclose(p.flatten(), lastp) and self.log:
1!
1283
                self.logger.warning(
×
1284
                    "Trying to resume from previous run, but transform function gives different result: %s gave %s, now %s",
1285
                    lastu, lastp, p.flatten())
1286
            assert np.allclose(p.flatten(), lastp), (
1✔
1287
                "Cannot resume because transform function changed. "
1288
                "To start from scratch, delete '%s'." % (self.logs['run_dir']))
1289
            logl = loglike(p).flatten()[0]
1✔
1290
            if not np.isclose(logl, lastL) and self.log:
1✔
1291
                self.logger.warning(
1✔
1292
                    "Trying to resume from previous run, but likelihood function gives different result: %s gave %s, now %s",
1293
                    lastu.flatten(), lastL, logl)
1294
            return np.isclose(logl, lastL)
1✔
1295
        return True
1✔
1296

1297
    def _set_likelihood_function(self, transform, loglike, num_test_samples, make_safe=False):
1✔
1298
        """Store the transform and log-likelihood functions.
1299

1300
        if make_safe is set, make functions safer by accepting misformed
1301
        return shapes and non-finite likelihood values.
1302
        """
1303

1304
        def safe_loglike(x):
1✔
1305
            """Safe wrapper of likelihood function."""
1306
            x = np.asarray(x)
×
1307
            if len(x.shape) == 1:
×
1308
                assert x.shape[0] == self.x_dim
×
1309
                x = np.expand_dims(x, 0)
×
1310
            logl = loglike(x)
×
1311
            if len(logl.shape) == 0:
×
1312
                logl = np.expand_dims(logl, 0)
×
1313
            logl[np.logical_not(np.isfinite(logl))] = -1e100
×
1314
            return logl
×
1315

1316
        if make_safe:
1!
1317
            self.loglike = safe_loglike
×
1318
        else:
1319
            self.loglike = loglike
1✔
1320

1321
        if transform is None:
1✔
1322
            self.transform = lambda x: x
1✔
1323
        elif make_safe:
1!
1324
            def safe_transform(x):
×
1325
                """Safe wrapper of transform function."""
1326
                x = np.asarray(x)
×
1327
                if len(x.shape) == 1:
×
1328
                    assert x.shape[0] == self.x_dim
×
1329
                    x = np.expand_dims(x, 0)
×
1330
                return transform(x)
×
1331
            self.transform = safe_transform
×
1332
        else:
1333
            self.transform = transform
1✔
1334

1335
        lims = np.ones((2, self.x_dim))
1✔
1336
        lims[0,:] = 1e-6
1✔
1337
        lims[1,:] = 1 - 1e-6
1✔
1338
        self.transform_limits = self.transform(lims).transpose()
1✔
1339

1340
        self.volfactor = vol_prefactor(self.x_dim)
1✔
1341

1342
    def _widen_nodes(self, weighted_parents, weights, nnodes_needed, update_interval_ncall):
1✔
1343
        """Ensure that at parents have `nnodes_needed` live points (parallel arcs).
1344

1345
        If not, fill up by sampling.
1346
        """
1347
        ndone = len(weighted_parents)
1✔
1348
        if ndone == 0:
1!
1349
            if self.log:
×
1350
                self.logger.info('No parents, so widening roots')
×
1351
            self._widen_roots(nnodes_needed)
×
1352
            return {}
×
1353

1354
        # select parents with weight 1/parent_weights
1355
        p = 1. / np.array(weights)
1✔
1356
        if (p == p[0]).all():
1✔
1357
            parents = weighted_parents
1✔
1358
        else:
1359
            # preferentially select nodes with few parents, as those
1360
            # have most weight
1361
            i = np.random.choice(len(weighted_parents), size=nnodes_needed, p=p / p.sum())
1✔
1362
            if self.use_mpi:
1!
1363
                i = self.comm.bcast(i, root=0)
×
1364

1365
            parents = [weighted_parents[ii] for ii in i]
1✔
1366

1367
        del weighted_parents, weights
1✔
1368

1369
        # sort from low to high
1370
        parents.sort(key=operator.attrgetter('value'))
1✔
1371
        Lmin = parents[0].value
1✔
1372
        if np.isinf(Lmin):
1!
1373
            # some of the parents were born by sampling from the entire
1374
            # prior volume. So we can efficiently apply a solution:
1375
            # expand the roots
1376
            if self.log:
×
1377
                self.logger.info('parent value is -inf, so widening roots')
×
1378
            self._widen_roots(nnodes_needed)
×
1379
            return {}
×
1380

1381
        # double until we reach the necessary points
1382
        # this is probably 1, from (2K - K) / K
1383
        nsamples = int(np.ceil((nnodes_needed - ndone) / len(parents)))
1✔
1384

1385
        if self.log:
1!
1386
            self.logger.info('Will add %d live points (x%d) at L=%.1g ...', nnodes_needed - ndone, nsamples, Lmin)
1✔
1387

1388
        # add points where necessary (parents can have multiple entries)
1389
        target_min_num_children = {}
1✔
1390
        for n in parents:
1✔
1391
            orign = target_min_num_children.get(n.id, len(n.children))
1✔
1392
            target_min_num_children[n.id] = orign + nsamples
1✔
1393

1394
        return target_min_num_children
1✔
1395

1396
    def _widen_roots_beyond_initial_plateau(self, nroots, num_warn, num_stop):
1✔
1397
        """Widen roots, but populate ahead of initial plateau.
1398

1399
        calls _widen_roots, and if there are several points with the same
1400
        value equal to the lowest loglikelihood, widens some more until
1401
        there are `nroots`-1 that are different to the lowest
1402
        loglikelihood value.
1403

1404
        Parameters
1405
        -----------
1406
        nroots: int
1407
            Number of root live points, after the plateau is traversed.
1408

1409
        num_warn: int
1410
            Warn if the number of root live points reached this.
1411

1412
        num_stop: int
1413
            Do not increasing the number of root live points beyond this limit.
1414

1415
        """
1416
        nroots_needed = nroots
1✔
1417
        user_has_been_warned = False
1✔
1418
        while True:
1419
            self._widen_roots(nroots_needed)
1✔
1420
            Ls = np.array([node.value for node in self.root.children])
1✔
1421
            Lmin = np.min(Ls)
1✔
1422
            if self.log and nroots_needed > num_warn and not user_has_been_warned:
1✔
1423
                self.logger.warning("""Warning: The log-likelihood has a large plateau with L=%g.
1✔
1424

1425
  Probably you are returning a low value when the parameters are problematic/unphysical.
1426
  ultranest can handle this correctly, by discarding live points with the same loglikelihood.
1427
  (arxiv:2005.08602 arxiv:2010.13884). To mitigate running out of live points,
1428
  the initial number of live points is increased. But now this has reached over %d points.
1429

1430
  You can avoid this making the loglikelihood increase towards where the good region is.
1431
  For example, let's say you have two parameters where the sum must be below 1. Replace this:
1432

1433
    if params[0] + params[1] > 1:
1434
         return -1e300
1435

1436
  with:
1437

1438
    if params[0] + params[1] > 1:
1439
         return -1e300 * (params[0] + params[1])
1440

1441
  The current strategy will continue until %d live points are reached.
1442
  It is safe to ignore this warning.""", Lmin, num_warn, num_stop)
1443
                user_has_been_warned = True
1✔
1444

1445
            if nroots_needed >= num_stop:
1!
1446
                break
×
1447
            P = (Ls == Lmin).sum()
1✔
1448
            if 1 < P < len(Ls) and len(Ls) - P + 1 < nroots:
1!
1449
                # guess the number of points needed: P-1 are useless
1450
                if self.log:
1!
1451
                    self.logger.debug(
1✔
1452
                        'Found plateau of %d/%d initial points at L=%g. '
1453
                        'Avoid this by a continuously increasing loglikelihood towards good regions.',
1454
                        P, nroots_needed, Lmin)
1455
                nroots_needed = min(num_stop, nroots_needed + (P - 1))
1✔
1456
            else:
1457
                break
×
1458

1459
    def _widen_roots(self, nroots):
1✔
1460
        """Ensure root has `nroots` children.
1461

1462
        Sample from prior to fill up (if needed).
1463

1464
        Parameters
1465
        -----------
1466
        nroots: int
1467
            Number of root live points, after the plateau is traversed.
1468
        """
1469
        if self.log and len(self.root.children) > 0:
1✔
1470
            self.logger.info('Widening roots to %d live points (have %d already) ...', nroots, len(self.root.children))
1✔
1471
        nnewroots = nroots - len(self.root.children)
1✔
1472
        if nnewroots <= 0:
1!
1473
            # nothing to do
1474
            return
×
1475

1476
        prev_u = []
1✔
1477
        prev_v = []
1✔
1478
        prev_logl = []
1✔
1479
        prev_rowid = []
1✔
1480

1481
        if self.log and self.use_point_stack:
1✔
1482
            # try to resume:
1483
            # self.logger.info('Resuming...')
1484
            for i in range(nnewroots):
1✔
1485
                rowid, row = self.pointstore.pop(-np.inf)
1✔
1486
                if row is None:
1!
1487
                    break
×
1488
                prev_logl.append(row[1])
1✔
1489
                prev_u.append(row[3:3 + self.x_dim])
1✔
1490
                prev_v.append(row[3 + self.x_dim:3 + self.x_dim + self.num_params])
1✔
1491
                prev_rowid.append(rowid)
1✔
1492

1493
        if self.log:
1!
1494
            prev_u = np.array(prev_u)
1✔
1495
            prev_v = np.array(prev_v)
1✔
1496
            prev_logl = np.array(prev_logl)
1✔
1497

1498
            num_live_points_missing = nnewroots - len(prev_logl)
1✔
1499
        else:
1500
            num_live_points_missing = -1
×
1501

1502
        if self.use_mpi:
1!
1503
            num_live_points_missing = self.comm.bcast(num_live_points_missing, root=0)
×
1504
            prev_u = self.comm.bcast(prev_u, root=0)
×
1505
            prev_v = self.comm.bcast(prev_v, root=0)
×
1506
            prev_logl = self.comm.bcast(prev_logl, root=0)
×
1507

1508
        assert num_live_points_missing >= 0
1✔
1509
        if self.log and num_live_points_missing > 0:
1✔
1510
            self.logger.info('Sampling %d live points from prior ...', num_live_points_missing)
1✔
1511
        if num_live_points_missing > 0:
1✔
1512
            num_live_points_todo = distributed_work_chunk_size(num_live_points_missing, self.mpi_rank, self.mpi_size)
1✔
1513
            self.ncall += num_live_points_missing
1✔
1514

1515
            if num_live_points_todo > 0:
1!
1516
                active_u = np.random.uniform(size=(num_live_points_todo, self.x_dim))
1✔
1517
                active_v = self.transform(active_u)
1✔
1518
                active_logl = self.loglike(active_v)
1✔
1519
            else:
1520
                active_u = np.empty((0, self.x_dim))
×
1521
                active_v = np.empty((0, self.num_params))
×
1522
                active_logl = np.empty((0,))
×
1523

1524
            if self.use_mpi:
1!
1525
                recv_samples = self.comm.gather(active_u, root=0)
×
1526
                recv_samplesv = self.comm.gather(active_v, root=0)
×
1527
                recv_likes = self.comm.gather(active_logl, root=0)
×
1528
                recv_samples = self.comm.bcast(recv_samples, root=0)
×
1529
                recv_samplesv = self.comm.bcast(recv_samplesv, root=0)
×
1530
                recv_likes = self.comm.bcast(recv_likes, root=0)
×
1531

1532
                active_u = np.concatenate(recv_samples, axis=0)
×
1533
                active_v = np.concatenate(recv_samplesv, axis=0)
×
1534
                active_logl = np.concatenate(recv_likes, axis=0)
×
1535

1536
            assert active_logl.shape == (num_live_points_missing,), (active_logl.shape, num_live_points_missing)
1✔
1537

1538
            if self.log_to_pointstore:
1✔
1539
                for i in range(num_live_points_missing):
1✔
1540
                    rowid = self.pointstore.add(_listify(
1✔
1541
                        [-np.inf, active_logl[i], 0.0],
1542
                        active_u[i,:],
1543
                        active_v[i,:]), 1)
1544

1545
            if len(prev_u) > 0:
1!
1546
                active_u = np.concatenate((prev_u, active_u))
×
1547
                active_v = np.concatenate((prev_v, active_v))
×
1548
                active_logl = np.concatenate((prev_logl, active_logl))
×
1549
            assert active_u.shape == (nnewroots, self.x_dim), (active_u.shape, nnewroots, self.x_dim, num_live_points_missing, len(prev_u))
1✔
1550
            assert active_v.shape == (nnewroots, self.num_params), (active_v.shape, nnewroots, self.num_params, num_live_points_missing, len(prev_u))
1✔
1551
            assert active_logl.shape == (nnewroots,), (active_logl.shape, nnewroots)
1✔
1552
        else:
1553
            active_u = prev_u
1✔
1554
            active_v = prev_v
1✔
1555
            active_logl = prev_logl
1✔
1556

1557
        roots = [self.pointpile.make_node(logl, u, p) for u, p, logl in zip(active_u, active_v, active_logl)]
1✔
1558
        if len(active_u) > 4:
1!
1559
            self.build_tregion = not is_affine_transform(active_u, active_v)
1✔
1560
        self.root.children += roots
1✔
1561

1562
    def _adaptive_strategy_advice(self, Lmin, parallel_values, main_iterator, minimal_widths, frac_remain, Lepsilon):
1✔
1563
        """Check if integration is done.
1564

1565
        Returns range where more sampling is needed
1566

1567
        Returns
1568
        --------
1569
        Llo: float
1570
            lower log-likelihood bound, nan if done
1571
        Lhi: float
1572
            lower log-likelihood bound, nan if done
1573

1574
        Parameters
1575
        -----------
1576
        Lmin: float
1577
            current loglikelihood threshold
1578
        parallel_values: array of floats
1579
            loglikelihoods of live points
1580
        main_iterator: BreadthFirstIterator
1581
            current tree exploration iterator
1582
        minimal_widths: list
1583
            current width required
1584
        frac_remain: float
1585
            maximum fraction of integral in remainder for termination
1586
        Lepsilon: float
1587
            loglikelihood accuracy threshold
1588
        """
1589
        Ls = parallel_values.copy()
1✔
1590
        Ls.sort()
1✔
1591
        # Ls = [node.value] + [n.value for rootid2, n in parallel_nodes]
1592
        Lmax = Ls[-1]
1✔
1593
        Lmin = Ls[0]
1✔
1594

1595
        # all points the same, stop
1596
        if Lmax - Lmin < Lepsilon:
1✔
1597
            return np.nan, np.nan
1✔
1598

1599
        # max remainder contribution is Lmax + weight, to be added to main_iterator.logZ
1600
        # the likelihood that would add an equal amount as main_iterator.logZ is:
1601
        logZmax = main_iterator.logZremain
1✔
1602
        Lnext = logZmax - (main_iterator.logVolremaining + log(frac_remain)) - log(len(Ls))
1✔
1603
        L1 = Ls[1] if len(Ls) > 1 else Ls[0]
1✔
1604
        Lmax1 = np.median(Ls)
1✔
1605
        Lnext = max(min(Lnext, Lmax1), L1)
1✔
1606

1607
        # if the remainder dominates, return that range
1608
        if main_iterator.logZremain > main_iterator.logZ:
1✔
1609
            return Lmin, Lnext
1✔
1610

1611
        if main_iterator.remainder_fraction > frac_remain:
1✔
1612
            return Lmin, Lnext
1✔
1613

1614
        return np.nan, np.nan
1✔
1615

1616
    def _find_strategy(self, saved_logl, main_iterator, dlogz, dKL, min_ess):
1✔
1617
        """Ask each strategy which log-likelihood interval needs more exploration.
1618

1619
        Returns
1620
        -------
1621
        (Llo_Z, Lhi_Z): floats
1622
            interval where dlogz strategy requires more samples.
1623
        (Llo_KL, Lhi_KL): floats
1624
            interval where posterior uncertainty strategy requires more samples.
1625
        (Llo_ess, Lhi_ess): floats
1626
            interval where effective sample strategy requires more samples.
1627

1628
        Parameters
1629
        ----------
1630
        saved_logl: array of float
1631
            loglikelihood values in integration
1632
        main_iterator: BreadthFirstIterator
1633
            current tree exploration iterator
1634
        dlogz: float
1635
            required logZ accuracy (smaller is stricter)
1636
        dKL: float
1637
            required Kulback-Leibler information gain between bootstrapped
1638
            nested sampling incarnations (smaller is stricter).
1639
        min_ess: float
1640
            required number of effective samples (higher is stricter).
1641

1642
        """
1643
        saved_logl = np.asarray(saved_logl)
1✔
1644
        logw = np.asarray(main_iterator.logweights) + saved_logl.reshape((-1,1)) - main_iterator.all_logZ
1✔
1645
        ref_logw = logw[:,0].reshape((-1,1))
1✔
1646
        other_logw = logw[:,1:]
1✔
1647

1648
        Llo_ess = np.inf
1✔
1649
        Lhi_ess = -np.inf
1✔
1650
        w = exp(ref_logw.flatten())
1✔
1651
        w /= w.sum()
1✔
1652
        ess = len(w) / (1.0 + ((len(w) * w - 1)**2).sum() / len(w))
1✔
1653
        if ess < min_ess:
1✔
1654
            samples = np.random.choice(len(w), p=w, size=min_ess)
1✔
1655
            Llo_ess = saved_logl[samples].min()
1✔
1656
            Lhi_ess = saved_logl[samples].max()
1✔
1657
        if self.log and Lhi_ess > Llo_ess:
1✔
1658
            self.logger.info("Effective samples strategy wants to improve: %.2f..%.2f (ESS = %.1f, need >%d)",
1✔
1659
                             Llo_ess, Lhi_ess, ess, min_ess)
1660
        elif self.log and min_ess > 0:
1✔
1661
            self.logger.info("Effective samples strategy satisfied (ESS = %.1f, need >%d)",
1✔
1662
                             ess, min_ess)
1663

1664
        # compute KL divergence
1665
        with np.errstate(invalid='ignore'):
1✔
1666
            KL = np.where(np.isfinite(other_logw), exp(other_logw) * (other_logw - ref_logw), 0)
1✔
1667
        KLtot = KL.sum(axis=0)
1✔
1668
        dKLtot = np.abs(KLtot - KLtot.mean())
1✔
1669
        p = np.where(KL > 0, KL, 0)
1✔
1670
        p /= p.sum(axis=0).reshape((1, -1))
1✔
1671

1672
        Llo_KL = np.inf
1✔
1673
        Lhi_KL = -np.inf
1✔
1674
        for i, (pi, dKLi, logwi) in enumerate(zip(p.transpose(), dKLtot, other_logw)):
1✔
1675
            if dKLi > dKL:
1!
1676
                ilo, ihi = _get_cumsum_range(pi, 1. / 400)
×
1677
                # ilo and ihi are most likely missing in this iterator
1678
                # --> select the one before/after in this iterator
1679
                ilos = np.where(np.isfinite(logwi[:ilo]))[0]
×
1680
                ihis = np.where(np.isfinite(logwi[ihi:]))[0]
×
1681
                ilo2 = ilos[-1] if len(ilos) > 0 else 0
×
1682
                ihi2 = (ihi + ihis[0]) if len(ihis) > 0 else -1
×
1683
                # self.logger.info('   - KL[%d] = %.2f: need to improve near %.2f..%.2f --> %.2f..%.2f' % (
1684
                #  i, dKLi, saved_logl[ilo], saved_logl[ihi], saved_logl[ilo2], saved_logl[ihi2]))
1685
                Llo_KL = min(Llo_KL, saved_logl[ilo2])
×
1686
                Lhi_KL = max(Lhi_KL, saved_logl[ihi2])
×
1687

1688
        if self.log and Lhi_KL > Llo_KL:
1!
1689
            self.logger.info("Posterior uncertainty strategy wants to improve: %.2f..%.2f (KL: %.2f+-%.2f nat, need <%.2f nat)",
×
1690
                             Llo_KL, Lhi_KL, KLtot.mean(), dKLtot.max(), dKL)
1691
        elif self.log:
1!
1692
            self.logger.info("Posterior uncertainty strategy is satisfied (KL: %.2f+-%.2f nat, need <%.2f nat)",
1✔
1693
                             KLtot.mean(), dKLtot.max(), dKL)
1694

1695
        Nlive_min = 0
1✔
1696
        p = exp(logw)
1✔
1697
        p /= p.sum(axis=0).reshape((1, -1))
1✔
1698
        deltalogZ = np.abs(main_iterator.all_logZ[1:] - main_iterator.logZ)
1✔
1699

1700
        tail_fraction = w[np.asarray(main_iterator.istail)].sum() / w.sum()
1✔
1701
        logzerr_tail = logaddexp(log(tail_fraction) + main_iterator.logZ, main_iterator.logZ) - main_iterator.logZ
1✔
1702
        maxlogzerr = max(main_iterator.logZerr, deltalogZ.max(), main_iterator.logZerr_bs)
1✔
1703
        if maxlogzerr > dlogz:
1✔
1704
            if logzerr_tail > maxlogzerr:
1!
1705
                if self.log:
×
1706
                    self.logger.info("logz error is dominated by tail. Decrease frac_remain to make progress.")
×
1707
            # very convervative estimation using all iterations
1708
            # this punishes short intervals with many live points
1709
            niter_max = len(saved_logl)
1✔
1710
            Nlive_min = int(np.ceil(niter_max**0.5 / dlogz))
1✔
1711
            if self.log:
1!
1712
                self.logger.debug("  conservative estimate says at least %d live points are needed to reach dlogz goal", Nlive_min)
1✔
1713

1714
            # better estimation:
1715

1716
            # get only until where logz bulk is (random sample here)
1717
            itmax = np.random.choice(len(w), p=w)
1✔
1718
            # back out nlive sequence (width changed by (1 - exp(-1/N))*(exp(-1/N)) )
1719
            logweights = np.array(main_iterator.logweights[:itmax])
1✔
1720
            with np.errstate(divide='ignore', invalid='ignore'):
1✔
1721
                widthratio = 1 - np.exp(logweights[1:,0] - logweights[:-1,0])
1✔
1722
                nlive = 1. / np.log((1 - np.sqrt(1 - 4 * widthratio)) / (2 * widthratio))
1✔
1723
                nlive[~np.logical_and(np.isfinite(nlive), nlive > 1)] = 1
1✔
1724

1725
            # build iteration groups
1726
            nlive_sets, niter = np.unique(nlive.astype(int), return_counts=True)
1✔
1727
            if self.log:
1!
1728
                self.logger.debug(
1✔
1729
                    "  number of live points vary between %.0f and %.0f, most (%d/%d iterations) have %d",
1730
                    nlive.min(), nlive.max(), niter.max(), itmax, nlive_sets[niter.argmax()])
1731
            for nlive_floor in nlive_sets:
1✔
1732
                # estimate error if this was the minimum nlive applied
1733
                nlive_adjusted = np.where(nlive_sets < nlive_floor, nlive_floor, nlive_sets)
1✔
1734
                deltalogZ_expected = (niter / nlive_adjusted**2.0).sum()**0.5
1✔
1735
                if deltalogZ_expected < dlogz:
1✔
1736
                    # achievable with Nlive_min
1737
                    Nlive_min = int(nlive_floor)
1✔
1738
                    if self.log:
1!
1739
                        self.logger.debug("  at least %d live points are needed to reach dlogz goal", Nlive_min)
1✔
1740
                    break
1✔
1741

1742
        if self.log and Nlive_min > 0:
1✔
1743
            self.logger.info(
1✔
1744
                "Evidency uncertainty strategy wants %d minimum live points (dlogz from %.2f to %.2f, need <%s)",
1745
                Nlive_min, deltalogZ.mean(), deltalogZ.max(), dlogz)
1746
        elif self.log:
1!
1747
            self.logger.info(
1✔
1748
                "Evidency uncertainty strategy is satisfied (dlogz=%.2f, need <%s)",
1749
                (main_iterator.logZerr_bs**2 + logzerr_tail**2)**0.5, dlogz)
1750
        if self.log:
1!
1751
            self.logger.info(
1✔
1752
                '  logZ error budget: single: %.2f bs:%.2f tail:%.2f total:%.2f required:<%.2f',
1753
                main_iterator.logZerr, main_iterator.logZerr_bs, logzerr_tail,
1754
                (main_iterator.logZerr_bs**2 + logzerr_tail**2)**0.5, dlogz)
1755

1756
        return Nlive_min, (Llo_KL, Lhi_KL), (Llo_ess, Lhi_ess)
1✔
1757

1758
    def _refill_samples(self, Lmin, ndraw, nit):
1✔
1759
        """Get new samples from region."""
1760
        nc = 0
1✔
1761
        u = self.region.sample(nsamples=ndraw)
1✔
1762
        assert np.logical_and(u > 0, u < 1).all(), (u)
1✔
1763
        nu = u.shape[0]
1✔
1764
        if nu == 0:
1✔
1765
            v = np.empty((0, self.num_params))
1✔
1766
            logl = np.empty((0,))
1✔
1767
            accepted = np.empty(0, dtype=bool)
1✔
1768
        else:
1769
            if nu > 1 and not self.draw_multiple:
1✔
1770
                # peel off first if multiple evaluation is not supported
1771
                nu = 1
1✔
1772
                u = u[:1,:]
1✔
1773

1774
            v = self.transform(u)
1✔
1775
            logl = np.ones(nu) * -np.inf
1✔
1776

1777
            if self.tregion is not None:
1✔
1778
                # check wrapping ellipsoid in transformed space
1779
                accepted = self.tregion.inside(v)
1✔
1780
                nt = accepted.sum()
1✔
1781
            else:
1782
                # if undefined, all pass; rarer branch
1783
                accepted = np.ones(nu, dtype=bool)
1✔
1784
                nt = nu
1✔
1785

1786
            if nt > 0:
1✔
1787
                logl[accepted] = self.loglike(v[accepted, :])
1✔
1788
                nc += nt
1✔
1789
            accepted = logl > Lmin
1✔
1790

1791
            # print("it: %4d ndraw: %d -> %d -> %d -> %d " % (nit, ndraw, nu, nt, accepted.sum()))
1792

1793
        if not self.sampling_slow_warned and nit * ndraw >= 100000 and nit > 20:
1!
1794
            warning_message1 = ("Sampling from region seems inefficient (%d/%d accepted in iteration %d). " % (accepted.sum(), ndraw, nit))
×
1795
            warning_message2 = "To improve efficiency, modify the transformation so that the current live points%s are ellipsoidal, " + \
×
1796
                "or use a stepsampler, or set frac_remain to a lower number (e.g., 0.5) to terminate earlier."
1797
            if self.log_to_disk:
×
1798
                debug_filename = os.path.join(self.logs['extra'], 'sampling-stuck-it%d')
×
1799
                np.savez(
×
1800
                    debug_filename + '.npz',
1801
                    u=self.region.u, unormed=self.region.unormed,
1802
                    maxradiussq=self.region.maxradiussq,
1803
                    sample_u=u, sample_v=v, sample_logl=logl)
1804
                np.savetxt(debug_filename + '.csv', self.region.u, delimiter=',')
×
1805
                warning_message = warning_message1 + (warning_message2 % (' (stored for you in %s.csv)' % debug_filename))
×
1806
            else:
1807
                warning_message = warning_message1 + warning_message2 % ''
×
1808
            warnings.warn(warning_message)
×
1809
            logl_region = self.loglike(self.transform(self.region.u))
×
1810
            if (logl_region == Lmin).all():
×
1811
                raise ValueError(
×
1812
                    "Region cannot sample a higher point. "
1813
                    "All remaining live points have the same value.")
1814
            if not (logl_region > Lmin).any():
×
1815
                raise ValueError(
×
1816
                    "Region cannot sample a higher point. "
1817
                    "Perhaps you are resuming from a different problem?"
1818
                    "Delete the output files and start again.")
1819
            self.sampling_slow_warned = True
×
1820

1821
        self.ncall_region += ndraw
1✔
1822
        return u[accepted,:], v[accepted,:], logl[accepted], nc, 0
1✔
1823

1824
    def _create_point(self, Lmin, ndraw, active_u, active_values):
1✔
1825
        """Draw a new point above likelihood threshold `Lmin`.
1826

1827
        Parameters
1828
        -----------
1829
        Lmin: float
1830
            loglikelihood threshold to draw above
1831
        ndraw: float
1832
            number of points to try to sample at once
1833
        active_u: array of floats
1834
            current live points
1835
        active_values: array
1836
            loglikelihoods of current live points
1837

1838
        """
1839
        if self.stepsampler is None:
1✔
1840
            assert self.region.inside(active_u).any(), \
1✔
1841
                ("None of the live points satisfies the current region!",
1842
                 self.region.maxradiussq, self.region.u, self.region.unormed, active_u,
1843
                 getattr(self.region, 'bbox_lo'),
1844
                 getattr(self.region, 'bbox_hi'),
1845
                 getattr(self.region, 'ellipsoid_cov'),
1846
                 getattr(self.region, 'ellipsoid_center'),
1847
                 getattr(self.region, 'ellipsoid_invcov'),
1848
                 getattr(self.region, 'ellipsoid_cov'),
1849
                 )
1850

1851
        nit = 0
1✔
1852
        while True:
1853
            ib = self.ib
1✔
1854
            if ib >= len(self.samples) and self.use_point_stack:
1✔
1855
                # root checks the point store
1856
                next_point = np.zeros((1, 3 + self.x_dim + self.num_params)) * np.nan
1✔
1857

1858
                if self.log_to_pointstore:
1!
1859
                    _, stored_point = self.pointstore.pop(Lmin)
1✔
1860
                    if stored_point is not None:
1!
1861
                        next_point[0,:] = stored_point
1✔
1862
                    else:
1863
                        next_point[0,:] = -np.inf
×
1864
                    self.use_point_stack = not self.pointstore.stack_empty
1✔
1865

1866
                if self.use_mpi:  # and informs everyone
1!
1867
                    self.use_point_stack = self.comm.bcast(self.use_point_stack, root=0)
×
1868
                    next_point = self.comm.bcast(next_point, root=0)
×
1869

1870
                # unpack
1871
                self.likes = next_point[:,1]
1✔
1872
                self.samples = next_point[:,3:3 + self.x_dim]
1✔
1873
                self.samplesv = next_point[:,3 + self.x_dim:3 + self.x_dim + self.num_params]
1✔
1874
                # skip if we already know it is not useful
1875
                ib = 0 if np.isfinite(self.likes[0]) else 1
1✔
1876

1877
            use_stepsampler = self.stepsampler is not None
1✔
1878
            while ib >= len(self.samples):
1✔
1879
                ib = 0
1✔
1880
                if use_stepsampler:
1✔
1881
                    u, v, logl, nc = self.stepsampler.__next__(
1✔
1882
                        self.region,
1883
                        transform=self.transform, loglike=self.loglike,
1884
                        Lmin=Lmin, us=active_u, Ls=active_values,
1885
                        ndraw=ndraw, tregion=self.tregion)
1886
                    quality = self.stepsampler.nsteps
1✔
1887
                else:
1888
                    u, v, logl, nc, quality = self._refill_samples(Lmin, ndraw, nit)
1✔
1889
                nit += 1
1✔
1890

1891
                if logl is None:
1✔
1892
                    u = np.empty((0, self.x_dim))
1✔
1893
                    v = np.empty((0, self.num_params))
1✔
1894
                    logl = np.empty((0,))
1✔
1895
                elif u.ndim == 1:
1✔
1896
                    assert np.logical_and(u > 0, u < 1).all(), (u)
1✔
1897
                    u = u.reshape((1, self.x_dim))
1✔
1898
                    v = v.reshape((1, self.num_params))
1✔
1899
                    logl = logl.reshape((1,))
1✔
1900

1901
                if self.use_mpi:
1!
1902
                    recv_samples = self.comm.gather(u, root=0)
×
1903
                    recv_samplesv = self.comm.gather(v, root=0)
×
1904
                    recv_likes = self.comm.gather(logl, root=0)
×
1905
                    recv_nc = self.comm.gather(nc, root=0)
×
1906
                    recv_samples = self.comm.bcast(recv_samples, root=0)
×
1907
                    recv_samplesv = self.comm.bcast(recv_samplesv, root=0)
×
1908
                    recv_likes = self.comm.bcast(recv_likes, root=0)
×
1909
                    recv_nc = self.comm.bcast(recv_nc, root=0)
×
1910
                    self.samples = np.concatenate(recv_samples, axis=0)
×
1911
                    self.samplesv = np.concatenate(recv_samplesv, axis=0)
×
1912
                    self.likes = np.concatenate(recv_likes, axis=0)
×
1913
                    self.ncall += sum(recv_nc)
×
1914
                else:
1915
                    self.samples = u
1✔
1916
                    self.samplesv = v
1✔
1917
                    self.likes = logl
1✔
1918
                    self.ncall += nc
1✔
1919

1920
                if self.log:
1!
1921
                    for ui, vi, logli in zip(self.samples, self.samplesv, self.likes):
1✔
1922
                        self.pointstore.add(
1✔
1923
                            _listify([Lmin, logli, quality], ui, vi),
1924
                            self.ncall)
1925

1926
            if self.likes[ib] > Lmin:
1✔
1927
                u = self.samples[ib, :]
1✔
1928
                assert np.logical_and(u > 0, u < 1).all(), (u)
1✔
1929
                p = self.samplesv[ib, :]
1✔
1930
                logl = self.likes[ib]
1✔
1931

1932
                self.ib = ib + 1
1✔
1933
                return u, p, logl
1✔
1934
            else:
1935
                self.ib = ib + 1
1✔
1936

1937
    def _update_region(
1✔
1938
        self, active_u, active_node_ids,
1939
        bootstrap_rootids=None, active_rootids=None,
1940
        nbootstraps=30, minvol=0., active_p=None
1941
    ):
1942
        """Build a new MLFriends region from `active_u`, and wrapping ellipsoid.
1943

1944
        Both are safely built using bootstrapping, so that the
1945
        region can be used for sampling and rejecting points.
1946
        If MPI is enabled, this computation is parallelised.
1947

1948
        If active_p is not None, a wrapping ellipsoid is built also
1949
        in the user-transformed parameter space.
1950

1951
        Parameters
1952
        -----------
1953
        active_u: array of floats
1954
            current live points
1955
        active_node_ids: 2d array of ints
1956
            which bootstrap initialisation the points belong to.
1957
        active_rootids: 2d array of ints
1958
            roots active in each bootstrap initialisation
1959
        bootstrap_rootids: array of ints
1960
            bootstrap samples. if None, they are drawn fresh.
1961
        nbootstraps: int
1962
            number of bootstrap rounds
1963
        active_p: array of floats
1964
            current live points, in user-transformed space
1965
        minvol: float
1966
            expected current minimum volume of region.
1967

1968
        Returns
1969
        --------
1970
        updated: bool
1971
            True if update was made, False if previous region remained.
1972

1973
        """
1974
        assert nbootstraps > 0
1✔
1975
        updated = False
1✔
1976
        if self.region is None:
1✔
1977
            # if self.log:
1978
            #    self.logger.debug("building first region ...")
1979
            self.transformLayer = self.transform_layer_class(wrapped_dims=self.wrapped_axes)
1✔
1980
            self.transformLayer.optimize(active_u, active_u, minvol=minvol)
1✔
1981
            self.region = self.region_class(active_u, self.transformLayer)
1✔
1982
            self.region_nodes = active_node_ids.copy()
1✔
1983
            assert self.region.maxradiussq is None
1✔
1984

1985
            _update_region_bootstrap(self.region, nbootstraps, minvol, self.comm if self.use_mpi else None, self.mpi_size)
1✔
1986
            self.region.create_ellipsoid(minvol=minvol)
1✔
1987
            # if self.log:
1988
            #     self.logger.debug("building first region ... r=%e, f=%e" % (r, f))
1989
            updated = True
1✔
1990

1991
            # verify correctness:
1992
            # self.region.create_ellipsoid(minvol=minvol)
1993
            # assert self.region.inside(active_u).all(), self.region.inside(active_u).mean()
1994

1995
        assert self.transformLayer is not None
1✔
1996
        need_accept = False
1✔
1997

1998
        if self.region.maxradiussq is None:
1✔
1999
            # we have been told that radius is currently invalid
2000
            # we need to bootstrap back to a valid state
2001

2002
            # compute radius given current transformLayer
2003
            oldu = self.region.u
1✔
2004
            self.region.u = active_u
1✔
2005
            self.region_nodes = active_node_ids.copy()
1✔
2006
            self.region.set_transformLayer(self.transformLayer)
1✔
2007

2008
            _update_region_bootstrap(self.region, nbootstraps, minvol, self.comm if self.use_mpi else None, self.mpi_size)
1✔
2009

2010
            # print("made first region, r=%e" % (r))
2011

2012
            # now that we have r, can do clustering
2013
            # but such reclustering would forget the cluster ids
2014

2015
            # instead, track the clusters from before by matching manually
2016
            oldt = self.transformLayer.transform(oldu)
1✔
2017
            clusterids = np.zeros(len(active_u), dtype=int)
1✔
2018
            nnearby = np.empty(len(self.region.unormed), dtype=int)
1✔
2019
            for ci in np.unique(self.transformLayer.clusterids):
1✔
2020
                if ci == 0:
1✔
2021
                    continue
1✔
2022

2023
                # find points from that cluster
2024
                oldti = oldt[self.transformLayer.clusterids == ci]
1✔
2025
                # identify which new points are near this cluster
2026
                find_nearby(oldti, self.region.unormed, self.region.maxradiussq, nnearby)
1✔
2027
                mask = nnearby != 0
1✔
2028
                # assign the nearby ones to this cluster
2029
                # if they have not been set yet
2030
                # if they have, set them to -1
2031
                clusterids[mask] = np.where(clusterids[mask] == 0, ci, -1)
1✔
2032

2033
            # clusters we are unsure about (double assignments) go unassigned
2034
            clusterids[clusterids == -1] = 0
1✔
2035

2036
            # tell scaling layer the correct cluster information
2037
            self.transformLayer.clusterids = clusterids
1✔
2038

2039
            # we want the clustering to repeat to remove remaining zeros
2040
            need_accept = (self.transformLayer.clusterids == 0).any()
1✔
2041

2042
            updated = True
1✔
2043
            assert len(self.region.u) == len(self.transformLayer.clusterids)
1✔
2044

2045
            # verify correctness:
2046
            self.region.create_ellipsoid(minvol=minvol)
1✔
2047
            # assert self.region.inside(active_u).all(), self.region.inside(active_u).mean()
2048

2049
        assert len(self.region.u) == len(self.transformLayer.clusterids)
1✔
2050
        # rebuild space
2051
        with warnings.catch_warnings(), np.errstate(all='raise'):
1✔
2052
            try:
1✔
2053
                nextTransformLayer = self.transformLayer.create_new(active_u, self.region.maxradiussq, minvol=minvol)
1✔
2054
                assert not (nextTransformLayer.clusterids == 0).any()
1✔
2055
                _, cluster_sizes = np.unique(nextTransformLayer.clusterids, return_counts=True)
1✔
2056
                smallest_cluster = cluster_sizes.min()
1✔
2057
                if self.log and smallest_cluster == 1:
1✔
2058
                    self.logger.debug(
1✔
2059
                        "clustering found some stray points [need_accept=%s] %s",
2060
                        need_accept,
2061
                        np.unique(nextTransformLayer.clusterids, return_counts=True)
2062
                    )
2063

2064
                nextregion = self.region_class(active_u, nextTransformLayer)
1✔
2065
                assert np.isfinite(nextregion.unormed).all()
1✔
2066

2067
                if not nextTransformLayer.nclusters < 20:
1!
2068
                    if self.log:
×
2069
                        self.logger.info(
×
2070
                            "Found a lot of clusters: %d (%d with >1 members)",
2071
                            nextTransformLayer.nclusters, (cluster_sizes > 1).sum())
2072

2073
                # if self.log:
2074
                #     self.logger.info("computing maxradius...")
2075
                r, f = _update_region_bootstrap(nextregion, nbootstraps, minvol, self.comm if self.use_mpi else None, self.mpi_size)
1✔
2076
                # verify correctness:
2077
                nextregion.create_ellipsoid(minvol=minvol)
1✔
2078

2079
                # check if live points are numerically colliding or linearly dependent
2080
                self.live_points_healthy = len(active_u) > self.x_dim and \
1✔
2081
                    np.all(np.sum(active_u[1:] != active_u[0], axis=0) > self.x_dim) and \
2082
                    np.linalg.matrix_rank(nextregion.ellipsoid_cov) == self.x_dim
2083

2084
                assert (nextregion.u == active_u).all()
1✔
2085
                assert np.allclose(nextregion.unormed, nextregion.transformLayer.transform(active_u))
1✔
2086
                # assert nextregion.inside(active_u).all(),
2087
                #  ("live points should live in new region, but only %.3f%% do." % (100 * nextregion.inside(active_u).mean()), active_u)
2088
                good_region = nextregion.inside(active_u).all()
1✔
2089
                # assert good_region
2090
                if not good_region and self.log:
1✔
2091
                    self.logger.debug("Proposed region is inconsistent (maxr=%g,enlarge=%g) and will be skipped.", r, f)
1✔
2092

2093
                # avoid cases where every point is its own cluster,
2094
                # and even the largest cluster has fewer than x_dim points
2095
                sensible_clustering = nextTransformLayer.nclusters < len(nextregion.u) \
1✔
2096
                    and cluster_sizes.max() >= nextregion.u.shape[1]
2097

2098
                # force shrinkage of volume. avoids reconnecting dying modes
2099
                if good_region and \
1✔
2100
                        (need_accept or nextregion.estimate_volume() <= self.region.estimate_volume()) \
2101
                        and sensible_clustering:
2102
                    self.region = nextregion
1✔
2103
                    self.transformLayer = self.region.transformLayer
1✔
2104
                    self.region_nodes = active_node_ids.copy()
1✔
2105
                    updated = True
1✔
2106

2107
                    assert not (self.transformLayer.clusterids == 0).any(), (self.transformLayer.clusterids, need_accept, updated)
1✔
2108

2109
            except Warning:
×
2110
                if self.log:
×
2111
                    self.logger.debug("not updating region", exc_info=True)
×
2112
            except FloatingPointError:
×
2113
                if self.log:
×
2114
                    self.logger.debug("not updating region", exc_info=True)
×
2115
            except np.linalg.LinAlgError:
×
2116
                if self.log:
×
2117
                    self.logger.debug("not updating region", exc_info=True)
×
2118

2119
        assert len(self.region.u) == len(self.transformLayer.clusterids)
1✔
2120

2121
        if active_p is None or not self.build_tregion:
1✔
2122
            self.tregion = None
1✔
2123
        else:
2124
            try:
1✔
2125
                with np.errstate(invalid='raise'):
1✔
2126
                    tregion = WrappingEllipsoid(active_p)
1✔
2127
                    f = tregion.compute_enlargement(
1✔
2128
                        nbootstraps=max(1, nbootstraps // self.mpi_size))
2129
                    if self.use_mpi:
1!
2130
                        recv_enlarge = self.comm.gather(f, root=0)
×
2131
                        recv_enlarge = self.comm.bcast(recv_enlarge, root=0)
×
2132
                        f = np.max(recv_enlarge)
×
2133
                    tregion.enlarge = f
1✔
2134
                    tregion.create_ellipsoid()
1✔
2135
                    self.tregion = tregion
1✔
2136
            except FloatingPointError:
×
2137
                if self.log:
×
2138
                    self.logger.debug("not updating t-ellipsoid", exc_info=True)
×
2139
                    self.tregion = None
×
2140
            except np.linalg.LinAlgError:
×
2141
                if self.log:
×
2142
                    self.logger.debug("not updating t-ellipsoid", exc_info=True)
×
2143
                    self.tregion = None
×
2144

2145
        return updated
1✔
2146

2147
    def _expand_nodes_before(self, Lmin, nnodes_needed, update_interval_ncall):
1✔
2148
        """Expand nodes before `Lmin` to have `nnodes_needed`.
2149

2150
        Returns
2151
        --------
2152
        Llo: float
2153
            lowest parent sampled (-np.inf if sampling from root)
2154
        Lhi: float
2155
            Lmin
2156
        target_min_num_children: int
2157
            number of children that need to be maintained between Llo, Lhi
2158

2159
        """
2160
        self.pointstore.reset()
×
2161
        parents, weights = find_nodes_before(self.root, Lmin)
×
2162
        target_min_num_children = self._widen_nodes(parents, weights, nnodes_needed, update_interval_ncall)
×
2163
        if len(parents) == 0:
×
2164
            Llo = -np.inf
×
2165
        else:
2166
            Llo = min(n.value for n in parents)
×
2167
        Lhi = Lmin
×
2168
        return Llo, Lhi, target_min_num_children
×
2169

2170
    def _should_node_be_expanded(
1✔
2171
        self, it, Llo, Lhi, minimal_widths_sequence, target_min_num_children,
2172
        node, parallel_values, max_ncalls, max_iters, live_points_healthy
2173
    ):
2174
        """Check if node needs new children.
2175

2176
        Returns
2177
        -------
2178
        expand_node: bool
2179
            True if should sample a new point
2180
            based on this node (above its likelihood value Lmin).
2181

2182
        Parameters
2183
        ----------
2184
        it: int
2185
            current iteration
2186
        node: node
2187
            The node to consider
2188
        parallel_values: array of floats
2189
            loglikelihoods of live points
2190
        max_ncalls: int
2191
            maximum number of likelihood function calls allowed
2192
        max_iters: int
2193
            maximum number of nested sampling iteration allowed
2194
        Llo: float
2195
            lower loglikelihood bound for the strategy
2196
        Lhi: float
2197
            upper loglikelihood bound for the strategy
2198
        minimal_widths_sequence: list
2199
            list of likelihood intervals with minimum number of live points
2200
        target_min_num_children:
2201
            minimum number of live points currently targeted
2202
        live_points_healthy: bool
2203
            indicates whether the live points have become
2204
            linearly dependent (covariance not full rank)
2205
            or have attained the same exact value in some parameter.
2206

2207
        """
2208
        Lmin = node.value
1✔
2209
        nlive = len(parallel_values)
1✔
2210

2211
        if not (Lmin <= Lhi and Llo <= Lhi):
1✔
2212
            return False
1✔
2213

2214
        if not live_points_healthy:
1!
2215
            if self.log:
×
2216
                self.logger.debug("not expanding, because live points are linearly dependent")
×
2217
            return False
×
2218

2219
        # some reasons to stop:
2220
        if it > 0:
1✔
2221
            if max_ncalls is not None and self.ncall >= max_ncalls:
1!
2222
                # print("not expanding, because above max_ncall")
2223
                return False
×
2224

2225
            if max_iters is not None and it >= max_iters:
1✔
2226
                # print("not expanding, because above max_iters")
2227
                return False
1✔
2228

2229
        # in a plateau, only shrink (Fowlie+2020)
2230
        if (Lmin == parallel_values).sum() > 1:
1✔
2231
            if self.log:
1!
2232
                self.logger.debug("Plateau detected at L=%e, not replacing live point." % Lmin)
1✔
2233
            return False
1✔
2234

2235
        expand_node = False
1✔
2236
        # we should continue to progress towards Lhi
2237
        while Lmin > minimal_widths_sequence[0][0]:
1✔
2238
            minimal_widths_sequence.pop(0)
1✔
2239

2240
        # get currently desired width
2241
        if self.region is None:
1✔
2242
            minimal_width_clusters = 0
1✔
2243
        else:
2244
            # compute number of clusters with more than 1 element
2245
            _, cluster_sizes = np.unique(self.region.transformLayer.clusterids, return_counts=True)
1✔
2246
            nclusters = (cluster_sizes > 1).sum()
1✔
2247
            minimal_width_clusters = self.cluster_num_live_points * nclusters
1✔
2248

2249
        minimal_width = max(minimal_widths_sequence[0][1], minimal_width_clusters)
1✔
2250

2251
        # if already has children, no need to expand
2252
        # if we are wider than the width required
2253
        # we do not need to expand this one
2254
        # expand_node = len(node.children) == 0
2255
        # prefer 1 child, or the number required, if specified
2256
        nmin = target_min_num_children.get(node.id, 1) if target_min_num_children else 1
1✔
2257
        expand_node = len(node.children) < nmin
1✔
2258
        # print("not expanding, because we are quite wide", nlive, minimal_width, minimal_widths_sequence)
2259
        # but we have to expand the first iteration,
2260
        # otherwise the integrator never sets H
2261
        too_wide = nlive > minimal_width and it > 0
1✔
2262

2263
        return expand_node and not too_wide
1✔
2264

2265
    def run(
1✔
2266
            self,
2267
            update_interval_volume_fraction=0.8,
2268
            update_interval_ncall=None,
2269
            log_interval=None,
2270
            show_status=True,
2271
            viz_callback='auto',
2272
            dlogz=0.5,
2273
            dKL=0.5,
2274
            frac_remain=0.01,
2275
            Lepsilon=0.001,
2276
            min_ess=400,
2277
            max_iters=None,
2278
            max_ncalls=None,
2279
            max_num_improvement_loops=-1,
2280
            min_num_live_points=400,
2281
            cluster_num_live_points=40,
2282
            insertion_test_window=10,
2283
            insertion_test_zscore_threshold=4,
2284
            region_class=MLFriends,
2285
            widen_before_initial_plateau_num_warn=10000,
2286
            widen_before_initial_plateau_num_max=50000,
2287
    ):
2288
        """Run until target convergence criteria are fulfilled.
2289

2290
        Parameters
2291
        ----------
2292
        update_interval_volume_fraction: float
2293
            Update region when the volume shrunk by this amount.
2294

2295
        update_interval_ncall: int
2296
            Update region after update_interval_ncall likelihood calls (not used).
2297

2298
        log_interval: int
2299
            Update stdout status line every log_interval iterations
2300

2301
        show_status: bool
2302
            show integration progress as a status line.
2303
            If no output desired, set to False.
2304

2305
        viz_callback: function
2306
            callback function when region was rebuilt. Allows to
2307
            show current state of the live points.
2308
            See :py:func:`nicelogger` or :py:class:`LivePointsWidget`.
2309
            If no output desired, set to False.
2310

2311
        dlogz: float
2312
            Target evidence uncertainty. This is the std
2313
            between bootstrapped logz integrators.
2314

2315
        dKL: float
2316
            Target posterior uncertainty. This is the
2317
            Kullback-Leibler divergence in nat between bootstrapped integrators.
2318

2319
        frac_remain: float
2320
            Integrate until this fraction of the integral is left in the remainder.
2321
            Set to a low number (1e-2 ... 1e-5) to make sure peaks are discovered.
2322
            Set to a higher number (0.5) if you know the posterior is simple.
2323

2324
        Lepsilon: float
2325
            Terminate when live point likelihoods are all the same,
2326
            within Lepsilon tolerance. Increase this when your likelihood
2327
            function is inaccurate, to avoid unnecessary search.
2328

2329
        min_ess: int
2330
            Target number of effective posterior samples.
2331

2332
        max_iters: int
2333
            maximum number of integration iterations.
2334

2335
        max_ncalls: int
2336
            stop after this many likelihood evaluations.
2337

2338
        max_num_improvement_loops: int
2339
            run() tries to assess iteratively where more samples are needed.
2340
            This number limits the number of improvement loops.
2341

2342
        min_num_live_points: int
2343
            minimum number of live points throughout the run
2344

2345
        cluster_num_live_points: int
2346
            require at least this many live points per detected cluster
2347

2348
        insertion_test_zscore_threshold: float
2349
            z-score used as a threshold for the insertion order test.
2350
            Set to infinity to disable.
2351

2352
        insertion_test_window: int
2353
            Number of iterations after which the insertion order test is reset.
2354

2355
        region_class: :py:class:`MLFriends` or :py:class:`RobustEllipsoidRegion` or :py:class:`SimpleRegion`
2356
            Whether to use MLFriends+ellipsoidal+tellipsoidal region (better for multi-modal problems)
2357
            or just ellipsoidal sampling (faster for high-dimensional, gaussian-like problems)
2358
            or a axis-aligned ellipsoid (fastest, to be combined with slice sampling).
2359

2360
        widen_before_initial_plateau_num_warn: int
2361
            If a likelihood plateau is encountered, increase the number
2362
            of initial live points so that once the plateau is traversed,
2363
            *min_num_live_points* live points remain.
2364
            If the number exceeds *widen_before_initial_plateau_num_warn*,
2365
            a warning is raised.
2366

2367
        widen_before_initial_plateau_num_max: int
2368
            If a likelihood plateau is encountered, increase the number
2369
            of initial live points so that once the plateau is traversed,
2370
            *min_num_live_points* live points remain, but not more than
2371
            *widen_before_initial_plateau_num_warn*.
2372
        """
2373
        for result in self.run_iter(
1✔
2374
            update_interval_volume_fraction=update_interval_volume_fraction,
2375
            update_interval_ncall=update_interval_ncall,
2376
            log_interval=log_interval,
2377
            dlogz=dlogz, dKL=dKL,
2378
            Lepsilon=Lepsilon, frac_remain=frac_remain,
2379
            min_ess=min_ess, max_iters=max_iters,
2380
            max_ncalls=max_ncalls, max_num_improvement_loops=max_num_improvement_loops,
2381
            min_num_live_points=min_num_live_points,
2382
            cluster_num_live_points=cluster_num_live_points,
2383
            show_status=show_status,
2384
            viz_callback=viz_callback,
2385
            insertion_test_window=insertion_test_window,
2386
            insertion_test_zscore_threshold=insertion_test_zscore_threshold,
2387
            region_class=region_class,
2388
            widen_before_initial_plateau_num_warn=widen_before_initial_plateau_num_warn,
2389
            widen_before_initial_plateau_num_max=widen_before_initial_plateau_num_max,
2390
        ):
2391
            if self.log:
1!
2392
                self.logger.debug("did a run_iter pass!")
1✔
2393
            pass
1✔
2394
        if self.log:
1!
2395
            self.logger.info("done iterating.")
1✔
2396

2397
        return self.results
1✔
2398

2399
    def run_iter(
1✔
2400
            self,
2401
            update_interval_volume_fraction=0.8,
2402
            update_interval_ncall=None,
2403
            log_interval=None,
2404
            dlogz=0.5,
2405
            dKL=0.5,
2406
            frac_remain=0.01,
2407
            Lepsilon=0.001,
2408
            min_ess=400,
2409
            max_iters=None,
2410
            max_ncalls=None,
2411
            max_num_improvement_loops=-1,
2412
            min_num_live_points=400,
2413
            cluster_num_live_points=40,
2414
            show_status=True,
2415
            viz_callback='auto',
2416
            insertion_test_window=10000,
2417
            insertion_test_zscore_threshold=2,
2418
            region_class=MLFriends,
2419
            widen_before_initial_plateau_num_warn=10000,
2420
            widen_before_initial_plateau_num_max=50000,
2421
    ):
2422
        """Iterate towards convergence.
2423

2424
        Use as an iterator like so::
2425

2426
            for result in sampler.run_iter(...):
2427
                print('lnZ = %(logz).2f +- %(logzerr).2f' % result)
2428

2429
        Parameters as described in run() method.
2430

2431
        Yields
2432
        ------
2433
        results: dict
2434
        """
2435
        # frac_remain=1  means 1:1 -> dlogz=log(0.5)
2436
        # frac_remain=0.1 means 1:10 -> dlogz=log(0.1)
2437
        # dlogz_min = log(1./(1 + frac_remain))
2438
        # dlogz_min = -log1p(frac_remain)
2439
        if -np.log1p(frac_remain) > dlogz:
1!
2440
            raise ValueError("To achieve the desired logz accuracy, set frac_remain to a value much smaller than %s (currently: %s)" % (
×
2441
                exp(-dlogz) - 1, frac_remain))
2442

2443
        # the error is approximately dlogz = sqrt(iterations) / Nlive
2444
        # so we need a minimum, which depends on the number of iterations
2445
        # fewer than 1000 iterations is quite unlikely
2446
        if min_num_live_points < 1000**0.5 / dlogz:
1✔
2447
            min_num_live_points = int(np.ceil(1000**0.5 / dlogz))
1✔
2448
            if self.log:
1!
2449
                self.logger.info("To achieve the desired logz accuracy, min_num_live_points was increased to %d" % (
1✔
2450
                    min_num_live_points))
2451

2452
        if self.log_to_pointstore:
1✔
2453
            if len(self.pointstore.stack) > 0:
1✔
2454
                self.logger.info("Resuming from %d stored points", len(self.pointstore.stack))
1✔
2455
            self.use_point_stack = not self.pointstore.stack_empty
1✔
2456
        else:
2457
            self.use_point_stack = False
1✔
2458

2459
        assert min_num_live_points >= cluster_num_live_points, \
1✔
2460
            ('min_num_live_points(%d) cannot be less than cluster_num_live_points(%d)' %
2461
                (min_num_live_points, cluster_num_live_points))
2462
        self.min_num_live_points = min_num_live_points
1✔
2463
        self.cluster_num_live_points = cluster_num_live_points
1✔
2464
        self.sampling_slow_warned = False
1✔
2465
        self.build_tregion = True
1✔
2466
        self.region_class = region_class
1✔
2467
        update_interval_volume_log_fraction = log(update_interval_volume_fraction)
1✔
2468

2469
        if viz_callback == 'auto':
1✔
2470
            viz_callback = get_default_viz_callback()
1✔
2471

2472
        self._widen_roots_beyond_initial_plateau(
1✔
2473
            min_num_live_points,
2474
            widen_before_initial_plateau_num_warn, widen_before_initial_plateau_num_max)
2475

2476
        Llo, Lhi = -np.inf, np.inf
1✔
2477
        Lmax = -np.inf
1✔
2478
        strategy_stale = True
1✔
2479
        minimal_widths = []
1✔
2480
        target_min_num_children = {}
1✔
2481
        improvement_it = 0
1✔
2482

2483
        assert max_iters is None or max_iters > 0, ("Invalid value for max_iters: %s. Set to None or positive number" % max_iters)
1✔
2484
        assert max_ncalls is None or max_ncalls > 0, ("Invalid value for max_ncalls: %s. Set to None or positive number" % max_ncalls)
1✔
2485

2486
        if self.log:
1!
2487
            self.logger.debug(
1✔
2488
                'run_iter dlogz=%.1f, dKL=%.1f, frac_remain=%.2f, Lepsilon=%.4f, min_ess=%d' % (
2489
                    dlogz, dKL, frac_remain, Lepsilon, min_ess)
2490
            )
2491
            self.logger.debug(
1✔
2492
                'max_iters=%d, max_ncalls=%d, max_num_improvement_loops=%d, min_num_live_points=%d, cluster_num_live_points=%d' % (
2493
                    max_iters if max_iters else -1, max_ncalls if max_ncalls else -1,
2494
                    max_num_improvement_loops, min_num_live_points, cluster_num_live_points)
2495
            )
2496

2497
        self.results = None
1✔
2498

2499
        while True:
2500
            roots = self.root.children
1✔
2501

2502
            nroots = len(roots)
1✔
2503

2504
            if update_interval_ncall is None:
1✔
2505
                update_interval_ncall = nroots
1✔
2506

2507
            if log_interval is None:
1✔
2508
                log_interval = max(1, round(0.1 * nroots))
1✔
2509
            else:
2510
                log_interval = round(log_interval)
1✔
2511
                if log_interval < 1:
1!
2512
                    raise ValueError("log_interval must be >= 1")
×
2513

2514
            explorer = BreadthFirstIterator(roots)
1✔
2515
            # Integrating thing
2516
            main_iterator = MultiCounter(
1✔
2517
                nroots=len(roots),
2518
                nbootstraps=max(1, self.num_bootstraps // self.mpi_size),
2519
                random=False, check_insertion_order=False)
2520
            main_iterator.Lmax = max(Lmax, max(n.value for n in roots))
1✔
2521
            insertion_test = UniformOrderAccumulator()
1✔
2522
            insertion_test_runs = []
1✔
2523
            insertion_test_quality = np.inf
1✔
2524
            insertion_test_direction = 0
1✔
2525

2526
            self.transformLayer = None
1✔
2527
            self.region = None
1✔
2528
            self.tregion = None
1✔
2529
            self.live_points_healthy = True
1✔
2530
            it_at_first_region = 0
1✔
2531
            self.ib = 0
1✔
2532
            self.samples = []
1✔
2533
            if self.draw_multiple:
1✔
2534
                ndraw = self.ndraw_min
1✔
2535
            else:
2536
                ndraw = 40
1✔
2537
            self.pointstore.reset()
1✔
2538
            if self.log_to_pointstore:
1✔
2539
                self.use_point_stack = not self.pointstore.stack_empty
1✔
2540
            else:
2541
                self.use_point_stack = False
1✔
2542
            if self.use_mpi:
1!
2543
                self.use_point_stack = self.comm.bcast(self.use_point_stack, root=0)
×
2544

2545
            if self.log and (np.isfinite(Llo) or np.isfinite(Lhi)):
1✔
2546
                self.logger.info("Exploring (in particular: L=%.2f..%.2f) ...", Llo, Lhi)
1✔
2547
            region_sequence = []
1✔
2548
            minimal_widths_sequence = _sequentialize_width_sequence(minimal_widths, self.min_num_live_points)
1✔
2549
            if self.log:
1!
2550
                self.logger.debug('minimal_widths_sequence: %s', minimal_widths_sequence)
1✔
2551

2552
            saved_nodeids = []
1✔
2553
            saved_logl = []
1✔
2554
            it = 0
1✔
2555
            ncall_at_run_start = self.ncall
1✔
2556
            ncall_region_at_run_start = self.ncall_region
1✔
2557
            next_update_interval_volume = 1
1✔
2558
            last_status = time.time()
1✔
2559

2560
            # we go through each live point (regardless of root) by likelihood value
2561
            while True:
2562
                next_node = explorer.next_node()
1✔
2563
                if next_node is None:
1✔
2564
                    break
1✔
2565
                rootid, node, (_, active_rootids, active_values, active_node_ids) = next_node
1✔
2566
                assert not isinstance(rootid, float)
1✔
2567
                # this is the likelihood level we have to improve upon
2568
                self.Lmin = Lmin = node.value
1✔
2569

2570
                # if within suggested range, expand
2571
                if strategy_stale or not (Lmin <= Lhi) or not np.isfinite(Lhi) or (active_values == Lmin).all():
1✔
2572
                    # check with advisor if we want to expand this node
2573
                    Llo, Lhi = self._adaptive_strategy_advice(
1✔
2574
                        Lmin, active_values, main_iterator,
2575
                        minimal_widths, frac_remain, Lepsilon=Lepsilon)
2576
                    # when we are going to the peak, numerical accuracy
2577
                    # can become an issue. We should try not to get stuck there
2578
                    strategy_stale = Lhi - Llo < max(Lepsilon, 0.01)
1✔
2579

2580
                expand_node = self._should_node_be_expanded(
1✔
2581
                    it, Llo, Lhi, minimal_widths_sequence,
2582
                    target_min_num_children, node, active_values,
2583
                    max_ncalls, max_iters, self.live_points_healthy)
2584

2585
                region_fresh = False
1✔
2586
                if expand_node:
1✔
2587
                    # sample a new point above Lmin
2588
                    active_u = self.pointpile.getu(active_node_ids)
1✔
2589
                    active_p = self.pointpile.getp(active_node_ids)
1✔
2590
                    nlive = len(active_u)
1✔
2591
                    # first we check that the region is up-to-date
2592
                    if main_iterator.logVolremaining < next_update_interval_volume:
1✔
2593
                        if self.region is None:
1✔
2594
                            it_at_first_region = it
1✔
2595
                        region_fresh = self._update_region(
1✔
2596
                            active_u=active_u, active_p=active_p, active_node_ids=active_node_ids,
2597
                            active_rootids=active_rootids,
2598
                            bootstrap_rootids=main_iterator.rootids[1:,],
2599
                            nbootstraps=self.num_bootstraps,
2600
                            minvol=exp(main_iterator.logVolremaining))
2601

2602
                        if region_fresh and self.stepsampler is not None:
1✔
2603
                            self.stepsampler.region_changed(active_values, self.region)
1✔
2604

2605
                        _, cluster_sizes = np.unique(self.region.transformLayer.clusterids, return_counts=True)
1✔
2606
                        nclusters = (cluster_sizes > 1).sum()
1✔
2607
                        region_sequence.append((Lmin, nlive, nclusters, np.max(active_values)))
1✔
2608

2609
                        # next_update_interval_ncall = self.ncall + (update_interval_ncall or nlive)
2610
                        next_update_interval_volume = main_iterator.logVolremaining + update_interval_volume_log_fraction
1✔
2611

2612
                        # provide nice output to follow what is going on
2613
                        # but skip if we are resuming
2614
                        #  and (self.ncall != ncall_at_run_start and it_at_first_region == it)
2615
                        if self.log and viz_callback:
1✔
2616
                            viz_callback(
1✔
2617
                                points=dict(u=active_u, p=active_p, logl=active_values),
2618
                                info=dict(
2619
                                    it=it, ncall=self.ncall,
2620
                                    logz=main_iterator.logZ,
2621
                                    logz_remain=main_iterator.logZremain,
2622
                                    logvol=main_iterator.logVolremaining,
2623
                                    paramnames=self.paramnames + self.derivedparamnames,
2624
                                    paramlims=self.transform_limits,
2625
                                    order_test_correlation=insertion_test_quality,
2626
                                    order_test_direction=insertion_test_direction,
2627
                                    stepsampler_info=self.stepsampler.get_info_dict() if hasattr(self.stepsampler, 'get_info_dict') else {}
2628
                                ),
2629
                                region=self.region, transformLayer=self.transformLayer,
2630
                                region_fresh=region_fresh,
2631
                            )
2632
                        if self.log:
1!
2633
                            self.pointstore.flush()
1✔
2634

2635
                    if nlive < cluster_num_live_points * nclusters and improvement_it < max_num_improvement_loops:
1!
2636
                        # make wider here
2637
                        if self.log:
×
2638
                            self.logger.info(
×
2639
                                "Found %d clusters, but only have %d live points, want %d.",
2640
                                self.region.transformLayer.nclusters, nlive,
2641
                                cluster_num_live_points * nclusters)
2642
                        break
×
2643

2644
                    # sample point
2645
                    u, p, L = self._create_point(Lmin=Lmin, ndraw=ndraw, active_u=active_u, active_values=active_values)
1✔
2646
                    child = self.pointpile.make_node(L, u, p)
1✔
2647
                    main_iterator.Lmax = max(main_iterator.Lmax, L)
1✔
2648
                    if np.isfinite(insertion_test_zscore_threshold) and nlive > 1:
1!
2649
                        insertion_test.add((active_values < L).sum(), nlive)
1✔
2650
                        if abs(insertion_test.zscore) > insertion_test_zscore_threshold:
1✔
2651
                            insertion_test_runs.append(insertion_test.N)
1✔
2652
                            insertion_test_quality = insertion_test.N
1✔
2653
                            insertion_test_direction = np.sign(insertion_test.zscore)
1✔
2654
                            insertion_test.reset()
1✔
2655
                        elif insertion_test.N > insertion_test_window:
1✔
2656
                            insertion_test_quality = np.inf
1✔
2657
                            insertion_test_direction = 0
1✔
2658
                            insertion_test.reset()
1✔
2659

2660
                    # identify which point is being replaced (from when we built the region)
2661
                    worst = np.where(self.region_nodes == node.id)[0]
1✔
2662
                    self.region_nodes[worst] = child.id
1✔
2663
                    # if we keep the region informed about the new live points
2664
                    # then the region follows the live points even if maxradius is not updated
2665
                    self.region.u[worst] = u
1✔
2666
                    self.region.unormed[worst] = self.region.transformLayer.transform(u)
1✔
2667
                    # move also the ellipsoid
2668
                    self.region.ellipsoid_center = np.mean(self.region.u, axis=0)
1✔
2669
                    if self.tregion:
1✔
2670
                        self.tregion.update_center(np.mean(active_p, axis=0))
1✔
2671

2672
                    # if we track the cluster assignment, then in the next round
2673
                    # the ids with the same members are likely to have the same id
2674
                    # this is imperfect
2675
                    # transformLayer.clusterids[worst] = transformLayer.clusterids[father[ib]]
2676
                    # so we just mark the replaced ones as "unassigned"
2677
                    self.transformLayer.clusterids[worst] = 0
1✔
2678

2679
                    node.children.append(child)
1✔
2680

2681
                    if self.log and (region_fresh or it % log_interval == 0 or time.time() > last_status + 0.1):
1✔
2682
                        last_status = time.time()
1✔
2683
                        # the number of proposals asked from region
2684
                        ncall_region_here = (self.ncall_region - ncall_region_at_run_start)
1✔
2685
                        # the number of proposals returned by the region
2686
                        ncall_here = self.ncall - ncall_at_run_start
1✔
2687
                        # the number of likelihood evaluations above threshold
2688
                        it_here = it - it_at_first_region
1✔
2689

2690
                        if show_status:
1!
2691
                            if Lmin < -1e8:
1✔
2692
                                txt = 'Z=%.1g(%.2f%%) | Like=%.2g..%.2g [%.4g..%.4g]%s| it/evals=%d/%d eff=%.4f%% N=%d \r'
1✔
2693
                            elif Llo < -1e8:
1✔
2694
                                txt = 'Z=%.1f(%.2f%%) | Like=%.2f..%.2f [%.4g..%.4g]%s| it/evals=%d/%d eff=%.4f%% N=%d \r'
1✔
2695
                            else:
2696
                                txt = 'Z=%.1f(%.2f%%) | Like=%.2f..%.2f [%.4f..%.4f]%s| it/evals=%d/%d eff=%.4f%% N=%d \r'
1✔
2697
                            sys.stdout.write(txt % (
1✔
2698
                                main_iterator.logZ, 100 * (1 - main_iterator.remainder_fraction),
2699
                                Lmin, main_iterator.Lmax, Llo, Lhi, '*' if strategy_stale else ' ', it, self.ncall,
2700
                                np.inf if ncall_here == 0 else it_here * 100 / ncall_here,
2701
                                nlive))
2702
                            sys.stdout.flush()
1✔
2703
                        self.logger.debug('iteration=%d, ncalls=%d, regioncalls=%d, ndraw=%d, logz=%.2f, remainder_fraction=%.4f%%, Lmin=%.2f, Lmax=%.2f' % (
1✔
2704
                            it, self.ncall, self.ncall_region, ndraw, main_iterator.logZ,
2705
                            100 * main_iterator.remainder_fraction, Lmin, main_iterator.Lmax))
2706

2707
                        # if efficiency becomes low, bulk-process larger arrays
2708
                        if self.draw_multiple:
1✔
2709
                            # inefficiency is the number of (region) proposals per successful number of iterations
2710
                            # but improves by parallelism (because we need only the per-process inefficiency)
2711
                            # sampling_inefficiency = (self.ncall - ncall_at_run_start + 1) / (it + 1) / self.mpi_size
2712
                            sampling_inefficiency = (ncall_region_here + 1) / (it_here + 1) / self.mpi_size
1✔
2713

2714
                            # smooth update:
2715
                            ndraw_next = 0.04 * sampling_inefficiency + ndraw * 0.96
1✔
2716
                            ndraw = max(self.ndraw_min, min(self.ndraw_max, round(ndraw_next), ndraw * 100))
1✔
2717

2718
                            if sampling_inefficiency > 100000 and it >= it_at_first_region + 10:
1!
2719
                                # if the efficiency is poor, there are enough samples in each iteration
2720
                                # to estimate the inefficiency
2721
                                ncall_at_run_start = self.ncall
×
2722
                                it_at_first_region = it
×
2723
                                ncall_region_at_run_start = self.ncall_region
×
2724

2725
                else:
2726
                    # we do not want to count iterations without work
2727
                    # otherwise efficiency becomes > 1
2728
                    it_at_first_region += 1
1✔
2729

2730
                saved_nodeids.append(node.id)
1✔
2731
                saved_logl.append(Lmin)
1✔
2732

2733
                # inform iterators (if it is their business) about the arc
2734
                main_iterator.passing_node(rootid, node, active_rootids, active_values)
1✔
2735
                if len(node.children) == 0 and self.region is not None:
1✔
2736
                    # the region radius needs to increase if nlive decreases
2737
                    # radius is not reliable, so set to inf
2738
                    # (heuristics do not work in practice)
2739
                    self.region.maxradiussq = None
1✔
2740
                    # ask for the region to be rebuilt
2741
                    next_update_interval_volume = 1
1✔
2742

2743
                it += 1
1✔
2744
                explorer.expand_children_of(rootid, node)
1✔
2745

2746
            if self.log:
1!
2747
                self.logger.info("Explored until L=%.1g  ", node.value)
1✔
2748
            # print_tree(roots[::10])
2749

2750
            self.pointstore.flush()
1✔
2751
            self._update_results(main_iterator, saved_logl, saved_nodeids)
1✔
2752
            yield self.results
1✔
2753

2754
            if max_ncalls is not None and self.ncall >= max_ncalls:
1!
2755
                if self.log:
×
2756
                    self.logger.info(
×
2757
                        'Reached maximum number of likelihood calls (%d > %d)...',
2758
                        self.ncall, max_ncalls)
2759
                break
×
2760

2761
            improvement_it += 1
1✔
2762
            if max_num_improvement_loops >= 0 and improvement_it > max_num_improvement_loops:
1✔
2763
                if self.log:
1!
2764
                    self.logger.info('Reached maximum number of improvement loops.')
1✔
2765
                break
1✔
2766

2767
            if ncall_at_run_start == self.ncall and improvement_it > 1:
1!
2768
                if self.log:
×
2769
                    self.logger.info(
×
2770
                        'No changes made. '
2771
                        'Probably the strategy was to explore in the remainder, '
2772
                        'but it is irrelevant already; try decreasing frac_remain.')
2773
                break
×
2774

2775
            Lmax = main_iterator.Lmax
1✔
2776
            if len(region_sequence) > 0:
1✔
2777
                Lmin, nlive, nclusters, Lhi = region_sequence[-1]
1✔
2778
                nnodes_needed = cluster_num_live_points * nclusters
1✔
2779
                if nlive < nnodes_needed:
1!
2780
                    Llo, _, target_min_num_children_new = self._expand_nodes_before(Lmin, nnodes_needed, update_interval_ncall or nlive)
×
2781
                    target_min_num_children.update(target_min_num_children_new)
×
2782
                    # if self.log:
2783
                    #     print_tree(self.root.children[::10])
2784
                    minimal_widths.append((Llo, Lhi, nnodes_needed))
×
2785
                    Llo, Lhi = -np.inf, np.inf
×
2786
                    continue
×
2787

2788
            if self.log:
1!
2789
                # self.logger.info('  logZ = %.4f +- %.4f (main)' % (main_iterator.logZ, main_iterator.logZerr))
2790
                self.logger.info('  logZ = %.4g +- %.4g', main_iterator.logZ_bs, main_iterator.logZerr_bs)
1✔
2791

2792
            saved_logl = np.asarray(saved_logl)
1✔
2793
            # reactive nested sampling: see where we have to improve
2794
            dlogz_min_num_live_points, (Llo_KL, Lhi_KL), (Llo_ess, Lhi_ess) = self._find_strategy(
1✔
2795
                saved_logl, main_iterator, dlogz=dlogz, dKL=dKL, min_ess=min_ess)
2796
            Llo = min(Llo_ess, Llo_KL)
1✔
2797
            Lhi = max(Lhi_ess, Lhi_KL)
1✔
2798
            # to avoid numerical issues when all likelihood values are the same
2799
            Lhi = min(Lhi, saved_logl.max() - 0.001)
1✔
2800

2801
            if self.use_mpi:
1!
2802
                recv_Llo = self.comm.gather(Llo, root=0)
×
2803
                recv_Llo = self.comm.bcast(recv_Llo, root=0)
×
2804
                recv_Lhi = self.comm.gather(Lhi, root=0)
×
2805
                recv_Lhi = self.comm.bcast(recv_Lhi, root=0)
×
2806
                recv_dlogz_min_num_live_points = self.comm.gather(dlogz_min_num_live_points, root=0)
×
2807
                recv_dlogz_min_num_live_points = self.comm.bcast(recv_dlogz_min_num_live_points, root=0)
×
2808

2809
                Llo = min(recv_Llo)
×
2810
                Lhi = max(recv_Lhi)
×
2811
                dlogz_min_num_live_points = max(recv_dlogz_min_num_live_points)
×
2812

2813
            if dlogz_min_num_live_points > self.min_num_live_points:
1✔
2814
                # more live points needed throughout to reach target
2815
                self.min_num_live_points = dlogz_min_num_live_points
1✔
2816
                self._widen_roots_beyond_initial_plateau(
1✔
2817
                    self.min_num_live_points,
2818
                    widen_before_initial_plateau_num_warn,
2819
                    widen_before_initial_plateau_num_max)
2820

2821
            elif Llo <= Lhi:
1!
2822
                # if self.log:
2823
                #     print_tree(roots, title="Tree before forking:")
2824
                parents, parent_weights = find_nodes_before(self.root, Llo)
1✔
2825
                # double the width / live points:
2826
                _, width = count_tree_between(self.root.children, Llo, Lhi)
1✔
2827
                nnodes_needed = width * 2
1✔
2828
                if self.log:
1!
2829
                    self.logger.info(
1✔
2830
                        'Widening from %d to %d live points before L=%.1g...',
2831
                        len(parents), nnodes_needed, Llo)
2832

2833
                if len(parents) == 0:
1!
2834
                    Llo = -np.inf
×
2835
                else:
2836
                    Llo = min(n.value for n in parents)
1✔
2837
                self.pointstore.reset()
1✔
2838
                target_min_num_children.update(self._widen_nodes(parents, parent_weights, nnodes_needed, update_interval_ncall))
1✔
2839
                minimal_widths.append((Llo, Lhi, nnodes_needed))
1✔
2840
                # if self.log:
2841
                #     print_tree(roots, title="Tree after forking:")
2842
                # print('tree size:', count_tree(roots))
2843
            else:
2844
                break
×
2845

2846
    def _update_results(self, main_iterator, saved_logl, saved_nodeids):
1✔
2847
        if self.log:
1!
2848
            self.logger.info('Likelihood function evaluations: %d', self.ncall)
1✔
2849

2850
        results = combine_results(
1✔
2851
            saved_logl, saved_nodeids, self.pointpile,
2852
            main_iterator, mpi_comm=self.comm if self.use_mpi else None)
2853

2854
        results['ncall'] = int(self.ncall)
1✔
2855
        results['paramnames'] = self.paramnames + self.derivedparamnames
1✔
2856
        results['logzerr_single'] = (main_iterator.all_H[0] / self.min_num_live_points)**0.5
1✔
2857

2858
        sequence, results2 = logz_sequence(self.root, self.pointpile, random=True, check_insertion_order=True)
1✔
2859
        results['insertion_order_MWW_test'] = results2['insertion_order_MWW_test']
1✔
2860

2861
        results_simple = dict(results)
1✔
2862
        weighted_samples = results_simple.pop('weighted_samples')
1✔
2863
        samples = results_simple.pop('samples')
1✔
2864
        saved_wt0 = weighted_samples['weights']
1✔
2865
        saved_u = weighted_samples['upoints']
1✔
2866
        saved_v = weighted_samples['points']
1✔
2867

2868
        if self.log_to_disk:
1✔
2869
            if self.log:
1!
2870
                self.logger.info("Writing samples and results to disk ...")
1✔
2871
            np.savetxt(os.path.join(self.logs['chains'], 'equal_weighted_post.txt'),
1✔
2872
                       samples,
2873
                       header=' '.join(self.paramnames + self.derivedparamnames),
2874
                       comments='')
2875
            np.savetxt(os.path.join(self.logs['chains'], 'weighted_post.txt'),
1✔
2876
                       np.hstack((saved_wt0.reshape((-1, 1)), np.reshape(saved_logl, (-1, 1)), saved_v)),
2877
                       header=' '.join(['weight', 'logl'] + self.paramnames + self.derivedparamnames),
2878
                       comments='')
2879
            np.savetxt(os.path.join(self.logs['chains'], 'weighted_post_untransformed.txt'),
1✔
2880
                       np.hstack((saved_wt0.reshape((-1, 1)), np.reshape(saved_logl, (-1, 1)), saved_u)),
2881
                       header=' '.join(['weight', 'logl'] + self.paramnames + self.derivedparamnames),
2882
                       comments='')
2883

2884
            with open(os.path.join(self.logs['info'], 'results.json'), 'w') as f:
1✔
2885
                json.dump(results_simple, f, indent=4)
1✔
2886

2887
            np.savetxt(
1✔
2888
                os.path.join(self.logs['info'], 'post_summary.csv'),
2889
                [[results['posterior'][k][i] for i in range(self.num_params) for k in ('mean', 'stdev', 'median', 'errlo', 'errup')]],
2890
                header=','.join(['"{0}_mean","{0}_stdev","{0}_median","{0}_errlo","{0}_errup"'.format(k)
2891
                                 for k in self.paramnames + self.derivedparamnames]),
2892
                delimiter=',', comments='',
2893
            )
2894

2895
        if self.log_to_disk:
1✔
2896
            keys = 'logz', 'logzerr', 'logvol', 'nlive', 'logl', 'logwt', 'insert_order'
1✔
2897
            np.savetxt(os.path.join(self.logs['chains'], 'run.txt'),
1✔
2898
                       np.hstack(tuple([np.reshape(sequence[k], (-1, 1)) for k in keys])),
2899
                       header=' '.join(keys),
2900
                       comments='')
2901
            if self.log:
1!
2902
                self.logger.info("Writing samples and results to disk ... done")
1✔
2903

2904
        self.results = results
1✔
2905
        self.run_sequence = sequence
1✔
2906

2907
    def store_tree(self):
1✔
2908
        """Store tree to disk (results/tree.hdf5)."""
2909
        if self.log_to_disk:
×
2910
            dump_tree(os.path.join(self.logs['results'], 'tree.hdf5'),
×
2911
                      self.root.children, self.pointpile)
2912

2913
    def print_results(self, use_unicode=True):
1✔
2914
        """Give summary of marginal likelihood and parameter posteriors.
2915

2916
        Parameters
2917
        ----------
2918
        use_unicode: bool
2919
            Whether to print a unicode plot of the posterior distributions
2920

2921
        """
2922
        if self.log:
1!
2923
            print()
1✔
2924
            print('logZ = %(logz).3f +- %(logzerr).3f' % self.results)
1✔
2925
            print('  single instance: logZ = %(logz_single).3f +- %(logzerr_single).3f' % self.results)
1✔
2926
            print('  bootstrapped   : logZ = %(logz_bs).3f +- %(logzerr_bs).3f' % self.results)
1✔
2927
            print('  tail           : logZ = +- %(logzerr_tail).3f' % self.results)
1✔
2928
            print('insert order U test : converged: %(converged)s correlation: %(independent_iterations)s iterations' % (
1✔
2929
                self.results['insertion_order_MWW_test']))
2930
            if self.stepsampler and hasattr(self.stepsampler, 'print_diagnostic'):
1✔
2931
                self.stepsampler.print_diagnostic()
1✔
2932

2933
            print()
1✔
2934
            for i, p in enumerate(self.paramnames + self.derivedparamnames):
1✔
2935
                v = self.results['samples'][:,i]
1✔
2936
                sigma = v.std()
1✔
2937
                med = v.mean()
1✔
2938
                if sigma == 0:
1!
2939
                    j = 3
×
2940
                else:
2941
                    j = max(0, int(-np.floor(np.log10(sigma))) + 1)
1✔
2942
                fmt = '%%.%df' % j
1✔
2943
                try:
1✔
2944
                    if not use_unicode:
1!
2945
                        raise UnicodeEncodeError("")
×
2946
                    # make fancy terminal visualisation on a best-effort basis
2947
                    ' ▁▂▃▄▅▆▇██'.encode(sys.stdout.encoding)
1✔
2948
                    H, edges = np.histogram(v, bins=40)
1✔
2949
                    # add a bit of padding, but not outside parameter limits
2950
                    lo, hi = edges[0], edges[-1]
1✔
2951
                    step = edges[1] - lo
1✔
2952
                    lo = max(self.transform_limits[i,0], lo - 2 * step)
1✔
2953
                    hi = min(self.transform_limits[i,1], hi + 2 * step)
1✔
2954
                    H, edges = np.histogram(v, bins=np.linspace(lo, hi, 40))
1✔
2955
                    lo, hi = edges[0], edges[-1]
1✔
2956

2957
                    dist = ''.join([' ▁▂▃▄▅▆▇██'[i] for i in np.ceil(H * 7 / H.max()).astype(int)])
1✔
2958
                    print('    %-20s: %-6s│%s│%-6s    %s +- %s' % (p, fmt % lo, dist, fmt % hi, fmt % med, fmt % sigma))
1✔
2959
                except Exception:
1✔
2960
                    fmts = '    %-20s' + fmt + " +- " + fmt
1✔
2961
                    print(fmts % (p, med, sigma))
1✔
2962
            print()
1✔
2963

2964
    def plot(self):
1✔
2965
        """Make corner, run and trace plots.
2966

2967
        calls:
2968

2969
        * plot_corner()
2970
        * plot_run()
2971
        * plot_trace()
2972
        """
2973
        self.plot_corner()
1✔
2974
        self.plot_run()
1✔
2975
        self.plot_trace()
1✔
2976

2977
    def plot_corner(self):
1✔
2978
        """Make corner plot.
2979

2980
        Writes corner plot to plots/ directory if log directory was
2981
        specified, otherwise show interactively.
2982

2983
        This does essentially::
2984

2985
            from ultranest.plot import cornerplot
2986
            cornerplot(results)
2987

2988
        """
2989
        from .plot import cornerplot
1✔
2990
        import matplotlib.pyplot as plt
1✔
2991
        if self.log:
1!
2992
            self.logger.debug('Making corner plot ...')
1✔
2993
        cornerplot(self.results, logger=self.logger if self.log else None)
1✔
2994
        if self.log_to_disk:
1✔
2995
            plt.savefig(os.path.join(self.logs['plots'], 'corner.pdf'), bbox_inches='tight')
1✔
2996
            plt.close()
1✔
2997
            self.logger.debug('Making corner plot ... done')
1✔
2998

2999
    def plot_trace(self):
1✔
3000
        """Make trace plot.
3001

3002
        Write parameter trace diagnostic plots to plots/ directory
3003
        if log directory specified, otherwise show interactively.
3004

3005
        This does essentially::
3006

3007
            from ultranest.plot import traceplot
3008
            traceplot(results=results, labels=paramnames + derivedparamnames)
3009

3010
        """
3011
        from .plot import traceplot
1✔
3012
        import matplotlib.pyplot as plt
1✔
3013
        if self.log:
1!
3014
            self.logger.debug('Making trace plot ... ')
1✔
3015
        paramnames = self.paramnames + self.derivedparamnames
1✔
3016
        # get dynesty-compatible sequences
3017
        traceplot(results=self.run_sequence, labels=paramnames)
1✔
3018
        if self.log_to_disk:
1✔
3019
            plt.savefig(os.path.join(self.logs['plots'], 'trace.pdf'), bbox_inches='tight')
1✔
3020
            plt.close()
1✔
3021
            self.logger.debug('Making trace plot ... done')
1✔
3022

3023
    def plot_run(self):
1✔
3024
        """Make run plot.
3025

3026
        Write run diagnostic plots to plots/ directory
3027
        if log directory specified, otherwise show interactively.
3028

3029
        This does essentially::
3030

3031
            from ultranest.plot import runplot
3032
            runplot(results=results)
3033

3034
        """
3035
        from .plot import runplot
1✔
3036
        import matplotlib.pyplot as plt
1✔
3037
        if self.log:
1!
3038
            self.logger.debug('Making run plot ... ')
1✔
3039
        # get dynesty-compatible sequences
3040
        runplot(results=self.run_sequence, logplot=True)
1✔
3041
        if self.log_to_disk:
1✔
3042
            plt.savefig(os.path.join(self.logs['plots'], 'run.pdf'), bbox_inches='tight')
1✔
3043
            plt.close()
1✔
3044
            self.logger.debug('Making run plot ... done')
1✔
3045

3046

3047
def read_file(log_dir, x_dim, num_bootstraps=20, random=True, verbose=False, check_insertion_order=True):
1✔
3048
    """
3049
    Read the output HDF5 file of UltraNest.
3050

3051
    Parameters
3052
    ----------
3053
    log_dir: str
3054
        Folder containing results
3055
    x_dim: int
3056
        number of dimensions
3057
    num_bootstraps: int
3058
        number of bootstraps to use for estimating logZ.
3059
    random: bool
3060
        use randomization for volume estimation.
3061
    verbose: bool
3062
        show progress
3063
    check_insertion_order: bool
3064
        whether to perform MWW insertion order test for assessing convergence
3065

3066
    Returns
3067
    ----------
3068
    sequence: dict
3069
        contains arrays storing for each iteration estimates of:
3070

3071
            * logz: log evidence estimate
3072
            * logzerr: log evidence uncertainty estimate
3073
            * logvol: log volume estimate
3074
            * samples_n: number of live points
3075
            * logwt: log weight
3076
            * logl: log likelihood
3077

3078
    final: dict
3079
        same as ReactiveNestedSampler.results and
3080
        ReactiveNestedSampler.run return values
3081

3082
    """
3083
    import h5py
1✔
3084
    filepath = os.path.join(log_dir, 'results', 'points.hdf5')
1✔
3085
    fileobj = h5py.File(filepath, 'r')
1✔
3086
    _, ncols = fileobj['points'].shape
1✔
3087
    num_params = ncols - 3 - x_dim
1✔
3088

3089
    points = fileobj['points'][:]
1✔
3090
    fileobj.close()
1✔
3091
    del fileobj
1✔
3092
    stack = list(enumerate(points))
1✔
3093

3094
    pointpile = PointPile(x_dim, num_params)
1✔
3095

3096
    def pop(Lmin):
1✔
3097
        """Find matching sample from points file."""
3098
        # look forward to see if there is an exact match
3099
        # if we do not use the exact matches
3100
        #   this causes a shift in the loglikelihoods
3101
        for i, (idx, next_row) in enumerate(stack):
1✔
3102
            row_Lmin = next_row[0]
1✔
3103
            L = next_row[1]
1✔
3104
            if row_Lmin <= Lmin and L > Lmin:
1✔
3105
                idx, row = stack.pop(i)
1✔
3106
                return idx, row
1✔
3107
        return None, None
1✔
3108

3109
    roots = []
1✔
3110
    while True:
3111
        _, row = pop(-np.inf)
1✔
3112
        if row is None:
1✔
3113
            break
1✔
3114
        logl = row[1]
1✔
3115
        u = row[3:3 + x_dim]
1✔
3116
        v = row[3 + x_dim:3 + x_dim + num_params]
1✔
3117
        roots.append(pointpile.make_node(logl, u, v))
1✔
3118

3119
    root = TreeNode(id=-1, value=-np.inf, children=roots)
1✔
3120

3121
    def onNode(node, main_iterator):
1✔
3122
        """Insert (single) child of node if available."""
3123
        while True:
3124
            _, row = pop(node.value)
1✔
3125
            if row is None:
1✔
3126
                break
1✔
3127
            if row is not None:
1!
3128
                logl = row[1]
1✔
3129
                u = row[3:3 + x_dim]
1✔
3130
                v = row[3 + x_dim:3 + x_dim + num_params]
1✔
3131
                child = pointpile.make_node(logl, u, v)
1✔
3132
                assert logl > node.value, (logl, node.value)
1✔
3133
                main_iterator.Lmax = max(main_iterator.Lmax, logl)
1✔
3134
                node.children.append(child)
1✔
3135

3136
    return logz_sequence(root, pointpile, nbootstraps=num_bootstraps,
1✔
3137
                         random=random, onNode=onNode, verbose=verbose,
3138
                         check_insertion_order=check_insertion_order)
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc