• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

JohannesBuchner / UltraNest / 9f2dd4f6-0775-47e9-b700-af647027ebfa

22 Apr 2024 12:51PM UTC coverage: 74.53% (+0.3%) from 74.242%
9f2dd4f6-0775-47e9-b700-af647027ebfa

push

circleci

web-flow
Merge pull request #118 from njzifjoiez/fixed-size-vectorised-slice-sampler

vectorised slice sampler of fixed batch size

1329 of 2026 branches covered (65.6%)

Branch coverage included in aggregate %.

79 of 80 new or added lines in 1 file covered. (98.75%)

1 existing line in 1 file now uncovered.

4026 of 5159 relevant lines covered (78.04%)

0.78 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.75
/ultranest/popstepsampler.py
1
"""
2
Vectorized step samplers
3
------------------------
4

5
Likelihood based on GPUs (model emulators based on neural networks,
6
or JAX implementations) can evaluate hundreds of points as efficiently
7
as one point. The implementations in this module leverage this power,
8
by providing random walks of populations of walkers.
9
"""
10

11
import numpy as np
1✔
12
from ultranest.utils import submasks
1✔
13
from ultranest.stepfuncs import evolve, step_back, update_vectorised_slice_sampler
1✔
14
from ultranest.stepfuncs import generate_cube_oriented_direction, generate_cube_oriented_direction_scaled
1✔
15
from ultranest.stepfuncs import generate_random_direction, generate_region_oriented_direction, generate_region_random_direction
1✔
16
from ultranest.stepfuncs import generate_differential_direction, generate_mixture_random_direction
1✔
17
import scipy.stats
1✔
18

19

20
def unitcube_line_intersection(ray_origin, ray_direction):
1✔
21
    r"""Compute intersection of a line (ray) and a unit box (0:1 in all axes).
22

23
    Based on
24
    http://www.iquilezles.org/www/articles/intersectors/intersectors.htm
25

26
    Parameters
27
    -----------
28
    ray_origin: array of vectors
29
        starting point of line
30
    ray_direction: vector
31
        line direction vector
32

33
    Returns
34
    --------
35
    tleft: array
36
        negative intersection point distance from ray\_origin in units in ray\_direction
37
    tright: array
38
        positive intersection point distance from ray\_origin in units in ray\_direction
39

40
    """
41
    # make sure ray starts inside the box
42
    assert (ray_origin >= 0).all(), ray_origin
1✔
43
    assert (ray_origin <= 1).all(), ray_origin
1✔
44
    assert ((ray_direction**2).sum()**0.5 > 1e-200).all(), ray_direction
1✔
45

46
    # step size
47
    with np.errstate(divide='ignore', invalid='ignore'):
1✔
48
        m = 1. / ray_direction
1✔
49
        n = m * (ray_origin - 0.5)
1✔
50
        k = np.abs(m) * 0.5
1✔
51
        # line coordinates of intersection
52
        # find first intersecting coordinate
53
        t1 = -n - k
1✔
54
        t2 = -n + k
1✔
55
        return np.nanmax(t1, axis=1), np.nanmin(t2, axis=1)
1✔
56

57
def diagnose_move_distances(region, ustart, ufinal):
1✔
58
    """Compares random walk travel distance to MLFriends radius.
59

60
    Compares in whitened space (t-space), the L2 norm between final
61
    point and starting point to the MLFriends bootstrapped radius.
62

63
    Parameters
64
    ----------
65
    region: MLFriends
66
        built region
67
    ustart: array
68
        starting positions
69
    ufinal: array
70
        final positions
71

72
    Returns
73
    -------
74
    far_enough: bool
75
        whether the distance is larger than the radius
76
    move_distance: float
77
        distance between start and final point in whitened space
78
    reference_distance: float
79
        MLFriends radius
80
    """
81
    assert ustart.shape == ufinal.shape, (ustart.shape, ufinal.shape)
1✔
82
    tstart = region.transformLayer.transform(ustart)
1✔
83
    tfinal = region.transformLayer.transform(ufinal)
1✔
84
    d2 = ((tstart - tfinal)**2).sum(axis=1)
1✔
85
    far_enough = d2 > region.maxradiussq
1✔
86

87
    return far_enough, [d2**0.5, region.maxradiussq**0.5]
1✔
88

89
class GenericPopulationSampler():
1✔
90
    def plot(self, filename):
1✔
91
        """Plot sampler statistics.
92

93
        Parameters
94
        -----------
95
        filename: str
96
            Stores plot into ``filename`` and data into
97
            ``filename + ".txt.gz"``.
98
        """
99
        if len(self.logstat) == 0:
1!
100
            return
×
101

102
        import matplotlib.pyplot as plt
1✔
103
        plt.figure(figsize=(10, 1 + 3 * len(self.logstat_labels)))
1✔
104
        for i, label in enumerate(self.logstat_labels):
1✔
105
            part = [entry[i] for entry in self.logstat]
1✔
106
            plt.subplot(len(self.logstat_labels), 1, 1 + i)
1✔
107
            plt.ylabel(label)
1✔
108
            plt.plot(part)
1✔
109
            x = []
1✔
110
            y = []
1✔
111
            for j in range(0, len(part), 20):
1✔
112
                x.append(j)
1✔
113
                y.append(np.mean(part[j:j + 20]))
1✔
114
            plt.plot(x, y)
1✔
115
            if np.min(part) > 0:
1✔
116
                plt.yscale('log')
1✔
117
        plt.savefig(filename, bbox_inches='tight')
1✔
118
        np.savetxt(filename + '.txt.gz', self.logstat,
1✔
119
                   header=','.join(self.logstat_labels), delimiter=',')
120
        plt.close()
1✔
121

122
    @property
1✔
123
    def mean_jump_distance(self):
1✔
124
        """Geometric mean jump distance."""
125
        if len(self.logstat) == 0:
1!
126
            return np.nan
×
127
        return np.exp(np.average(
1✔
128
            np.log([entry[-1] + 1e-10 for entry in self.logstat]),
129
            weights=([entry[0] for entry in self.logstat])
130
        ))
131

132
    @property
1✔
133
    def far_enough_fraction(self):
1✔
134
        """Fraction of jumps exceeding reference distance."""
135
        if len(self.logstat) == 0:
1!
136
            return np.nan
×
137
        return np.average(
1✔
138
            [entry[-2] for entry in self.logstat],
139
            weights=([entry[0] for entry in self.logstat])
140
        )
141

142
    def get_info_dict(self):
1✔
143
        return dict(
×
144
            num_logs=len(self.logstat),
145
            rejection_rate=1 - np.nanmean([entry[0] for entry in self.logstat]) if len(self.logstat) > 0 else np.nan,
146
            mean_scale=np.nanmean([entry[1] for entry in self.logstat]) if len(self.logstat) > 0 else np.nan,
147
            mean_nsteps=np.nanmean([entry[2] for entry in self.logstat]) if len(self.logstat) > 0 else np.nan,
148
            mean_distance=self.mean_jump_distance,
149
            frac_far_enough=self.far_enough_fraction,
150
            last_logstat=dict(zip(self.logstat_labels, self.logstat[-1] if len(self.logstat) > 1 else [np.nan] * len(self.logstat_labels)))
151
        )
152

153

154
    def print_diagnostic(self):
1✔
155
        """Print diagnostic of step sampler performance."""
156
        if len(self.logstat) == 0:
1!
157
            print("diagnostic unavailable, no recorded steps found")
×
158
            return
×
159
        frac_farenough = self.far_enough_fraction
1✔
160
        average_distance = self.mean_jump_distance
1✔
161
        if frac_farenough < 0.5:
1✔
162
            advice = ': very fishy. Double nsteps and see if fraction and lnZ change)'
1✔
163
        elif frac_farenough < 0.66:
1!
164
            advice = ': fishy. Double nsteps and see if fraction and lnZ change)'
×
165
        else:
166
            advice = ' (should be >50%)'
1✔
167
        print('step sampler diagnostic: jump distance %.2f (should be >1), far enough fraction: %.2f%% %s' % (
1✔
168
            average_distance, frac_farenough * 100, advice))
169

170
    def plot_jump_diagnostic_histogram(self, filename, **kwargs):
1✔
171
        """Plot jump diagnostic histogram."""
172
        if len(self.logstat) == 0:
1!
173
            return
×
174
        import matplotlib.pyplot as plt
1✔
175
        plt.hist(np.log10([entry[-1] for entry in self.logstat]), **kwargs)
1✔
176
        ylo, yhi = plt.ylim()
1✔
177
        plt.vlines(self.mean_jump_distance, ylo, yhi)
1✔
178
        plt.ylim(ylo, yhi)
1✔
179
        plt.xlabel('log(relative step distance)')
1✔
180
        plt.ylabel('Frequency')
1✔
181
        plt.savefig(filename, bbox_inches='tight')
1✔
182
        plt.close()
1✔
183

184

185
class PopulationRandomWalkSampler(GenericPopulationSampler):
1✔
186
    """Vectorized Gaussian Random Walk sampler."""
187

188
    def __init__(
1✔
189
        self, popsize, nsteps, generate_direction, scale,
190
        scale_adapt_factor=0.9, scale_min=1e-20, scale_max=20, log=False, logfile=None
191
    ):
192
        """Initialise.
193

194
        Parameters
195
        ----------
196
        popsize: int
197
            number of walkers to maintain.
198
            this should be fairly large (~100), if too large you probably get memory issues
199
            Also, some results have to be discarded as the likelihood threshold increases.
200
            Observe the nested sampling efficiency.
201
        nsteps: int
202
            number of steps to take until the found point is accepted as independent.
203
            To find the right value, see :py:class:`ultranest.calibrator.ReactiveNestedCalibrator`
204
        generate_direction: function
205
            Function that gives proposal kernel shape, one of:
206
            :py:func:`ultranest.popstepsampler.generate_cube_oriented_direction`
207
            :py:func:`ultranest.popstepsampler.generate_cube_oriented_direction_scaled`
208
            :py:func:`ultranest.popstepsampler.generate_random_direction`
209
            :py:func:`ultranest.popstepsampler.generate_region_oriented_direction`
210
            :py:func:`ultranest.popstepsampler.generate_region_random_direction`
211
        scale: float
212
            initial guess for the proposal scaling factor
213
        scale_adapt_factor: float
214
            if 1, no adapting is done.
215
            if <1, the scale is increased if the acceptance rate is below 23.4%,
216
            or decreased if it is above, by *scale_adapt_factor*.
217
        scale_min: float
218
            lowest value allowed for scale, do not adapt down further
219
        scale_max: float
220
            highest value allowed for scale, do not adapt up further
221
        logfile: file
222
            where to print the current scaling factor and acceptance rate
223

224
        """
225
        self.nsteps = nsteps
1✔
226
        self.nrejects = 0
1✔
227
        self.scale = scale
1✔
228
        self.ncalls = 0
1✔
229
        assert scale_adapt_factor <= 1
1✔
230
        self.scale_adapt_factor = scale_adapt_factor
1✔
231
        self.scale_min = scale_min
1✔
232
        self.scale_max = scale_max
1✔
233

234
        self.log = log
1✔
235
        self.logfile = logfile
1✔
236
        self.logstat = []
1✔
237
        self.logstat_labels = ['accept_rate', 'efficiency', 'scale', 'far_enough', 'mean_rel_jump']
1✔
238
        self.prepared_samples = []
1✔
239

240
        self.popsize = popsize
1✔
241
        self.generate_direction = generate_direction
1✔
242

243
    def __str__(self):
1✔
244
        """Return string representation."""
UNCOV
245
        return 'PopulationRandomWalkSampler(popsize=%d, nsteps=%d, generate_direction=%s, scale=%.g)' % (
×
246
            self.popsize, self.nsteps, self.generate_direction, self.scale)
247

248
    def region_changed(self, Ls, region):
1✔
249
        """Act upon region changed. Currently unused."""
250
        pass
1✔
251

252
    def __next__(
1✔
253
        self, region, Lmin, us, Ls, transform, loglike, ndraw=10,
254
        plot=False, tregion=None, log=False
255
    ):
256
        """Sample a new live point.
257

258
        Parameters
259
        ----------
260
        region: MLFriends object
261
            Region
262
        Lmin: float
263
            current log-likelihood threshold
264
        us: np.array((nlive, ndim))
265
            live points
266
        Ls: np.array(nlive)
267
            loglikelihoods live points
268
        transform: function
269
            prior transform function
270
        loglike: function
271
            loglikelihood function
272
        ndraw: int
273
            not used
274
        plot: bool
275
            not used
276
        tregion: bool
277
            not used
278
        log: bool
279
            not used
280

281
        Returns
282
        -------
283
        u: np.array(ndim) or None
284
            new point coordinates (None if not yet available)
285
        p: np.array(nparams) or None
286
            new point transformed coordinates (None if not yet available)
287
        L: float or None
288
            new point likelihood (None if not yet available)
289
        nc: int
290

291
        """
292
        nlive, ndim = us.shape
1✔
293

294
        # fill if empty:
295
        if len(self.prepared_samples) == 0:
1✔
296
            # choose live points
297
            ilive = np.random.randint(0, nlive, size=self.popsize)
1✔
298
            allu = us[ilive,:]
1✔
299
            allp = None
1✔
300
            allL = Ls[ilive]
1✔
301
            nc = self.nsteps * self.popsize
1✔
302
            nrejects_expected = self.nrejects + self.nsteps * self.popsize * (1 - 0.234)
1✔
303

304
            for i in range(self.nsteps):
1✔
305
                # perturb walker population
306
                v = self.generate_direction(allu, region, self.scale)
1✔
307
                # compute intersection of u + t * v with unit cube
308
                tleft, tright = unitcube_line_intersection(allu, v)
1✔
309
                proposed_t = scipy.stats.truncnorm.rvs(tleft, tright, loc=0, scale=1).reshape((-1, 1))
1✔
310

311
                proposed_u = allu + v * proposed_t
1✔
312
                mask_outside = ~np.logical_and(proposed_u > 0, proposed_u < 1).all(axis=1)
1✔
313
                assert not mask_outside.any(), proposed_u[mask_outside, :]
1✔
314

315
                proposed_p = transform(proposed_u)
1✔
316
                # accept if likelihood threshold exceeded
317
                proposed_L = loglike(proposed_p)
1✔
318
                mask_accept = proposed_L > Lmin
1✔
319
                self.nrejects += (~mask_accept).sum()
1✔
320
                allu[mask_accept,:] = proposed_u[mask_accept,:]
1✔
321
                if allp is None:
1✔
322
                    del allp
1✔
323
                    allp = proposed_p * np.nan
1✔
324
                allp[mask_accept,:] = proposed_p[mask_accept,:]
1✔
325
                allL[mask_accept] = proposed_L[mask_accept]
1✔
326
            assert np.isfinite(allp).all(), 'some walkers never moved! Double nsteps of PopulationRandomWalkSampler.'
1✔
327
            far_enough, (move_distance, reference_distance) = diagnose_move_distances(region, us[ilive[mask_accept],:], allu[mask_accept,:])
1✔
328
            self.prepared_samples = list(zip(allu, allp, allL))
1✔
329

330
            self.logstat.append([
1✔
331
                mask_accept.mean(),
332
                1 - (self.nrejects - (nrejects_expected - self.nsteps * self.popsize * (1 - 0.234))) / (self.nsteps * self.popsize),
333
                self.scale,
334
                self.nsteps,
335
                np.mean(far_enough),
336
                np.exp(np.mean(np.log(move_distance / reference_distance + 1e-10)))
337
            ])
338
            if self.logfile:
1!
339
                self.logfile.write("rescale\t%.4f\t%.4f\t%g\t%.4f%g\n" % self.logstat[-1])
×
340

341
            # adapt slightly
342
            if self.nrejects > nrejects_expected and self.scale > self.scale_min:
1!
343
                # lots of rejects, decrease scale
344
                self.scale *= self.scale_adapt_factor
×
345
            elif self.nrejects < nrejects_expected and self.scale < self.scale_max:
1!
346
                self.scale /= self.scale_adapt_factor
1✔
347
        else:
348
            nc = 0
1✔
349

350
        u, p, L = self.prepared_samples.pop(0)
1✔
351
        return u, p, L, nc
1✔
352

353

354
class PopulationSliceSampler(GenericPopulationSampler):
1✔
355
    """Vectorized slice/HARM sampler.
356

357
    Can revert until all previous steps have likelihoods allL above Lmin.
358
    Updates currentt, generation and allL, in-place.
359
    """
360

361
    def __init__(
1✔
362
        self, popsize, nsteps, generate_direction, scale=1.0,
363
        scale_adapt_factor=0.9, log=False, logfile=None
364
    ):
365
        """Initialise.
366

367
        Parameters
368
        ----------
369
        popsize: int
370
            number of walkers to maintain
371
        nsteps: int
372
            number of steps to take until the found point is accepted as independent.
373
            To find the right value, see :py:class:`ultranest.calibrator.ReactiveNestedCalibrator`
374
        generate_direction: function `(u, region, scale) -> v`
375
            function such as `generate_unit_directions`, which
376
            generates a random slice direction.
377
        scale: float
378
            initial guess scale for the length of the slice
379
        scale_adapt_factor: float
380
            smoothing factor for updating scale.
381
            if near 1, scale is barely updating, if near 0,
382
            the last slice length is used as a initial guess for the next.
383

384
        """
385
        self.nsteps = nsteps
1✔
386
        self.nrejects = 0
1✔
387
        self.scale = scale
1✔
388
        self.scale_adapt_factor = scale_adapt_factor
1✔
389
        self.allu = []
1✔
390
        self.allL = []
1✔
391
        self.currentt = []
1✔
392
        self.currentv = []
1✔
393
        self.currentp = []
1✔
394
        self.generation = []
1✔
395
        self.current_left = []
1✔
396
        self.current_right = []
1✔
397
        self.searching_left = []
1✔
398
        self.searching_right = []
1✔
399
        self.ringindex = 0
1✔
400

401
        self.log = log
1✔
402
        self.logfile = logfile
1✔
403
        self.logstat = []
1✔
404
        self.logstat_labels = ['accept_rate', 'efficiency', 'scale', 'far_enough', 'mean_rel_jump']
1✔
405

406
        self.popsize = popsize
1✔
407
        self.generate_direction = generate_direction
1✔
408

409
    def __str__(self):
1✔
410
        """Return string representation."""
411
        return 'PopulationSliceSampler(popsize=%d, nsteps=%d, generate_direction=%s, scale=%.g)' % (
1✔
412
            self.popsize, self.nsteps, self.generate_direction, self.scale)
413

414
    def region_changed(self, Ls, region):
1✔
415
        """Act upon region changed. Currently unused."""
416
        # self.scale = region.us.std(axis=1).mean()
417
        if self.logfile:
1!
418
            self.logfile.write("region-update\t%g\t%g\n" % (self.scale, region.us.std(axis=1).mean()))
×
419

420
    def _setup(self, ndim):
1✔
421
        """Allocate arrays."""
422
        self.allu = np.zeros((self.popsize, self.nsteps + 1, ndim)) + np.nan
1✔
423
        self.allL = np.zeros((self.popsize, self.nsteps + 1)) + np.nan
1✔
424
        self.currentt = np.zeros(self.popsize) + np.nan
1✔
425
        self.currentv = np.zeros((self.popsize, ndim)) + np.nan
1✔
426
        self.generation = np.zeros(self.popsize, dtype=int) - 1
1✔
427
        self.current_left = np.zeros(self.popsize)
1✔
428
        self.current_right = np.zeros(self.popsize)
1✔
429
        self.searching_left = np.zeros(self.popsize, dtype=bool)
1✔
430
        self.searching_right = np.zeros(self.popsize, dtype=bool)
1✔
431

432
    def setup_start(self, us, Ls, starting):
1✔
433
        """Initialize walker starting points.
434

435
        For iteration zero, randomly selects a live point as starting point.
436

437
        Parameters
438
        ----------
439
        us: np.array((nlive, ndim))
440
            live points
441
        Ls: np.array(nlive)
442
            loglikelihoods live points
443
        starting: np.array(nwalkers, dtype=bool)
444
            which walkers to initialize.
445

446
        """
447
        if self.log:
1!
448
            print("setting up:", starting)
1✔
449
        nlive = len(us)
1✔
450
        i = np.random.randint(nlive, size=starting.sum())
1✔
451

452
        if not starting.all():
1✔
453
            while starting[self.ringindex]:
1✔
454
                # if the one we are waiting for is being restarted,
455
                # we may as well pick the next one to wait for
456
                # because every other one is started from a random point
457
                # as well
458
                self.shift()
1✔
459

460
        self.allu[starting,0] = us[i]
1✔
461
        self.allL[starting,0] = Ls[i]
1✔
462
        self.generation[starting] = 0
1✔
463

464
    @property
1✔
465
    def status(self):
1✔
466
        """Return compact string representation of the current status."""
467
        s1 = ('G:' + ''.join(['%d' % g if g >= 0 else '_' for g in self.generation]))
1✔
468
        s2 = ('S:' + ''.join([
1✔
469
            'S' if not np.isfinite(self.currentt[i]) else 'L' if self.searching_left[i] else 'R' if self.searching_right[i] else 'B'
470
            for i in range(self.popsize)]))
471
        return s1 + '  ' + s2
1✔
472

473
    def setup_brackets(self, mask_starting, region):
1✔
474
        """Pick starting direction and range for slice.
475

476
        Parameters
477
        ----------
478
        region: MLFriends object
479
            Region
480
        mask_starting: np.array(nwalkers, dtype=bool)
481
            which walkers to set up.
482

483
        """
484
        if self.log:
1!
485
            print("starting brackets:", mask_starting)
1✔
486
        i_starting, = np.where(mask_starting)
1✔
487
        self.current_left[i_starting] = -self.scale
1✔
488
        self.current_right[i_starting] = self.scale
1✔
489
        self.searching_left[i_starting] = True
1✔
490
        self.searching_right[i_starting] = True
1✔
491
        self.currentt[i_starting] = 0
1✔
492
        # choose direction for new slice
493
        self.currentv[i_starting,:] = self.generate_direction(
1✔
494
            self.allu[i_starting, self.generation[i_starting]],
495
            region)
496

497
    def _setup_currentp(self, nparams):
1✔
498
        if self.log:
1!
499
            print("setting currentp")
1✔
500
        self.currentp = np.zeros((self.popsize, nparams)) + np.nan
1✔
501

502
    def advance(self, transform, loglike, Lmin, region):
1✔
503
        """Advance the walker population.
504

505
        Parameters
506
        ----------
507
        transform: function
508
            prior transform function
509
        loglike: function
510
            loglikelihood function
511
        Lmin: float
512
            current log-likelihood threshold
513

514
        """
515
        movable = self.generation < self.nsteps
1✔
516
        all_movable = movable.all()
1✔
517
        # print("moving ", movable.sum(), self.popsize)
518
        if all_movable:
1✔
519
            i = np.arange(self.popsize)
1✔
520
            args = [
1✔
521
                self.allu[i, self.generation],
522
                self.allL[i, self.generation],
523
                # pass values directly
524
                self.currentt,
525
                self.currentv,
526
                self.current_left,
527
                self.current_right,
528
                self.searching_left,
529
                self.searching_right
530
            ]
531
            del i
1✔
532
        else:
533
            args = [
1✔
534
                self.allu[movable, self.generation[movable]],
535
                self.allL[movable, self.generation[movable]],
536
                # this makes copies
537
                self.currentt[movable],
538
                self.currentv[movable],
539
                self.current_left[movable],
540
                self.current_right[movable],
541
                self.searching_left[movable],
542
                self.searching_right[movable]
543
            ]
544
        if self.log:
1!
545
            print("evolve will advance:", movable)
1✔
546

547
        uorig = args[0].copy()
1✔
548
        (
1✔
549
            (
550
                currentt, currentv,
551
                current_left, current_right, searching_left, searching_right
552
            ),
553
            (success, unew, pnew, Lnew),
554
            nc
555
        ) = evolve(transform, loglike, Lmin, *args)
556

557
        if success.any():
1✔
558
            far_enough, (move_distance, reference_distance) = diagnose_move_distances(region, uorig[success,:], unew)
1✔
559
            self.logstat.append([
1✔
560
                success.mean(),
561
                self.scale,
562
                self.nsteps,
563
                np.mean(far_enough) if len(far_enough) > 0 else 0,
564
                np.exp(np.mean(np.log(move_distance / reference_distance + 1e-10))) if len(far_enough) > 0 else 0
565
            ])
566
            if self.logfile:
1!
567
                self.logfile.write("rescale\t%.4f\t%.4f\t%g\t%.4f%g\n" % self.logstat[-1])
×
568

569
        if self.log:
1!
570
            print("movable", movable.shape, movable.sum(), success.shape)
1✔
571
        moved = submasks(movable, success)
1✔
572
        if self.log:
1!
573
            print("evolve moved:", moved)
1✔
574
        self.generation[moved] += 1
1✔
575
        if len(pnew) > 0:
1✔
576
            if len(self.currentp) == 0:
1✔
577
                self._setup_currentp(nparams=pnew.shape[1])
1✔
578

579
            if self.log:
1!
580
                print("currentp", self.currentp[moved,:].shape, pnew.shape)
1✔
581
            self.currentp[moved,:] = pnew
1✔
582

583
        # update with what we learned
584
        # print(currentu.shape, currentL.shape, success.shape, self.generation[movable])
585
        self.allu[moved, self.generation[moved]] = unew
1✔
586
        self.allL[moved, self.generation[moved]] = Lnew
1✔
587
        if all_movable:
1✔
588
            # in this case, the values were directly overwritten
589
            pass
1✔
590
        else:
591
            self.currentt[movable] = currentt
1✔
592
            self.currentv[movable] = currentv
1✔
593
            self.current_left[movable] = current_left
1✔
594
            self.current_right[movable] = current_right
1✔
595
            self.searching_left[movable] = searching_left
1✔
596
            self.searching_right[movable] = searching_right
1✔
597
        return nc
1✔
598

599
    def shift(self):
1✔
600
        """Update walker from which to pick next."""
601
        # this is a ring buffer
602
        # shift index forward, wrapping around
603
        # this is better than copying memory around when a element is removed
604
        self.ringindex = (self.ringindex + 1) % self.popsize
1✔
605

606
    def __next__(
1✔
607
        self, region, Lmin, us, Ls, transform, loglike, ndraw=10,
608
        plot=False, tregion=None, log=False
609
    ):
610
        """Sample a new live point.
611

612
        Parameters
613
        ----------
614
        region: MLFriends object
615
            Region
616
        Lmin: float
617
            current log-likelihood threshold
618
        us: np.array((nlive, ndim))
619
            live points
620
        Ls: np.array(nlive)
621
            loglikelihoods live points
622
        transform: function
623
            prior transform function
624
        loglike: function
625
            loglikelihood function
626
        ndraw: int
627
            not used
628
        plot: bool
629
            not used
630
        tregion: bool
631
            not used
632
        log: bool
633
            not used
634

635
        Returns
636
        -------
637
        u: np.array(ndim) or None
638
            new point coordinates (None if not yet available)
639
        p: np.array(nparams) or None
640
            new point transformed coordinates (None if not yet available)
641
        L: float or None
642
            new point likelihood (None if not yet available)
643
        nc: int
644

645
        """
646
        nlive, ndim = us.shape
1✔
647
        # initialize
648
        if len(self.allu) == 0:
1✔
649
            self._setup(ndim)
1✔
650

651
        step_back(Lmin, self.allL, self.generation, self.currentt)
1✔
652

653
        starting = self.generation < 0
1✔
654
        if starting.any():
1✔
655
            self.setup_start(us[Ls > Lmin], Ls[Ls > Lmin], starting)
1✔
656
        assert (self.generation >= 0).all(), self.generation
1✔
657

658
        # find those where bracket is undefined:
659
        mask_starting = ~np.isfinite(self.currentt)
1✔
660
        if mask_starting.any():
1✔
661
            self.setup_brackets(mask_starting, region)
1✔
662

663
        if self.log:
1!
664
            print(str(self), "(before)")
1✔
665
        nc = self.advance(transform, loglike, Lmin, region)
1✔
666
        if self.log:
1!
667
            print(str(self), "(after)")
1✔
668

669
        # harvest top individual if possible
670
        if self.generation[self.ringindex] == self.nsteps:
1✔
671
            if self.log:
1!
672
                print("have a candidate")
1✔
673
            u, p, L = self.allu[self.ringindex, self.nsteps, :].copy(), self.currentp[self.ringindex, :].copy(), self.allL[self.ringindex, self.nsteps].copy()
1✔
674
            assert np.isfinite(u).all(), u
1✔
675
            assert np.isfinite(p).all(), p
1✔
676
            self.generation[self.ringindex] = -1
1✔
677
            self.currentt[self.ringindex] = np.nan
1✔
678
            self.allu[self.ringindex,:,:] = np.nan
1✔
679
            self.allL[self.ringindex,:] = np.nan
1✔
680

681
            # adjust guess length
682
            newscale = (self.current_right[self.ringindex] - self.current_left[self.ringindex]) / 2
1✔
683
            self.scale = self.scale * 0.9 + 0.1 * newscale
1✔
684

685
            self.shift()
1✔
686
            return u, p, L, nc
1✔
687
        else:
688
            return None, None, None, nc
1✔
689

690

691

692

693
def slice_limit_to_unitcube(tleft, tright):
1✔
694
    
695
    """
696
    return the slice limits as of the intersection between the slice and the unit cube boundaries
697

698

699
    parameters
700
    ----------
701
        tleft: float
702
                Intersection of the unit cube with the slice in the negative direction
703
        tright: float
704
                Intersection of the unit cube with the slice in the positive direction
705
    Returns
706
    -------
707
        (tleft_new,tright_new): tuple
708
                Positive and negative slice limits
709
    """
710
    tleft_new, tright_new = tleft.copy(), tright.copy()
1✔
711
    return (tleft_new, tright_new)
1✔
712

713

714
def slice_limit_to_scale(tleft, tright):
1✔
715

716
    """
717
    return the slice limits as an interval of size `2*scale` or the intersection between the slice and the unit cube boundaries
718
    if the interval is larger than the unit cube boundaries.
719

720
    parameters
721
    ----------
722
        tleft: float
723
                Intersection of the unit cube with the slice in the negative direction
724
        tright: float
725
                Intersection of the unit cube with the slice in the positive direction
726
    Returns
727
    -------
728
        (tleft_new,tright_new): tuple
729
                Positive and negative slice limits
730
    """
731

732
    
733
    tleft_new = np.fmax(tleft , -1. + np.zeros_like(tleft))
1✔
734
    tright_new = np.fmin(tright , 1. + np.zeros_like(tright))
1✔
735
    return (tleft_new, tright_new)
1✔
736

737

738

739
class PopulationSimpleSliceSampler(GenericPopulationSampler):
1✔
740
    """
741
       Vectorized Slice sampler without stepping out procedure for quick look fits.
742
       Unlike `:py:class:PopulationSliceSampler`, in `:py:class:PopulationSimpleSliceSampler`,
743
       the likelihood is always called with the same number of points.
744

745
       Sliced are defined by the `:py:func:generate_direction` function on a interval defined
746
       around the current point. The centred interval has the width of the scale parameter,
747
       i.e, there is no stepping out procedure as in `:py:class:PopulationSliceSampler`.
748
       Slices are then shrink towards the current point until a point is found with a
749
       likelihood above the threshold.
750

751
       In the default case, i.e. `scale=None`, the slice width is defined as the
752
       intersection between itself and the unit cube. To improve the efficiency of the sampler,
753
       the slice can be reduced to an interval of size `2*scale` centred on the point. `scale`
754
       can be adapted with the `scale_adapt_factor` parameter based on the median distance 
755
       between the current and the next point in a chains among all the chains. If the median
756
       distance is above `scale/adapt_slice_scale_target`, the scale is increased by `scale_adapt_factor`,
757
       and decreased otherwise. The `scale` parameter can also be jittered by a user supplied
758
       function `:py:func:scale_jitter_func` to counter balance the effect of a strong adaptation.
759

760
       In the case `scale!=None`, the detailed balance is not guaranteed, so this sampler should
761
       be use with caution.
762

763
       Multiple (`popsize`) slice sampling chains are run independently and in parallel. 
764
       In that case, we read points as if they were the next selected each after the other.
765
       For a points to update the slice, it needs to be still in the part of the slices
766
       searched after the first point have been read. In that case, we update as normal, 
767
       otherwise we discard the point.
768

769

770
    """
771

772
    def __init__(
1✔
773
        self, popsize, nsteps, generate_direction,
774
        scale_adapt_factor=1.0, adapt_slice_scale_target=2.0,
775
        scale=1.0, scale_jitter_func=None,slice_limit=slice_limit_to_unitcube,
776
        max_it=100,shrink_factor=1.0):
777
        """Initialise.
778

779
        Parameters
780
        ----------
781
        popsize: int
782
            number of walkers to maintain.
783
        nsteps: int
784
            number of steps to take until the found point is accepted as independent.
785
            To calibrate, try several runs with increasing nsteps (doubling).
786
            The ln(Z) should become stable at some value.
787
        generate_direction: function
788
            Function that gives proposal kernel shape, one of:
789
            :py:func:`ultranest.popstepsampler.generate_random_direction`
790
            :py:func:`ultranest.popstepsampler.generate_region_oriented_direction`
791
            :py:func:`ultranest.popstepsampler.generate_region_random_direction`
792
            :py:func:`ultranest.popstepsampler.generate_differential_direction`
793
            :py:func:`ultranest.popstepsampler.generate_mixture_random_direction`
794
            :py:func:`ultranest.popstepsampler.generate_cube_oriented_direction` -> no adaptation in that case
795
            :py:func:`ultranest.popstepsampler.generate_cube_oriented_direction_scaled` -> no adaptation in that case
796
        scale: float
797
            initial guess for the slice width.
798
        scale_jitter_func: function
799
            User supplied function to multiply the `scale` by a random factor. For example, 
800
            :py:func:`lambda : scipy.stats.truncnorm.rvs(-0.5, 5., loc=0, scale=1)+1.`
801
        scale_adapt_factor: float
802
            adaptation of `scale`. If 1: no adaptation. if <1, the scale is increased/decreased by this factor if the
803
            final slice length is shorter/longer than the `adapt_slice_scale_target*scale`.        
804
        adapt_slice_scale_target: float
805
            Targeted ratio of the median distance between slice mid and final point among all chains of `scale`.
806
            Default: 2.0. Higher values are more conservative, lower values are faster. 
807
        slice_limit: function
808
            Function setting the initial slice upper and lower bound. The default is `:py:func:slice_limit_to_unitcube`
809
            which defines  the slice limit as the intersection between the slice and the unit cube. An alternative 
810
            when the `scale` is used is `:py:func:slice_limit_to_scale` which defines the slice limit as an interval
811
            of size `2*scale`. This function should either return a copy of the `tleft` and `tright` arguments or 
812
            new arrays of the same shape. 
813
        max_it: int
814
            maximum number of iterations to find a point on the slice. If the maximum number of iterations is reached,
815
            the current point is returned as the next one.
816
        shrink_factor: float
817
            For standard slice sampling shrinking, `shrink_factor=1`, the slice bound is updated to the last 
818
            rejected point. Setting `shrink_factor>1` aggressively accelerates the shrinkage, by updating the 
819
            new slice bound to `1/shrink_factor` of the distance between the current point and rejected point.
820
        """
821

822
        self.nsteps = nsteps
1✔
823
        
824
        self.max_it = max_it
1✔
825
        self.nrejects = 0
1✔
826
        self.generate_direction =  generate_direction
1✔
827
        self.scale_adapt_factor = scale_adapt_factor
1✔
828
        self.ncalls = 0
1✔
829
        self.discarded = 0
1✔
830
        self.shrink_factor = shrink_factor
1✔
831
        assert shrink_factor>=1.0, "The shrink factor should be greater than 1.0 to be efficient"
1✔
832

833
        self.scale = float(scale)
1✔
834

835
        self.adapt_slice_scale_target = adapt_slice_scale_target
1✔
836
        
837
        if scale_jitter_func is None:
1!
838
            self.scale_jitter_func= lambda : 1.
1✔
839
        else:
NEW
840
            self.scale_jitter_func= scale_jitter_func      
×
841
        self.prepared_samples = []
1✔
842
        self.popsize = popsize
1✔
843
        
844
        self.slice_limit = slice_limit
1✔
845

846
        self.logstat = []
1✔
847
        self.logstat_labels = ['accept_rate', 'efficiency', 'scale', 'far_enough', 'mean_rel_jump']
1✔
848

849
        
850
        
851

852
    def __str__(self):
1✔
853
        """Return string representation."""
854
        return 'PopulationSimpleSliceSampler(popsize=%d, nsteps=%d, generate_direction=%s, scale=%.g)' % (
1✔
855
            self.popsize, self.nsteps, self.generate_direction, self.scale)
856

857
    def region_changed(self, Ls, region):
1✔
858
        """Act upon region changed. Currently unused."""
859
        pass
1✔
860

861
    
862
    
863
    def __next__(
1✔
864
        self, region, Lmin, us, Ls, transform, loglike, ndraw=10,
865
        plot=False, tregion=None, log=False, test=False
866
    ):
867
        """Sample a new live point.
868

869
        Parameters
870
        ----------
871
        region: MLFriends object
872
            Region
873
        Lmin: float
874
            current log-likelihood threshold
875
        us: np.array((nlive, ndim))
876
            live points
877
        Ls: np.array(nlive)
878
            loglikelihoods live points
879
        transform: function
880
            prior transform function
881
        loglike: function
882
            loglikelihood function
883
        ndraw: int
884
            not used
885
        plot: bool
886
            not used
887
        tregion: bool
888
            not used
889
        log: bool
890
            not used
891
        test: bool
892
            In case of test of the reversibility of the sampler, the points drawn
893
            from the live points needs to be deterministic. This parameters is
894
            ensuring that.
895

896
        Returns
897
        -------
898
        u: np.array(ndim) or None
899
            new point coordinates (None if not yet available)
900
        p: np.array(nparams) or None
901
            new point transformed coordinates (None if not yet available)
902
        L: float or None
903
            new point likelihood (None if not yet available)
904
        nc: int
905

906
        """
907
        nlive, ndim = us.shape
1✔
908
        
909
         
910
        # fill if empty:
911
        if len(self.prepared_samples) == 0:
1✔
912
            # choose live points
913
            ilive = np.random.randint(0, nlive, size=self.popsize)
1✔
914
            allu = np.array(us[ilive,:]) if not test else np.array(us)
1✔
915
            allp = np.zeros((self.popsize, ndim))
1✔
916
            allL = np.array(Ls[ilive])
1✔
917
            nc = 0
1✔
918
            n_discarded = 0
1✔
919
            
920
                                         
921
                
922
            
923
            interval_final = 0. 
1✔
924
            
925
            for k in range(self.nsteps):
1✔
926
                # Defining scale jitter
927
                factor_scale = self.scale_jitter_func()
1✔
928
                # Defining slice direction
929
                v = self.generate_direction(allu, region, scale = 1.0)*self.scale*factor_scale
1✔
930
                 
931
                
932
                # limite of the slice based on the unit cube boundaries
933
                tleft_unitcube, tright_unitcube = unitcube_line_intersection(allu, v)
1✔
934

935
                # Defining bound of the slice
936
                # Bounds for each points and likelihood calls are identical initially
937

938
                # Slice bounds for each likelihood call
939
                tleft_worker, tright_worker = self.slice_limit(tleft_unitcube,tright_unitcube)
1✔
940

941
                # Slice bounds for each points
942
                tleft, tright = self.slice_limit(tleft_unitcube,tright_unitcube)
1✔
943
                # Index of the workers working concurrently
944
                worker_running = np.arange(0,self.popsize,1,dtype=int)
1✔
945
                # Status indicating if a points has already find its next position
946
                status = np.zeros(self.popsize,dtype=int) # one for success, zero for running
1✔
947
                
948

949
                # Loop until each points has found its next position or we reached 100 iterations
950
                
951
                for it in range(self.max_it):
1!
952
                
953
                    # Sampling points on the slices
954
                    slice_position = np.random.uniform(size=(self.popsize,))
1✔
955
                    
956
                    t = tleft_worker+(tright_worker-tleft_worker)*slice_position
1✔
957
                    
958
                    
959
                    
960
                    points = allu[worker_running,:]
1✔
961
                    v_worker = v[worker_running,:]
1✔
962
                    proposed_u = points+t.reshape((-1,1))*v_worker
1✔
963
                
964
                    proposed_p = transform(proposed_u)
1✔
965
                    proposed_L = loglike(proposed_p)
1✔
966
                    nc += self.popsize
1✔
967
                    # Updating the pool of points based on the newly sampled points
968
                    tleft,tright,worker_running,status,allu,allL,allp,n_discarded_it = update_vectorised_slice_sampler(\
1✔
969
                    t,tleft,tright,proposed_L,proposed_u,proposed_p,worker_running,status,Lmin,self.shrink_factor,\
970
                    allu,allL,allp,self.popsize)
971
                    n_discarded += n_discarded_it
1✔
972
                    # Update of the limits of the slices
973
                    tleft_worker = tleft[worker_running]
1✔
974
                    tright_worker = tright[worker_running]
1✔
975
                    if not np.any(status==0):
1✔
976
                        break
1✔
977
                # Record of the final interval on theta for scale adaptation
978
                interval_final += np.median(tright-tleft)
1✔
979

980

981
            
982
            interval_final = interval_final/self.nsteps
1✔
983
            
984
            
985
            self.discarded += n_discarded
1✔
986
            self.ncalls += nc
1✔
987
            
988
            assert np.array([p!=np.zeros(ndim) for p in allp]).all(), 'some walkers never moved! Double nsteps of PopulationSimpleSliceSampler.'
1✔
989
            far_enough, (move_distance, reference_distance) = diagnose_move_distances(region, us[ilive,:], allu)
1✔
990
            self.prepared_samples = list(zip(allu, allp, allL))
1✔
991

992
            self.logstat.append([
1✔
993
                self.popsize/nc,
994
                self.scale, # will always be 1. in the default case
995
                self.nsteps,
996
                np.mean(far_enough) if len(far_enough) > 0 else 0,
997
                np.exp(np.mean(np.log(move_distance / reference_distance + 1e-10))) if len(far_enough) > 0 else 0
998
            ])
999

1000
            
1001

1002
            
1003
            # Scale adaptation such that the final interval is
1004
            # half the scale. There may be better things to do 
1005
            # here, but it seems to work.
1006
            if interval_final>=1./self.adapt_slice_scale_target:
1✔
1007
                self.scale *= 1./self.scale_adapt_factor
1✔
1008
            else:
1009
                self.scale *= self.scale_adapt_factor
1✔
1010
            #print("percentage of throws %.3f\n\n"%((self.throwed/self.ncalls)*100.))
1011
            
1012
        else:
1013
            nc = 0
1✔
1014

1015
        u, p, L = self.prepared_samples.pop(0)
1✔
1016
        return u, p, L, nc
1✔
1017

1018

1019
__all__ = [
1✔
1020
    "generate_cube_oriented_direction", "generate_cube_oriented_direction_scaled",
1021
    "generate_random_direction", "generate_region_oriented_direction", "generate_region_random_direction",
1022
    "PopulationRandomWalkSampler", "PopulationSliceSampler","PopulationSimpleSliceSampler"]
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc