• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

JohannesBuchner / UltraNest / 9f2dd4f6-0775-47e9-b700-af647027ebfa

22 Apr 2024 12:51PM UTC coverage: 74.53% (+0.3%) from 74.242%
9f2dd4f6-0775-47e9-b700-af647027ebfa

push

circleci

web-flow
Merge pull request #118 from njzifjoiez/fixed-size-vectorised-slice-sampler

vectorised slice sampler of fixed batch size

1329 of 2026 branches covered (65.6%)

Branch coverage included in aggregate %.

79 of 80 new or added lines in 1 file covered. (98.75%)

1 existing line in 1 file now uncovered.

4026 of 5159 relevant lines covered (78.04%)

0.78 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

5.84
/ultranest/pathsampler.py
1
"""MCMC-like step sampling on a trajectory
2

3
These features are experimental.
4
"""
5

6
import numpy as np
1✔
7

8
import matplotlib.pyplot as plt
1✔
9

10
from ultranest.samplingpath import SamplingPath, ContourSamplingPath, extrapolate_ahead
1✔
11
from ultranest.stepsampler import StepSampler
1✔
12
from ultranest.stepsampler import generate_region_oriented_direction, generate_region_random_direction, generate_random_direction
1✔
13

14
from ultranest.flatnuts import ClockedStepSampler, ClockedBisectSampler, ClockedNUTSSampler
1✔
15
from ultranest.flatnuts import SingleJumper, DirectJumper, IntervalJumper
1✔
16

17

18
class SamplingPathSliceSampler(StepSampler):
1✔
19
    """Slice sampler, respecting the region, on the sampling path.
20

21
    This first builds up a complete trajectory, respecting reflections.
22
    Then, from the trajectory a new point is drawn with slice sampling.
23

24
    The trajectory is built by doubling the length to each side and
25
    checking if the point is still inside. If not, reflection is
26
    attempted with the gradient (either provided or region-based estimate).
27
    """
28

29
    def __init__(self, nsteps):
1✔
30
        """Initialise sampler.
31

32
        Parameters
33
        -----------
34
        nsteps: int
35
            number of accepted steps until the sample is considered independent.
36

37
        """
38
        StepSampler.__init__(self, nsteps=nsteps)
×
39
        self.interval = None
×
40
        self.path = None
×
41

42
    def generate_direction(self, ui, region, scale=1):
1✔
43
        """Choose new initial direction according to region.transformLayer axes."""
44
        return generate_region_oriented_direction(ui, region, tscale=1, scale=scale)
×
45

46
    def adjust_accept(self, accepted, unew, pnew, Lnew, nc):
1✔
47
        """Adjust proposal given that we have been *accepted* at a new point after *nc* calls."""
48
        if accepted:
×
49
            # start with a new interval next time
50
            self.interval = None
×
51

52
            self.last = unew, Lnew
×
53
            self.history.append((unew, Lnew))
×
54
        else:
55
            self.nrejects += 1
×
56
            # continue on current interval
57
            pass
58
        self.logstat.append([accepted, self.scale])
×
59

60
    def adjust_outside_region(self):
1✔
61
        """Adjust proposal given that we have stepped out of region."""
62
        self.logstat.append([False, self.scale])
×
63

64
    def move(self, ui, region, ndraw=1, plot=False):
1✔
65
        """Advance by slice sampling on the path."""
66
        if self.interval is None:
×
67
            v = self.generate_direction(ui, region, scale=self.scale)
×
68
            self.path = ContourSamplingPath(
×
69
                SamplingPath(ui, v, 0.0), region)
70

71
            if not (ui > 0).all() or not (ui < 1).all() or not region.inside(ui.reshape((1, -1))):
×
72
                assert False, ui
×
73

74
            # unit hypercube diagonal gives a reasonable maximum path length
75
            maxlength = len(ui)**0.5
×
76

77
            # expand direction until it is surely outside
78
            left = -1
×
79
            right = +1
×
80
            while abs(left * self.scale) < maxlength:
×
81
                xj, vj = self.path.extrapolate(left)
×
82
                if not (xj > 0).all() or not (xj < 1).all() or not region.inside(xj.reshape((1, -1))):
×
83
                    break
×
84
                # self.path.add(left, xj, vj, 0.0)
85
                left *= 2
×
86

87
            while abs(right * self.scale) < maxlength:
×
88
                xj, _ = self.path.extrapolate(right)
×
89
                if not (xj > 0).all() or not (xj < 1).all() or not region.inside(xj.reshape((1, -1))):
×
90
                    break
×
91
                # self.path.add(right, xj, vj, 0.0)
92
                right *= 2
×
93

94
            scale = max(-left, right)
×
95
            # print("scale %f gave %d %d " % (self.scale, left, right))
96
            if scale < 5:
×
97
                self.scale /= 1.1
×
98
            # if scale > 100:
99
            #     self.scale *= 1.1
100

101
            assert self.scale > 1e-10, self.scale
×
102
            self.interval = (left, right, None)
×
103
        else:
104
            left, right, mid = self.interval
×
105
            # we rejected mid, and shrink corresponding side
106
            if mid < 0:
×
107
                left = mid
×
108
            elif mid > 0:
×
109
                right = mid
×
110

111
        # shrink direction if outside
112
        while True:
113
            mid = np.random.randint(left, right + 1)
×
114
            # print("interpolating %d - %d - %d" % (left, mid, right),
115
            #     self.path.points)
116
            if mid == 0:
×
117
                _, xj, _, _ = self.path.points[0]
×
118
            else:
119
                xj, _ = self.path.extrapolate(mid)
×
120

121
            if region.inside(xj.reshape((1, -1))):
×
122
                self.interval = (left, right, mid)
×
123
                return xj.reshape((1, -1))
×
124
            else:
125
                if mid < 0:
×
126
                    left = mid
×
127
                else:
128
                    right = mid
×
129
                self.interval = (left, right, mid)
×
130

131

132
class SamplingPathStepSampler(StepSampler):
1✔
133
    """Step sampler on a sampling path."""
134

135
    def __init__(self, nresets, nsteps, scale=1.0, balance=0.01, nudge=1.1, log=False):
1✔
136
        """Initialise sampler.
137

138
        Parameters
139
        ------------
140
        nresets: int
141
            after this many iterations, select a new direction
142
        nsteps: int
143
            how many steps to make in total
144
        scale: float
145
            initial step size
146
        balance: float
147
            acceptance rate to target
148
            if below, scale is increased, if above, scale is decreased
149
        nudge: float
150
            factor for increasing scale (must be >=1)
151
            nudge=1 implies no step size adaptation.
152

153
        """
154
        StepSampler.__init__(self, nsteps=nsteps)
×
155
        # self.lasti = None
156
        self.path = None
×
157
        self.nresets = nresets
×
158
        # initial step scale in transformed space
159
        self.scale = scale
×
160
        # fraction of times a reject is expected
161
        self.balance = balance
×
162
        # relative increase in step scale
163
        self.nudge = nudge
×
164
        assert nudge >= 1
×
165
        self.log = log
×
166
        self.grad_function = None
×
167
        self.istep = 0
×
168
        self.iresets = 0
×
169
        self.start()
×
170
        self.terminate_path()
×
171
        self.logstat_labels = ['acceptance rate', 'reflection rate', 'scale', 'nstuck']
×
172

173
    def __str__(self):
1✔
174
        """Get string representation."""
175
        return '%s(nsteps=%d, nresets=%d, AR=%d%%)' % (
×
176
            type(self).__name__, self.nsteps, self.nresets, (1 - self.balance) * 100)
177

178
    def start(self):
1✔
179
        """Start sampler, reset all counters."""
180
        if hasattr(self, 'naccepts') and self.nrejects + self.naccepts > 0:
×
181
            self.logstat.append([
×
182
                self.naccepts / (self.nrejects + self.naccepts),
183
                self.nreflects / (self.nreflects + self.nrejects + self.naccepts),
184
                self.scale, self.nstuck])
185
        self.nrejects = 0
×
186
        self.naccepts = 0
×
187
        self.nreflects = 0
×
188
        self.nstuck = 0
×
189
        self.istep = 0
×
190
        self.iresets = 0
×
191
        self.noutside_regions = 0
×
192
        self.last = None, None
×
193
        self.history = []
×
194

195
        self.direction = +1
×
196
        self.deadends = set()
×
197
        self.path = None
×
198

199
    def start_path(self, ui, region):
1✔
200
        """Start new trajectory path."""
201
        # print("new direction:", self.scale, self.noutside_regions, self.nrejects, self.naccepts)
202

203
        v = self.generate_direction(ui, region, scale=self.scale)
×
204
        assert (v**2).sum() > 0, (v, self.scale)
×
205
        assert region.inside(ui.reshape((1, -1))).all(), ui
×
206
        self.path = ContourSamplingPath(SamplingPath(ui, v, 0.0), region)
×
207
        if self.grad_function is not None:
×
208
            self.path.gradient = self.grad_function
×
209

210
        if not (ui > 0).all() or not (ui < 1).all() or not region.inside(ui.reshape((1, -1))):
×
211
            assert False, ui
×
212

213
        self.direction = +1
×
214
        self.lasti = 0
×
215
        self.cache = {0: (True, ui, self.last[1])}
×
216
        self.deadends = set()
×
217
        # self.iresets += 1
218
        if self.log:
×
219
            print()
×
220
            print("starting new direction", v, 'from', ui)
×
221

222
    def terminate_path(self):
1✔
223
        """Terminate current path, and reset path counting variable."""
224
        # check if we went anywhere:
225
        if -1 in self.deadends and +1 in self.deadends:
×
226
            # self.scale /= self.nudge
227
            self.nstuck += 1
×
228

229
        # self.nrejects = 0
230
        # self.naccepts = 0
231
        # self.istep = 0
232
        # self.noutside_regions = 0
233
        self.direction = +1
×
234
        self.deadends = set()
×
235
        self.path = None
×
236
        self.iresets += 1
×
237
        if self.log:
×
238
            print("reset %d" % self.iresets)
×
239

240
    def set_gradient(self, grad_function):
1✔
241
        """Set gradient function."""
242
        print("set gradient function to %s" % grad_function.__name__)
×
243

244
        def plot_gradient_wrapper(x, plot=False):
×
245
            """wrapper that makes plots (when desired)"""
246
            v = grad_function(x)
×
247
            if plot:
×
248
                plt.plot(x[0], x[1], '+ ', color='k', ms=10)
×
249
                plt.plot([x[0], v[0] * 1e-2 + x[0]],
×
250
                         [x[1], v[1] * 1e-2 + x[1]], color='gray')
251
            return v
×
252
        self.grad_function = plot_gradient_wrapper
×
253

254
    def generate_direction(self, ui, region, scale):
1✔
255
        """Choose a random axis from region.transformLayer."""
256
        return generate_region_random_direction(ui, region, scale=scale)
×
257
        # return generate_random_direction(ui, region, scale=scale)
258

259
    def adjust_accept(self, accepted, unew, pnew, Lnew, nc):
1✔
260
        """Adjust proposal given that we have been *accepted* at a new point after *nc* calls."""
261
        self.cache[self.nexti] = (accepted, unew, Lnew)
×
262
        if accepted:
×
263
            # start at new point next time
264
            self.lasti = self.nexti
×
265
            self.last = unew, Lnew
×
266
            self.history.append((unew, Lnew))
×
267
            self.naccepts += 1
×
268
        else:
269
            # continue on current point, do not update self.last
270
            self.nrejects += 1
×
271
            self.history.append((unew, Lnew))
×
272
            assert self.scale > 1e-10, (self.scale, self.istep, self.nrejects)
×
273

274
    def adjust_outside_region(self):
1✔
275
        """Adjust proposal given that we landed outside region."""
276
        self.noutside_regions += 1
×
277
        self.nrejects += 1
×
278

279
    def adjust_scale(self, maxlength):
1✔
280
        """Adjust scale, but not above maxlength."""
281
        # print("%2d | %2d | %2d | %2d %2d %2d %2d | %f"  % (self.iresets, self.istep,
282
        #     len(self.history), self.naccepts, self.nrejects,
283
        #     self.noutside_regions, self.nstuck, self.scale))
284
        assert len(self.history) > 1
×
285

286
        if self.naccepts < (self.nrejects + self.naccepts) * self.balance:
×
287
            if self.log:
×
288
                print("adjusting scale %f down: istep=%d inside=%d outside=%d region=%d nstuck=%d" % (
×
289
                    self.scale, len(self.history), self.naccepts, self.nrejects, self.noutside_regions, self.nstuck))
290
            self.scale /= self.nudge
×
291
        else:
292
            if self.scale < maxlength or True:
×
293
                if self.log:
×
294
                    print("adjusting scale %f up: istep=%d inside=%d outside=%d region=%d nstuck=%d" % (
×
295
                        self.scale, len(self.history), self.naccepts, self.nrejects, self.noutside_regions, self.nstuck))
296
                self.scale *= self.nudge
×
297
        assert self.scale > 1e-10, self.scale
×
298

299
    def movei(self, ui, region, ndraw=1, plot=False):
1✔
300
        """Make a move and return the proposed index."""
301
        if self.path is not None:
×
302
            if self.lasti - 1 in self.deadends and self.lasti + 1 in self.deadends:
×
303
                # stuck, cannot go anywhere. Stay.
304
                self.nexti = self.lasti
×
305
                return self.nexti
×
306

307
        if self.path is None:
×
308
            self.start_path(ui, region)
×
309

310
        assert not (self.lasti - 1 in self.deadends and self.lasti + 1 in self.deadends), \
×
311
            (self.deadends, self.lasti)
312
        if self.lasti + self.direction in self.deadends:
×
313
            self.direction *= -1
×
314

315
        self.nexti = self.lasti + self.direction
×
316
        # print("movei", self.nexti)
317
        # self.nexti = self.lasti + np.random.randint(0, 2) * 2 - 1
318
        return self.nexti
×
319

320
    def move(self, ui, region, ndraw=1, plot=False):
1✔
321
        """Advance move."""
322
        u, v = self.get_point(self.movei(ui, region=region, ndraw=ndraw, plot=plot))
×
323
        return u.reshape((1, -1))
×
324

325
    def reflect(self, reflpoint, v, region, plot=False):
1✔
326
        """Reflect at *reflpoint* going in direction *v*. Return new direction."""
327
        normal = self.path.gradient(reflpoint, plot=plot)
×
328
        if normal is None:
×
329
            return -v
×
330
        return v - 2 * (normal * v).sum() * normal
×
331

332
    def get_point(self, inew):
1✔
333
        """Get point corresponding to index *inew*."""
334
        ipoints = [(u, v) for i, u, p, v in self.path.points if i == inew]
×
335
        if len(ipoints) == 0:
×
336
            # print("getting point %d" % inew, self.path.points) #, "->", self.path.extrapolate(self.nexti))
337
            return self.path.extrapolate(inew)
×
338
        else:
339
            return ipoints[0]
×
340

341
    def __next__(self, region, Lmin, us, Ls, transform, loglike, ndraw=40, plot=False):
1✔
342
        """Get next point.
343

344
        Parameters
345
        ----------
346
        region: MLFriends
347
            region.
348
        Lmin: float
349
            loglikelihood threshold
350
        us: array of vectors
351
            current live points
352
        Ls: array of floats
353
            current live point likelihoods
354
        transform: function
355
            transform function
356
        loglike: function
357
            loglikelihood function
358
        ndraw: int
359
            number of draws to attempt simultaneously.
360
        plot: bool
361
            whether to produce debug plots.
362

363
        """
364
        # find most recent point in history conforming to current Lmin
365
        ui, Li = self.last
×
366
        if Li is not None and not Li >= Lmin:
×
367
            if self.log:
×
368
                print("wandered out of L constraint; resetting", ui[0])
×
369
            ui, Li = None, None
×
370

371
        if Li is not None and not region.inside(ui.reshape((1,-1))):
×
372
            # region was updated and we are not inside anymore
373
            # so reset
374
            if self.log:
×
375
                print("region change; resetting")
×
376
            ui, Li = None, None
×
377

378
        if Li is None and self.history:
×
379
            # try to resume from a previous point above the current contour
380
            for uj, Lj in self.history[::-1]:
×
381
                if Lj >= Lmin and region.inside(uj.reshape((1,-1))):
×
382
                    ui, Li = uj, Lj
×
383
                    if self.log:
×
384
                        print("recovered using history", ui)
×
385
                    break
×
386

387
        # select starting point
388
        if Li is None:
×
389
            # choose a new random starting point
390
            mask = region.inside(us)
×
391
            assert mask.any(), (
×
392
                "None of the live points satisfies the current region!",
393
                region.maxradiussq, region.u, region.unormed, us)
394
            i = np.random.randint(mask.sum())
×
395
            self.starti = i
×
396
            ui = us[mask,:][i]
×
397
            if self.log:
×
398
                print("starting at", ui)
×
399
            assert np.logical_and(ui > 0, ui < 1).all(), ui
×
400
            Li = Ls[mask][i]
×
401
            self.start()
×
402
            self.history.append((ui, Li))
×
403
            self.last = (ui, Li)
×
404

405
        inew = self.movei(ui, region, ndraw=ndraw)
×
406
        if self.log:
×
407
            print("i: %d->%d (step %d)" % (self.lasti, inew, self.istep))
×
408

409
        # uold, _ = self.get_point(self.lasti)
410
        _, uold, Lold = self.cache[self.lasti]
×
411
        if plot:
×
412
            plt.plot(uold[0], uold[1], 'd', color='brown', ms=4)
×
413

414
        uret, pret, Lret = uold, transform(uold), Lold
×
415

416
        nc = 0
×
417
        if inew != self.lasti:
×
418
            accept = False
×
419
            if inew not in self.cache:
×
420
                unew, _ = self.get_point(inew)
×
421
                if plot:
×
422
                    plt.plot(unew[0], unew[1], 'x', color='k', ms=4)
×
423
                accept = np.logical_and(unew > 0, unew < 1).all() and region.inside(unew.reshape((1, -1)))
×
424
                if accept:
×
425
                    if plot:
×
426
                        plt.plot(unew[0], unew[1], '+', color='orange', ms=4)
×
427
                    pnew = transform(unew)
×
428
                    Lnew = loglike(pnew.reshape((1, -1)))
×
429
                    nc = 1
×
430
                else:
431
                    Lnew = -np.inf
×
432
                    if self.log:
×
433
                        print("outside region: ", unew, "from", ui)
×
434
                    self.deadends.add(inew)
×
435
                    self.adjust_outside_region()
×
436
            else:
437
                _, unew, Lnew = self.cache[self.nexti]
×
438
                # if plot:
439
                #    plt.plot(unew[0], unew[1], 's', color='r', ms=2)
440

441
            if self.log:
×
442
                print("   suggested point:", unew)
×
443
            pnew = transform(unew)
×
444

445
            if Lnew >= Lmin:
×
446
                if self.log:
×
447
                    print(" -> inside.")
×
448
                if plot:
×
449
                    plt.plot(unew[0], unew[1], 'o', color='g', ms=4)
×
450
                self.adjust_accept(True, unew, pnew, Lnew, nc)
×
451
                uret, pret, Lret = unew, pnew, Lnew
×
452
            else:
453
                if plot:
×
454
                    plt.plot(unew[0], unew[1], '+', color='k', ms=2, alpha=0.3)
×
455
                if self.log:
×
456
                    print(" -> outside.")
×
457
                jump_successful = False
×
458
                if inew not in self.cache and inew not in self.deadends:
×
459
                    # first time we try to go beyond
460
                    # try to reflect:
461
                    reflpoint, v = self.get_point(inew)
×
462
                    if self.log:
×
463
                        print("    trying to reflect at", reflpoint)
×
464
                    self.nreflects += 1
×
465

466
                    sign = -1 if inew < 0 else +1
×
467
                    vnew = self.reflect(reflpoint, v * sign, region=region) * sign
×
468

469
                    xk, vk = extrapolate_ahead(sign, reflpoint, vnew, contourpath=self.path)
×
470

471
                    if plot:
×
472
                        plt.plot([reflpoint[0], (-v + reflpoint)[0]], [reflpoint[1], (-v + reflpoint)[1]], '-', color='k', lw=0.5, alpha=0.5)
×
473
                        plt.plot([reflpoint[0], (vnew + reflpoint)[0]], [reflpoint[1], (vnew + reflpoint)[1]], '-', color='k', lw=1)
×
474

475
                    if self.log:
×
476
                        print("    trying", xk)
×
477
                    accept = np.logical_and(xk > 0, xk < 1).all() and region.inside(xk.reshape((1, -1)))
×
478
                    if accept:
×
479
                        pk = transform(xk)
×
480
                        Lk = loglike(pk.reshape((1, -1)))[0]
×
481
                        nc += 1
×
482
                        if Lk >= Lmin:
×
483
                            jump_successful = True
×
484
                            uret, pret, Lret = xk, pk, Lk
×
485
                            if self.log:
×
486
                                print("successful reflect!")
×
487
                            self.path.add(inew, xk, vk, Lk)
×
488
                            self.adjust_accept(True, xk, pk, Lk, nc)
×
489
                        else:
490
                            if self.log:
×
491
                                print("unsuccessful reflect")
×
492
                            self.adjust_accept(False, xk, pk, Lk, nc)
×
493
                    else:
494
                        if self.log:
×
495
                            print("unsuccessful reflect out of region")
×
496
                        self.adjust_outside_region()
×
497

498
                    if plot:
×
499
                        plt.plot(xk[0], xk[1], 'x', color='g' if jump_successful else 'r', ms=8)
×
500

501
                    if not jump_successful:
×
502
                        # unsuccessful. mark as deadend
503
                        self.deadends.add(inew)
×
504
                        # print("deadends:", self.deadends)
505
                else:
506
                    self.adjust_accept(False, uret, pret, Lret, nc)
×
507

508
                # self.adjust_accept(False, unew, pnew, Lnew, nc)
509
                assert inew in self.cache or inew in self.deadends, (inew in self.cache, inew in self.deadends)
×
510
        else:
511
            # stuck, proposal did not move us
512
            self.nstuck += 1
×
513
            self.adjust_accept(False, uret, pret, Lret, nc)
×
514

515
        # increase step count
516
        self.istep += 1
×
517
        if self.istep == self.nsteps:
×
518
            if self.log:
×
519
                print("triggering re-orientation")
×
520
                # reset path so we go in a new direction
521
            self.terminate_path()
×
522
            self.istep = 0
×
523

524
        # if had enough resets, return final point
525
        if self.iresets >= self.nresets:
×
526
            if self.log:
×
527
                print("walked %d paths; returning sample" % self.iresets)
×
528
            self.adjust_scale(maxlength=len(uret)**0.5)
×
529
            self.start()
×
530
            self.last = None, None
×
531
            return uret, pret, Lret, nc
×
532

533
        # do not have a independent sample yet
534
        return None, None, None, nc
×
535

536

537
class OtherSamplerProxy(object):
1✔
538
    """Proxy for ClockedSamplers."""
539

540
    def __init__(self, nnewdirections, sampler='steps', nsteps=0,
1✔
541
                 balance=0.9, scale=0.1, nudge=1.1, log=False):
542
        """Initialise sampler.
543

544
        Parameters
545
        -----------
546
        nnewdirections: int
547
            number of accepted steps until the sample is considered independent.
548
        sampler: str
549
            which sampler to use
550
        nsteps:
551
            number of steps in sampler
552
        balance:
553
            acceptance rate to target
554
        scale:
555
            initial proposal scale
556
        nudge:
557
            adjustment factor for scale when acceptance rate is too low or high.
558
            must be >=1.
559

560
        """
561
        self.nsteps = nsteps
×
562
        self.samplername = sampler
×
563
        self.sampler = None
×
564

565
        self.scale = scale
×
566
        self.nudge = nudge
×
567
        self.balance = balance
×
568
        self.log = log
×
569

570
        self.last = None, None
×
571
        self.ncalls = 0
×
572
        self.nnewdirections = nnewdirections
×
573
        self.nreflections = 0
×
574
        self.nreverses = 0
×
575
        self.nsteps_done = 0
×
576

577
        self.naccepts = 0
×
578
        self.nrejects = 0
×
579

580
        self.logstat = []
×
581
        self.logstat_labels = ['accepted', 'scale']
×
582

583
    def __str__(self):
1✔
584
        """Get string representation."""
585
        return 'Proxy[%s](%dx%d steps, AR=%d%%)' % (
×
586
            self.samplername, self.nnewdirections, self.nsteps, self.balance * 100)
587

588
    def accumulate_statistics(self):
1✔
589
        """Accumulate statistics at end of step sequence."""
590
        self.nreflections += self.sampler.nreflections
×
591
        self.nreverses += self.sampler.nreverses
×
592
        points = self.sampler.points
×
593
        # range
594
        ilo, _, _, _ = min(points)
×
595
        ihi, _, _, _ = max(points)
×
596
        self.nsteps_done += ihi - ilo
×
597

598
        self.naccepts += self.stepper.naccepts
×
599
        self.nrejects += self.stepper.nrejects
×
600
        if self.log:
×
601
            print("%2d direction encountered %2d accepts, %2d rejects" % (
×
602
                self.nrestarts, self.stepper.naccepts, self.stepper.nrejects))
603

604
    def adjust_scale(self, maxlength):
1✔
605
        """Adjust proposal scale, but not above maxlength."""
606
        log = self.log
×
607
        if log:
×
608
            print("%2d | %2d %2d %2d | %f" % (self.nrestarts,
×
609
                  self.naccepts, self.nrejects, self.nreflections, self.scale))
610
        self.logstat.append([self.naccepts / (self.naccepts + self.nrejects), self.scale])
×
611

612
        if self.naccepts < (self.nrejects + self.naccepts) * self.balance:
×
613
            if log:
×
614
                print("adjusting scale %f down" % self.scale)
×
615
            self.scale /= self.nudge
×
616
        else:
617
            if self.scale < maxlength or True:
×
618
                if log:
×
619
                    print("adjusting scale %f up" % self.scale)
×
620
                self.scale *= self.nudge
×
621
        assert self.scale > 1e-10, self.scale
×
622

623
    def startup(self, region, us, Ls):
1✔
624
        """Choose a new random starting point."""
625
        if self.log:
×
626
            print("starting from scratch...")
×
627
        mask = region.inside(us)
×
628
        assert mask.any(), (
×
629
            "Not all of the live points satisfy the current region!",
630
            region.maxradiussq, region.u[~mask,:], region.unormed[~mask,:], us[~mask,:])
631
        i = np.random.randint(mask.sum())
×
632
        self.starti = i
×
633
        ui = us[mask,:][i]
×
634
        assert np.logical_and(ui > 0, ui < 1).all(), ui
×
635
        Li = Ls[mask][i]
×
636
        self.last = ui, Li
×
637
        self.ncalls = 0
×
638
        self.nrestarts = 0
×
639

640
        self.nreflections = 0
×
641
        self.nreverses = 0
×
642
        self.nsteps_done = 0
×
643
        self.naccepts = 0
×
644
        self.nrejects = 0
×
645

646
        self.sampler = None
×
647
        self.stepper = None
×
648

649
    def start_direction(self, region):
1✔
650
        """Choose a new random direction."""
651
        if self.log:
×
652
            print("choosing random direction")
×
653
        ui, Li = self.last
×
654
        v = generate_random_direction(ui, region, scale=self.scale)
×
655
        # v = generate_region_random_direction(ui, region, scale=self.scale)
656

657
        self.nrestarts += 1
×
658

659
        if self.sampler is None or True:
×
660
            samplingpath = SamplingPath(ui, v, Li)
×
661
            contourpath = ContourSamplingPath(samplingpath, region)
×
662
            if self.samplername == 'steps':
×
663
                self.sampler = ClockedStepSampler(contourpath, log=self.log)
×
664
                self.stepper = DirectJumper(self.sampler, self.nsteps, log=self.log)
×
665
            elif self.samplername == 'bisect':
×
666
                self.sampler = ClockedBisectSampler(contourpath, log=self.log)
×
667
                self.stepper = DirectJumper(self.sampler, self.nsteps, log=self.log)
×
668
            elif self.samplername == 'nuts':
×
669
                self.sampler = ClockedNUTSSampler(contourpath, log=self.log)
×
670
                self.stepper = IntervalJumper(self.sampler, self.nsteps, log=self.log)
×
671
            else:
672
                assert False
×
673

674
    def __next__(self, region, Lmin, us, Ls, transform, loglike, ndraw=40, plot=False):
1✔
675
        """Get next point.
676

677
        Parameters
678
        ----------
679
        region: MLFriends
680
            region.
681
        Lmin: float
682
            loglikelihood threshold
683
        us: array of vectors
684
            current live points
685
        Ls: array of floats
686
            current live point likelihoods
687
        transform: function
688
            transform function
689
        loglike: function
690
            loglikelihood function
691
        ndraw: int
692
            number of draws to attempt simultaneously.
693
        plot: bool
694
            whether to produce debug plots.
695

696
        """
697
        # find most recent point in history conforming to current Lmin
698
        ui, Li = self.last
×
699
        if Li is not None and not Li >= Lmin:
×
700
            # print("wandered out of L constraint; resetting", ui[0])
701
            ui, Li = None, None
×
702

703
        if Li is not None and not region.inside(ui.reshape((1,-1))):
×
704
            # region was updated and we are not inside anymore
705
            # so reset
706
            ui, Li = None, None
×
707

708
        if Li is None:
×
709
            self.startup(region, us, Ls)
×
710
        if self.sampler is None:
×
711
            self.start_direction(region)
×
712

713
        self.stepper.prepare_jump()
×
714
        Llast = None
×
715
        gaps = {}
×
716
        while True:
717
            if not self.sampler.is_done():
×
718
                u, is_independent = self.sampler.next(Llast=Llast)
×
719
                if not is_independent and u is not None:
×
720
                    # should evaluate point
721
                    Llast = None
×
722
                    if region.inside(u.reshape((1,-1))):
×
723
                        p = transform(u.reshape((1, -1)))
×
724
                        L = loglike(p)[0]
×
725
                        self.ncalls += 1
×
726
                        if L > Lmin:
×
727
                            Llast = L
×
728
                    else:
729
                        Llast = None
×
730
            else:
731
                u, i = self.stepper.check_gaps(gaps)
×
732
                if u is None:
×
733
                    unew, Lnew = self.stepper.make_jump(gaps)
×
734
                    break  # done!
×
735
                # check that u is allowed:
736
                assert i not in gaps
×
737
                gaps[i] = True
×
738
                if region.inside(u.reshape((1,-1))):
×
739
                    p = transform(u.reshape((1, -1)))
×
740
                    L = loglike(p)[0]
×
741
                    self.ncalls += 1
×
742
                    if L > Lmin:
×
743
                        # point is OK
744
                        gaps[i] = False
×
745
                        unew, Lnew = u, L
×
746
                        break
×
747

748
        # if self.log: print("after %d calls, jumped to" % self.ncalls, unew)
749
        assert np.isfinite(unew).all(), unew
×
750
        assert np.isfinite(Lnew).all(), Lnew
×
751

752
        self.accumulate_statistics()
×
753
        # forget sampler
754
        self.last = unew, Lnew
×
755
        self.sampler = None
×
756
        self.stepper = None
×
757
        # done, reset:
758
        # print("got a sample:", unew)
759
        if self.nrestarts >= self.nnewdirections:
×
760
            xnew = transform(unew)
×
761
            self.adjust_scale(maxlength=len(unew)**0.5)
×
762
            # forget as starting point
763
            self.last = None, None
×
764
            self.nrestarts = 0
×
765
            return unew, xnew, Lnew, self.ncalls
×
766
        else:
767
            return None, None, None, 0
×
768

769
    def plot(self, filename):
1✔
770
        """Plot sampler statistics."""
771
        if len(self.logstat) == 0:
×
772
            return
×
773

774
        parts = np.transpose(self.logstat)
×
775
        plt.figure(figsize=(10, 1 + 3 * len(parts)))
×
776
        for i, (label, part) in enumerate(zip(self.logstat_labels, parts)):
×
777
            plt.subplot(len(parts), 1, 1 + i)
×
778
            plt.ylabel(label)
×
779
            plt.plot(part)
×
780
            if np.min(part) > 0:
×
781
                plt.yscale('log')
×
782
        plt.savefig(filename, bbox_inches='tight')
×
783
        plt.close()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc