• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

JohannesBuchner / UltraNest / 9f2dd4f6-0775-47e9-b700-af647027ebfa

22 Apr 2024 12:51PM UTC coverage: 74.53% (+0.3%) from 74.242%
9f2dd4f6-0775-47e9-b700-af647027ebfa

push

circleci

web-flow
Merge pull request #118 from njzifjoiez/fixed-size-vectorised-slice-sampler

vectorised slice sampler of fixed batch size

1329 of 2026 branches covered (65.6%)

Branch coverage included in aggregate %.

79 of 80 new or added lines in 1 file covered. (98.75%)

1 existing line in 1 file now uncovered.

4026 of 5159 relevant lines covered (78.04%)

0.78 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

85.14
/ultranest/flatnuts.py
1
"""
2
FLATNUTS is a implementation of No-U-turn sampler 
3
for nested sampling assuming a flat prior space (hyper-cube u-space).
4

5
This is highly experimental. It is similar to NoGUTS and suffers from 
6
the same stability problems.
7

8
Directional sampling within regions.
9

10
Work in unit cube space. assume a step size.
11

12
1. starting from a live point
13
2. choose a random direction based on whitened space metric
14
3. for forward and backward direction:
15

16
  1. find distance where leaving spheres (surely outside)
17
  2. bisect the step that leads out of the likelihood threshold
18
  3. can we scatter forward?
19

20
     - if we stepped outside the unit cube, use normal to the parameter(s) we stepped out from
21
     - if gradient available, use it at first outside point
22
     - for each sphere that contains the last inside point:
23

24
       - resize so that first outside point is on the surface, get tangential vector there
25
         (this vector is just the difference between sphere center and last inside point)
26
       - compute reflection of direction vector with tangential plane
27
     - choose a forward reflection at random (if any)
28

29
  3.4) test if next point is inside again. If yes, continue NUTS
30

31
NUTS: 
32
  - alternatingly double the number of steps to the forward or backward side
33
  - build a tree; terminate when start and end directions are not forward any more
34
  - choose a end point at random out of the sequence
35

36
If the number of steps on any straight line is <10 steps, make step size smaller
37
If the number of steps on any straight line is >100 steps, make step size slightly bigger
38

39
Parameters:
40
 - Number of NUTS tracks (has to be user-tuned to ensure sufficiently independent samples; starting from 1, look when Z does not change anymore)
41
 - Step size (self-adjusting)
42

43
Benefit of this algorithm:
44
 - insensitive to step size
45
 - insensitive to dimensionality (sqrt scaling), better than slice sampling
46
 - takes advantage of region information, can accelerate low-d problems as well
47
Drawbacks:
48
 - inaccurate reflections degrade dimensionality scaling
49
 - more complex to implement than slice sampling
50

51
"""
52

53

54
import numpy as np
1✔
55
from numpy.linalg import norm
1✔
56
import matplotlib.pyplot as plt
1✔
57
from .samplingpath import angle, extrapolate_ahead
1✔
58

59

60
class SingleJumper(object):
1✔
61
    """ Jump on step at a time. If unsuccessful, reverse direction. """
62
    def __init__(self, stepsampler, nsteps=0):
1✔
63
        self.stepsampler = stepsampler
1✔
64
        self.direction = +1
1✔
65
        assert nsteps > 0
1✔
66
        self.nsteps = nsteps
1✔
67
        self.isteps = 0
1✔
68
        self.currenti = 0
1✔
69
        self.naccepts = 0
1✔
70
        self.nrejects = 0
1✔
71
    
72
    def prepare_jump(self):
1✔
73
        target = self.currenti + self.direction
1✔
74
        self.stepsampler.set_nsteps(target)
1✔
75
    
76
    def check_gaps(self, gaps):
1✔
77
        # gaps cannot happen, because we make each jump explicitly
78
        pass
×
79
    # then user runs stepsampler until it is done
80
    
81
    def make_jump(self, gaps={}):
1✔
82
        target = self.currenti + self.direction
1✔
83
        pointi = [(j, xj, vj, Lj) for j, xj, vj, Lj in self.stepsampler.points if j == target]
1✔
84
        accept = len(pointi) > 0
1✔
85
        if accept:
1✔
86
            self.currenti = target
1✔
87
            self.naccepts += 1
1✔
88
        else:
89
            pointi = [(j, xj, vj, Lj) for j, xj, vj, Lj in self.stepsampler.points if j == self.currenti]
1✔
90
            # reverse
91
            self.direction *= -1
1✔
92
            self.nrejects += 1
1✔
93
        
94
        self.isteps += 1
1✔
95
        return pointi[0][1], pointi[0][3]
1✔
96

97

98
class DirectJumper(object):
1✔
99
    """ Jump to n steps immediately. If unsuccessful, takes rest in other direction. """
100
    def __init__(self, stepsampler, nsteps, log=False):
1✔
101
        self.stepsampler = stepsampler
1✔
102
        self.direction = +1
1✔
103
        assert nsteps > 0
1✔
104
        self.nsteps = nsteps
1✔
105
        self.isteps = 0
1✔
106
        self.currenti = 0
1✔
107
        self.naccepts = 0
1✔
108
        self.nrejects = 0
1✔
109
        self.log = log
1✔
110
    
111
    def prepare_jump(self):
1✔
112
        target = self.currenti + self.nsteps
1✔
113
        self.stepsampler.set_nsteps(target)
1✔
114
    
115
    # then user runs stepsampler until it is done
116
    def check_gaps(self, gaps):
1✔
117
        pointi = {j: (xj, Lj) for j, xj, vj, Lj in self.stepsampler.points}
×
118
        ilo, ihi = min(pointi.keys()), max(pointi.keys())
×
119
        currenti = self.currenti
×
120
        direction = self.direction
×
121
        for isteps in range(self.nsteps):
×
122
            target = currenti + direction
×
123
            accept = ilo <= target <= ihi and not gaps.get(target, False)
×
124
            if accept:
×
125
                currenti = target
×
126
                if self.log:
×
127
                    print("accepted jump %d->%d" % (self.currenti, target), 'fwd' if self.direction == 1 else 'rwd')
×
128
            else:
129
                # reverse
130
                if self.log:
×
131
                    print("rejected jump %d->%d" % (self.currenti, target), 'fwd' if self.direction == 1 else 'rwd')
×
132
                direction *= -1
×
133
        
134
        if self.log: print("--> %d" % currenti)
×
135
        # double-check that final point is OK:
136
        # if we already evaluated it, it is OK
137
        if currenti in pointi:
×
138
            return None, None
×
139
        
140
        if currenti in gaps:
×
141
            assert gaps[currenti] == False, "could not have jumped into a known gap"
×
142
            return None, None
×
143
        
144
        xj, vj, Lj, onpath = self.stepsampler.contourpath.interpolate(currenti)
×
145
        if Lj is not None:
×
146
            return None, None
×
147
        
148
        if self.log: print("    checking for gap ...")
×
149
        # otherwise ask caller to verify it and call us again with
150
        # gaps[i] = True if outside, gaps[i] = False if OK
151
        return xj, currenti
×
152
    
153
    def make_jump(self, gaps={}):
1✔
154
        pointi = {j: (xj, Lj) for j, xj, vj, Lj in self.stepsampler.points}
1✔
155
        ilo, ihi = min(pointi.keys()), max(pointi.keys())
1✔
156
        
157
        for self.isteps in range(self.nsteps):
1✔
158
            target = self.currenti + self.direction
1✔
159
            accept = ilo <= target <= ihi and not gaps.get(target, False)
1✔
160
            if accept:
1✔
161
                if self.log:
1!
162
                    print("accepted jump %d->%d" % (self.currenti, target), 'fwd' if self.direction == 1 else 'rwd')
×
163
                self.currenti = target
1✔
164
                self.naccepts += 1
1✔
165
            else:
166
                if self.log:
1!
167
                    print("rejected jump %d->%d" % (self.currenti, target), 'fwd' if self.direction == 1 else 'rwd')
×
168
                # reverse
169
                self.direction *= -1
1✔
170
                self.nrejects += 1
1✔
171
        self.isteps += 1
1✔
172
        
173
        return pointi[self.currenti]
1✔
174

175

176
class IntervalJumper(object):
1✔
177
    """ Use interval to choose final point randomly """
178
    def __init__(self, stepsampler, nsteps):
1✔
179
        self.stepsampler = stepsampler
×
180
        self.direction = +1
×
181
        assert nsteps >= 0
×
182
        self.nsteps = nsteps
×
183
        self.isteps = 0
×
184
        self.currenti = 0
×
185
        self.naccepts = 0
×
186
        self.nrejects = 0
×
187
    
188
    def prepare_jump(self):
1✔
189
        target = self.currenti + self.nsteps
×
190
        self.stepsampler.set_nsteps(target)
×
191
        self.stepsampler.set_nsteps(-target)
×
192
    
193
    # then user runs stepsampler until it is done
194
    
195
    def make_jump(self):
1✔
196
        pointi = {j: (xj, Lj) for j, xj, vj, Lj in self.stepsampler.points}
×
197
        ilo, ihi = min(pointi.keys()), max(pointi.keys())
×
198
        a, b = self.nutssampler.validrange
×
199
        nused = b - a
×
200
        # these were not used:
201
        ntotal = ihi - ilo
×
202
        
203
        # count the number of accepts and rejects
204
        self.naccepts = nused
×
205
        self.nrejects = ntotal - nused
×
206
        
207
        return None
×
208

209
class ClockedSimpleStepSampler(object):
1✔
210
    """
211
    Find a new point with a series of small steps
212
    """
213
    def __init__(self, contourpath, plot=False, log=False):
1✔
214
        """
215
        Starts a sampling track from x in direction v.
216
        is_inside is a function that returns true when a given point is inside the volume
217

218
        epsilon gives the step size in direction v.
219
        samples, if given, helps choose the gradient -- To be removed
220
        plot: if set to true, make some debug plots
221
        """
222
        self.contourpath = contourpath
1✔
223
        self.points = self.contourpath.points
1✔
224
        self.nreflections = 0
1✔
225
        self.nreverses = 0
1✔
226
        self.plot = plot
1✔
227
        self.log = log
1✔
228
        self.reset()
1✔
229
    
230
    def reset(self):
1✔
231
        self.goals = []
1✔
232
    
233
    def reverse(self, reflpoint, v, plot=False):
1✔
234
        """
235
        Reflect off the surface at reflpoint going in direction v
236
        
237
        returns the new direction.
238
        """
239
        normal = self.contourpath.gradient(reflpoint, plot=plot)
1✔
240
        if normal is None:
1!
241
            #assert False
242
            return -v
×
243
        
244
        vnew = v - 2 * angle(normal, v) * normal
1✔
245
        if self.log: print("    new direction:", vnew)
1✔
246
        assert vnew.shape == v.shape, (vnew.shape, v.shape)
1✔
247
        assert np.isclose(norm(vnew), norm(v)), (vnew, v, norm(vnew), norm(v))
1✔
248
        #isunitlength(vnew)
249
        if plot:
1!
250
            plt.plot([reflpoint[0], (-v + reflpoint)[0]], [reflpoint[1], (-v + reflpoint)[1]], '-', color='k', lw=2, alpha=0.5)
×
251
            plt.plot([reflpoint[0], (vnew + reflpoint)[0]], [reflpoint[1], (vnew + reflpoint)[1]], '-', color='k', lw=3)
×
252
        return vnew
1✔
253
    
254
    def set_nsteps(self, i):
1✔
255
        self.goals.insert(0, ('sample-at', i))
1✔
256
    
257
    def is_done(self):
1✔
258
        return self.goals == []
1✔
259
    
260
    def expand_onestep(self, fwd, transform, loglike, Lmin):
1✔
261
        """ Helper interface, make one step (forward fwd=True or backward fwd=False) """
262
        if fwd:
1✔
263
            starti, _, _, _ = max(self.points)
1✔
264
            i = starti + 1
1✔
265
        else:
266
            starti, _, _, _ = min(self.points)
1✔
267
            i = starti - 1
1✔
268
        return self.expand_to_step(i, transform, loglike, Lmin)
1✔
269

270
    def expand_to_step(self, nsteps, transform, loglike, Lmin):
1✔
271
        """ Helper interface, go to step nstep """
272
        self.set_nsteps(nsteps)
1✔
273
        return self.get_independent_sample(transform, loglike, Lmin)
1✔
274

275
    def get_independent_sample(self, transform, loglike, Lmin):
1✔
276
        """ Helper interface, call next() until a independent sample is returned """
277
        Llast = None
1✔
278
        while True:
279
            sample, is_independent = self.next(Llast)
1✔
280
            if sample is None:
1✔
281
                return None, None
1✔
282
            if is_independent:
1✔
283
                unew, Lnew = sample
1✔
284
                return unew, Lnew
1✔
285
            else:
286
                unew = sample
1✔
287
                xnew = transform(unew)
1✔
288
                Llast = loglike(xnew)
1✔
289
                if Llast < Lmin:
1✔
290
                    Llast = None
1✔
291

292

293
class ClockedStepSampler(ClockedSimpleStepSampler):
1✔
294
    """
295
    Find a new point with a series of small steps
296
    """
297

298
    def continue_sampling(self, i):
1✔
299
        if i > 0 and self.contourpath.samplingpath.fwd_possible \
1✔
300
        or i < 0 and self.contourpath.samplingpath.rwd_possible:
301
            # we are not done:
302
            self.goals.insert(0, ('expand-to', i))
1✔
303
            self.goals.append(('sample-at', i))
1✔
304
        else:
305
            # we are not done, but cannot reach the goal.
306
            # reverse. Find position from where to reverse
307
            if i > 0:
1✔
308
                starti, _, _, _ = max(self.points)
1✔
309
                reversei = starti + 1
1✔
310
            else:
311
                starti, _, _, _ = min(self.points)
1✔
312
                reversei = starti - 1
1✔
313
            if self.log: print("reversing at %d..." % starti)
1✔
314
            # how many steps are missing?
315
            self.nreverses += 1
1✔
316
            deltai = i - starti
1✔
317
            # request one less because one step is spent on
318
            # the outside try
319
            #if self.log: print("   %d steps to do at %d -> [from %d, delta=%d] targeting %d." % (
320
            #    i - starti, starti, reversei, deltai, reversei - deltai))
321
            # make this many steps in the other direction
322
            self.goals.append(('sample-at', reversei - deltai))
1✔
323
    
324
    def expand_to(self, i):
1✔
325
        if i > 0 and self.contourpath.samplingpath.fwd_possible:
1✔
326
            starti, startx, startv, _ = max(self.points)
1✔
327
            if i > starti:
1!
328
                if self.log: print("going forward...", i, starti)
1✔
329
                j = starti + 1
1✔
330
                xj, v = self.contourpath.extrapolate(j)
1✔
331
                if j != i: # ultimate goal not reached yet
1✔
332
                    self.goals.insert(0, ('expand-to', i))
1✔
333
                self.goals.insert(0, ('eval-at', j, xj, v, +1))
1✔
334
                return xj, False
1✔
335
            else:
336
                if self.log: print("already done...", i, starti)
×
337
                # we are already done
338
                pass
×
339
        elif i < 0 and self.contourpath.samplingpath.rwd_possible:
1✔
340
            starti, startx, startv, _ = min(self.points)
1✔
341
            if i < starti:
1!
342
                if self.log: print("going backwards...", i, starti)
1✔
343
                j = starti - 1
1✔
344
                xj, v = self.contourpath.extrapolate(j)
1✔
345
                if j != i: # ultimate goal not reached yet
1✔
346
                    self.goals.insert(0, ('expand-to', i))
1✔
347
                self.goals.insert(0, ('eval-at', j, xj, v, -1))
1✔
348
                return xj, False
1✔
349
            else:
350
                if self.log: print("already done...", i, starti)
×
351
                # we are already done
352
                pass
×
353
        else:
354
            # we are trying to go somewhere we cannot.
355
            # skip to other goals
356
            pass
1✔
357
    
358
    def eval_at(self, j, xj, v, sign, Llast):
1✔
359
        if Llast is not None:
1✔
360
            # we can go about our merry way.
361
            self.contourpath.add(j, xj, v, Llast)
1✔
362
        else:
363
            # We stepped outside, so now we need to reflect
364
            self.nreflections += 1
1✔
365
            if self.log: print("reflecting:", xj, v)
1✔
366
            if self.plot: plt.plot(xj[0], xj[1], 'xr')
1✔
367
            vk = self.reverse(xj, v * sign, plot=self.plot) * sign
1✔
368
            if self.log: print("new direction:", vk)
1✔
369
            xk, vk = extrapolate_ahead(sign, xj, vk, contourpath=self.contourpath)
1✔
370
            if self.log: print("reflection point:", xk)
1✔
371
            self.goals.insert(0, ('reflect-at', j, xk, vk, sign))
1✔
372
            return xk, False
1✔
373
    
374
    def reflect_at(self, j, xk, vk, sign, Llast):
1✔
375
        self.nreflections += 1
1✔
376
        if Llast is not None:
1✔
377
            # we can go about our merry way.
378
            self.contourpath.add(j, xk, vk, Llast)
1✔
379
        else:
380
            # we are stuck and have to give up this direction
381
            if self.plot: plt.plot(xk[0], xk[1], 's', mfc='None', mec='r', ms=10)
1✔
382
            if sign == 1:
1✔
383
                self.contourpath.samplingpath.fwd_possible = False
1✔
384
            else:
385
                self.contourpath.samplingpath.rwd_possible = False
1✔
386

387
    
388
    def next(self, Llast=None):
1✔
389
        """
390
        Run steps forward or backward to step i (can be positive or 
391
        negative, 0 is the starting point) 
392
        """
393
        if self.log: print("next() call", Llast)
1✔
394
        while self.goals:
1✔
395
            if self.log: print("goals: ", self.goals)
1✔
396
            goal = self.goals.pop(0)
1✔
397
            if goal[0] == 'sample-at':
1✔
398
                i = goal[1]
1✔
399
                assert Llast is None
1✔
400

401
                if not self.contourpath.samplingpath.fwd_possible \
1✔
402
                and  not self.contourpath.samplingpath.rwd_possible \
403
                and len(self.points) == 1:
404
                    # we are stuck and cannot move.
405
                    # return the starting point as our best effort
406
                    starti, startx, startv, startL = self.points[0]
1✔
407
                    if self.log: print("stuck! returning start point", starti)
1✔
408
                    return (startx, startL), True
1✔
409

410
                # find point
411
                # here we assume all intermediate points have been sampled
412
                pointi = [(j, xj, vj, Lj) for j, xj, vj, Lj in self.points if j == i]
1✔
413
                if len(pointi) != 0:
1✔
414
                    # return the previously sampled point
415
                    _, xj, _, Lj = pointi[0]
1✔
416
                    if self.log: print("returning point", i)
1✔
417
                    return (xj, Lj), True
1✔
418
                
419
                self.continue_sampling(i)
1✔
420
            
421
            elif goal[0] == 'expand-to':
1✔
422
                i = goal[1]
1✔
423
                ret = self.expand_to(i)
1✔
424
                if ret is not None:
1✔
425
                    return ret
1✔
426
            
427
            elif goal[0] == 'eval-at':
1✔
428
                _, j, xj, v, sign = goal
1✔
429
                ret = self.eval_at(j, xj, v, sign, Llast)
1✔
430
                Llast = None
1✔
431
                if ret is not None:
1✔
432
                    return ret
1✔
433
            
434
            elif goal[0] == 'reflect-at':
1!
435
                _, j, xk, vk, sign = goal
1✔
436
                self.reflect_at(j, xk, vk, sign, Llast)
1✔
437
                Llast = None
1✔
438
            
439
            else:
440
                assert False, goal
×
441
        
442
        return None, False
1✔
443

444
class ClockedBisectSampler(ClockedStepSampler):
1✔
445
    """
446
    Step sampler that does not require each step to be evaluated
447
    """
448
    
449
    def continue_sampling(self, i):
1✔
450
        if i > 0:
1✔
451
            starti, _, _, _ = max(self.points)
1✔
452
            #fwd = True
453
            inside = i < starti
1✔
454
            more_possible = self.contourpath.samplingpath.fwd_possible
1✔
455
        else:
456
            starti, _, _, _ = min(self.points)
1✔
457
            #fwd = False
458
            inside = starti < i
1✔
459
            more_possible = self.contourpath.samplingpath.rwd_possible
1✔
460
        
461
        if inside:
1✔
462
            # interpolate point on track
463
            xj, vj, Lj, onpath = self.contourpath.interpolate(i)
1✔
464
            if self.log: print("target is on track, returning interpolation at %d..." % i, xj, Lj)
1✔
465
            return (xj, Lj), True
1✔
466
        elif more_possible:
1✔
467
            # we are not done:
468
            self.goals.insert(0, ('expand-to', i))
1✔
469
            if self.log: print("not done yet, continue expanding to %d..." % i)
1✔
470
            self.goals.append(('sample-at', i))
1✔
471
        else:
472
            # we are not done, but cannot reach the goal.
473
            # reverse. Find position from where to reverse
474
            if i > 0:
1✔
475
                starti, _, _, _ = max(self.points)
1✔
476
                reversei = starti + 1
1✔
477
            else:
478
                starti, _, _, _ = min(self.points)
1✔
479
                reversei = starti - 1
1✔
480
            if self.log: print("reversing at %d..." % starti)
1✔
481
            # how many steps are missing?
482
            self.nreverses += 1
1✔
483
            deltai = i - starti
1✔
484
            # request one less because one step is spent on
485
            # the outside try
486
            if self.log: print("   %d steps to do at %d -> [from %d, delta=%d] targeting %d." % (
1✔
487
                i - starti, starti, reversei, deltai, reversei - deltai))
488
            # make this many steps in the other direction
489
            self.goals.append(('sample-at', reversei - deltai))
1✔
490

491
    def expand_to(self, j):
1✔
492
        # check if we already tried 
493
        
494
        if j > 0 and self.contourpath.samplingpath.fwd_possible:
1✔
495
            #print("going forward...", j)
496
            starti, startx, startv, _ = max(self.points)
1✔
497
            if j > starti:
1✔
498
                xj, v = self.contourpath.extrapolate(j)
1✔
499
                self.goals.insert(0, ('bisect', starti, startx, startv, None, None, None, j, xj, v, +1))
1✔
500
                #self.goals.append(goal)
501
                return xj, False
1✔
502
            else:
503
                # we are already done
504
                if self.log: print("done going to", j, starti)
1✔
505
                pass
1✔
506
        elif j < 0 and self.contourpath.samplingpath.rwd_possible:
1✔
507
            #print("going backward...", j)
508
            starti, startx, startv, _ = min(self.points)
1✔
509
            if j < starti:
1✔
510
                xj, v = self.contourpath.extrapolate(j)
1✔
511
                self.goals.insert(0, ('bisect', starti, startx, startv, None, None, None, j, xj, v, -1))
1✔
512
                #self.goals.append(goal)
513
                return xj, False
1✔
514
            else:
515
                # we are already done
516
                if self.log: print("done going to", j)
1✔
517
                pass
1✔
518
        else:
519
            # we are trying to go somewhere we cannot.
520
            # skip to other goals
521
            if self.log: print("cannot go there", j)
1✔
522
            pass
1✔
523
    
524
    def bisect_at(self, lefti, leftx, leftv, midi, midx, midv, righti, rightx, rightv, sign, Llast):
1✔
525
        # Bisect to find first point outside
526
        
527
        # left is inside (i: index, x: coordinate, v: direction)
528
        # mid is the middle just evaluated (if not None)
529
        # right is outside
530
        if self.log: print("bisecting ...", lefti, midi, righti)
1✔
531
        
532
        if midi is None:
1✔
533
            # check if right is actually outside
534
            if Llast is None:
1✔
535
                # yes it is. continue below
536
                pass
1✔
537
            else:
538
                # right is actually inside
539
                # so we successfully jumped all the way successfully
540
                if self.log: print("successfully went all the way in one jump!")
1✔
541
                self.contourpath.add(righti, rightx, rightv, Llast)
1✔
542
                Llast = None
1✔
543
                return
1✔
544
        else:
545
            # shrink interval based on previous evaluation point
546
            if Llast is not None:
1✔
547
                #print("   inside.  updating interval %d-%d" % (midi, righti))
548
                lefti, leftx, leftv = midi, midx, midv
1✔
549
                self.contourpath.add(midi, midx, midv, Llast)
1✔
550
                Llast = None
1✔
551
            else:
552
                #print("   outside. updating interval %d-%d" % (lefti, midi))
553
                righti, rightx, rightv = midi, midx, midv
1✔
554
        
555
        # we need to bisect. righti was outside
556
        midi = (righti + lefti) // 2
1✔
557
        if midi == lefti or midi == righti:
1✔
558
            # we are done bisecting. right is the first point outside
559
            if self.log: print("  bisecting gave reflection point", righti, rightx, rightv)
1✔
560
            if self.plot: plt.plot(rightx[0], rightx[1], 'xr')
1✔
561
            # compute reflected direction
562
            vk = self.reverse(rightx, rightv * sign, plot=self.plot) * sign
1✔
563
            if self.log: print("  reversing there", rightv)
1✔
564
            # go from reflection point one step in that direction
565
            # that is our new point
566
            xk, vk = extrapolate_ahead(sign, rightx, vk, contourpath=self.contourpath)
1✔
567
            if self.log: print("  making one step from", rightx, rightv, '-->', xk, vk)
1✔
568
            self.nreflections += 1
1✔
569
            if self.log: print("  trying new point,", xk)
1✔
570
            self.goals.insert(0, ('reflect-at', righti, xk, vk, sign))
1✔
571
            return xk, False
1✔
572
        else:
573
            if self.log: print("  continue bisect at", midi)
1✔
574
            # we should evaluate the middle point
575
            midx, midv = extrapolate_ahead(midi - lefti, leftx, leftv, contourpath=self.contourpath)
1✔
576
            # continue bisecting
577
            self.goals.insert(0, ('bisect', lefti, leftx, leftv, midi, midx, midv, righti, rightx, rightv, sign))
1✔
578
            return midx, False
1✔
579
    
580
    
581
    def next(self, Llast=None):
1✔
582
        """
583
        Run steps forward or backward to step i (can be positive or 
584
        negative, 0 is the starting point) 
585
        """
586
        if self.log: print()
1✔
587
        if self.log: print("next() call", Llast)
1✔
588
        while self.goals:
1✔
589
            if self.log: print("goals: ", self.goals)
1✔
590
            goal = self.goals.pop(0)
1✔
591

592
            if goal[0] == 'sample-at':
1✔
593
                i = goal[1]
1✔
594
                assert Llast is None
1✔
595

596
                if not self.contourpath.samplingpath.fwd_possible and not self.contourpath.samplingpath.rwd_possible \
1✔
597
                    and len(self.points) == 1:
598
                    # we are stuck and cannot move.
599
                    # return the starting point as our best effort
600
                    if self.log: print("stuck! returning start point.")
1✔
601
                    starti, startx, startv, startL = self.points[0]
1✔
602
                    return (startx, startL), True
1✔
603

604
                # check if point already sampled
605
                pointi = [(j, xj, vj, Lj) for j, xj, vj, Lj in self.points if j == i]
1✔
606

607
                if len(pointi) == 1:
1✔
608
                    # return the previously sampled point
609
                    _, xj, _, Lj = pointi[0]
1✔
610
                    return (xj, Lj), True
1✔
611
                
612
                self.continue_sampling(i)
1✔
613
            
614
            elif goal[0] == 'expand-to':
1✔
615
                ret = self.expand_to(goal[1])
1✔
616
                if ret is not None:
1✔
617
                    return ret
1✔
618

619
            elif goal[0] == 'bisect':
1✔
620
                _, lefti, leftx, leftv, midi, midx, midv, righti, rightx, rightv, sign = goal
1✔
621
                ret = self.bisect_at(lefti, leftx, leftv, midi, midx, midv, righti, rightx, rightv, sign, Llast)
1✔
622
                Llast = None
1✔
623
                if ret is not None:
1✔
624
                    return ret
1✔
625
            
626
            elif goal[0] == 'reflect-at':
1!
627
                _, j, xk, vk, sign = goal
1✔
628
                self.reflect_at(j, xk, vk, sign, Llast)
1✔
629
                Llast = None
1✔
630
            else:
631
                assert False, goal
×
632
            
633
        return None, False
1✔
634

635
class ClockedNUTSSampler(ClockedBisectSampler):
1✔
636
    """
637
    No-U-turn sampler (NUTS) on flat surfaces.
638
    
639
    """
640
    
641
    def reset(self):
1✔
642
        self.goals = []
1✔
643
        self.left_state = self.points[0][:3]
1✔
644
        self.right_state = self.points[0][:3]
1✔
645
        self.left_warmed_up = False
1✔
646
        self.right_warmed_up = False
1✔
647
        self.tree_built = False
1✔
648
        self.validrange = (0, 0)
1✔
649
        self.tree_depth = 0
1✔
650
        self.current_direction = np.random.randint(2) == 1
1✔
651
    
652
    def next(self, Llast=None):
1✔
653
        """
654
        Alternatingly doubles the number of steps to forward and backward 
655
        direction (which may include reflections, see StepSampler and
656
        BisectSampler).
657
        When track returns (start and end of tree point toward each other),
658
        terminates and returns a random point on that track.
659
        """
660
        while not self.tree_built:
1✔
661
            if self.log: print("continue building tree")
1✔
662
            rwd = self.current_direction
1✔
663
            
664
            if True or self.tree_depth > 7:
1!
665
                print("NUTS step: tree depth %d, %s" % (self.tree_depth, "rwd" if rwd else "fwd"))
1✔
666
            
667
            
668
            # make sure the path is prepared for the desired tree
669
            if rwd:
1✔
670
                goal = ('expand-to', self.left_state[0] - 2**self.tree_depth)
1✔
671
            else:
672
                goal = ('expand-to', self.right_state[0] + 2**self.tree_depth)
1✔
673
            
674
            if goal not in self.goals:
1✔
675
                self.goals.append(goal)
1✔
676
            
677
            # work down any open tasks
678
            while self.goals:
1✔
679
                sample, is_independent = ClockedBisectSampler.next(self, Llast=Llast)
1✔
680
                Llast = None
1✔
681
                if sample is not None:
1✔
682
                    return sample, is_independent
1✔
683
            
684
            # now check if terminating
685
            if rwd:
1✔
686
                self.left_state, _, newrange, newstop = self.build_tree(self.left_state, self.tree_depth, rwd=rwd)
1✔
687
            else:   
688
                _, self.right_state, newrange, newstop = self.build_tree(self.right_state, self.tree_depth, rwd=rwd)
1✔
689
            
690
            if not newstop:
1✔
691
                self.validrange = (min(self.validrange[0], newrange[0]), max(self.validrange[1], newrange[1]))
1✔
692
                print("  new NUTS range: %d..%d" % (self.validrange[0], self.validrange[1]))
1✔
693
            
694
            ileft, xleft, vleft = self.left_state
1✔
695
            iright, xright, vright = self.right_state
1✔
696
            if self.plot: plt.plot([xleft[0], xright[0]], [xleft[1] + (self.tree_depth+1)*0.02, xright[1] + (self.tree_depth+1)*0.02], '--')
1✔
697
            #if j > 5:
698
            #   print("  first-to-last arrow", ileft, iright, xleft, xright, xright-xleft, " velocities:", vright, vleft)
699
            #   print("  stopping criteria: ", newstop, angle(xright-xleft, vleft), angle(xright-xleft, vright))
700
            
701
            # avoid U-turns:
702
            stop = newstop or angle(xright - xleft, vleft) <= 0 or angle(xright - xleft, vright) <= 0
1✔
703
            
704
            # stop when we cannot continue in any direction
705
            stop = stop and (self.contourpath.samplingpath.fwd_possible or self.contourpath.samplingpath.rwd_possible)
1✔
706
            
707
            if stop:
1✔
708
                self.tree_built = True
1✔
709
            else:
710
                self.tree_depth = self.tree_depth + 1
1✔
711
                self.current_direction = np.random.randint(2) == 1
1✔
712
        
713
        # Tree was built, we only need to sample from it
714
        print("sampling between", self.validrange)
1✔
715
        return self.sample_chain_point(self.validrange[0], self.validrange[1])
1✔
716
    
717
    def sample_chain_point(self, a, b):
1✔
718
        """
719
        Gets a point on the track between a and b (inclusive).
720
        
721
        Parameters
722
        ----------
723
        a: array
724
            starting point
725
        b: array
726
            end point
727
        
728
        Returns
729
        --------
730
        newpoint: tuple
731
            tuple of point_coordinates and loglikelihood
732
        is_independent: bool
733
            always True
734
        """
735
        if self.plot: 
1!
736
            for i in range(a, b+1):
×
737
                xi, vi, Li, onpath = self.contourpath.interpolate(i)
×
738
                plt.plot(xi[0], xi[1], '_ ', color='b', ms=10, mew=2)
×
739
        
740
        while True:
741
            i = np.random.randint(a, b+1)
1✔
742
            xi, vi, Li, onpath = self.contourpath.interpolate(i)
1✔
743
            if not onpath: 
1✔
744
                continue
1✔
745
            return (xi, Li), True
1✔
746
    
747
    def build_tree(self, startstate, j, rwd):
1✔
748
        """
749
        Build sub-trees of depth j in direction rwd
750
        
751
        startstate: (i, x, v) state information of first node
752
        j: int height of the tree
753
        rwd: bool whether we go backward
754
        """
755
        if j == 0:
1✔
756
            # base case: go forward one step
757
            i = startstate[0] + (-1 if rwd else +1)
1✔
758
            #self.expand_to_step(i)
759
            #print("  build_tree@%d" % i, rwd, self.contourpath.samplingpath.fwd_possible, self.contourpath.samplingpath.rwd_possible)
760
            xi, vi, _, _ = self.contourpath.interpolate(i)
1✔
761
            if self.plot: plt.plot(xi[0], xi[1], 'x', color='gray')
1✔
762
            # this is a good state, so return it
763
            return (i, xi, vi), (i, xi, vi), (i,i), False
1✔
764
        
765
        # recursion-build the left and right subtrees
766
        (ileft, xleft, vleft), (iright, xright, vright), rangea, stopa = self.build_tree(startstate, j-1, rwd)
1✔
767
        if stopa:
1✔
768
            #print("  one subtree already terminated; returning")
769
            #plt.plot([xright[0], xleft[0]], [xright[1], xleft[1]], ':', color='navy')
770
            return (ileft, xleft, vleft), (iright, xright, vright), (ileft,iright), stopa
1✔
771
        if rwd:
1✔
772
            # go back
773
            (ileft, xleft, vleft), _, rangeb, stopb = self.build_tree((ileft, xleft, vleft), j-1, rwd)
1✔
774
        else:
775
            _, (iright, xright, vright), rangeb, stopb = self.build_tree((iright, xright, vright), j-1, rwd)
1✔
776
        #print("  subtree termination at %d" % j, stopa, stopb, angle(xright-xleft, vleft), angle(xright-xleft, vright), angle(vleft, vright))
777
        #plt.plot([xright[0], xleft[0]], [xright[1], xleft[1]], ':', color='gray')
778
        # NUTS criterion: start to end vector must point in the same direction as velocity at end-point
779
        # additional criterion: start and end velocities must point in opposite directions
780
        stop = stopa or stopb or angle(xright-xleft, vleft) <= 0 or angle(xright-xleft, vright) <= 0 or angle(vleft, vright) <= 0
1✔
781
        return (ileft, xleft, vleft), (iright, xright, vright), (ileft,iright), stop
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc