• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

rmcar17 / cogent3 / 8919598942

02 May 2024 04:39AM UTC coverage: 91.82% (-0.008%) from 91.828%
8919598942

push

github

web-flow
Merge pull request #1849 from GavinHuttley/develop

DEP: deleted code marked for deprecation in upcoming release, fixes #1840

30152 of 32838 relevant lines covered (91.82%)

11.01 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

84.92
/src/cogent3/maths/scipy_optimize.py
1
# We don't want to depend on the monolithic, fortranish,
2
# Num-overlapping, mac-unfriendly SciPy.  But this
3
# module is too good to pass up. It has been lightly customised for
4
# use in Cogent.  Changes made to fmin_powell and brent: allow custom
5
# line search function (to allow bound_brent to be passed in), cope with
6
# infinity, tol specified as an absolute value, not a proportion of f,
7
# and more info passed out via callback.
8

9
# ******NOTICE***************
10
# optimize.py module by Travis E. Oliphant
11
#
12
# You may copy and use this module as you see fit with no
13
# guarantee implied provided you keep this notice in all copies.
14
# *****END NOTICE************
15

16

17
# Minimization routines
18

19
__all__ = ["fmin_powell", "brent", "bracket"]
12✔
20

21
import builtins
12✔
22

23
import numpy
12✔
24

25
from numpy import absolute, asarray, atleast_1d, eye, sqrt, squeeze
12✔
26

27

28
# These have been copied from Numeric's MLab.py
29
# I don't think they made the transition to scipy_core
30

31

32
def max(m, axis=0):
12✔
33
    """max(m,axis=0) returns the maximum of m along dimension axis."""
34
    m = asarray(m)
×
35
    return numpy.maximum.reduce(m, axis)
×
36

37

38
def min(m, axis=0):
12✔
39
    """min(m,axis=0) returns the minimum of m along dimension axis."""
40
    m = asarray(m)
×
41
    return numpy.minimum.reduce(m, axis)
×
42

43

44
def is_array_scalar(x):
12✔
45
    """Test whether `x` is either a scalar or an array scalar."""
46
    return len(atleast_1d(x) == 1)
×
47

48

49
abs = absolute
12✔
50

51
pymin = builtins.min
12✔
52
pymax = builtins.max
12✔
53

54

55
_epsilon = sqrt(numpy.finfo(float).eps)
12✔
56

57

58
def wrap_function(function, args):
12✔
59
    ncalls = [0]
12✔
60

61
    def function_wrapper(x):
12✔
62
        ncalls[0] += 1
12✔
63
        return function(x, *args)
12✔
64

65
    return ncalls, function_wrapper
12✔
66

67

68
class Brent:
12✔
69
    # need to rethink design of __init__
70

71
    def __init__(self, func, tol=1.48e-8, maxiter=500):
12✔
72
        self.func = func
12✔
73
        self.tol = tol
12✔
74
        self.maxiter = maxiter
12✔
75
        self._mintol = 1.0e-11
12✔
76
        self._cg = 0.3819660
12✔
77
        self.xmin = None
12✔
78
        self.fval = None
12✔
79
        self.iter = 0
12✔
80
        self.funcalls = 0
12✔
81
        self.brack = None
12✔
82
        self._brack_info = None
12✔
83

84
    # need to rethink design of set_bracket (new options, etc)
85
    def set_bracket(self, brack=None):
12✔
86
        self.brack = brack
12✔
87
        self._brack_info = self.get_bracket_info()
12✔
88

89
    def get_bracket_info(self):
12✔
90
        # set up
91
        func = self.func
12✔
92
        brack = self.brack
12✔
93
        ### BEGIN core bracket_info code ###
94
        ### carefully DOCUMENT any CHANGES in core ##
95
        if brack is None:
12✔
96
            xa, xb, xc, fa, fb, fc, funcalls = bracket(func)
×
97
        elif len(brack) == 2:
12✔
98
            xa, xb, xc, fa, fb, fc, funcalls = bracket(func, xa=brack[0], xb=brack[1])
12✔
99
        elif len(brack) == 3:
×
100
            xa, xb, xc = brack
×
101
            if xa > xc:  # swap so xa < xc can be assumed
×
102
                dum = xa
×
103
                xa = xc
×
104
                xc = dum
×
105
            assert (xa < xb) and (xb < xc), "Not a bracketing interval."
×
106
            fa = func(xa)
×
107
            fb = func(xb)
×
108
            fc = func(xc)
×
109
            assert (fb < fa) and (fb < fc), "Not a bracketing interval."
×
110
            funcalls = 3
×
111
        else:
112
            raise ValueError("Bracketing interval must be length 2 or 3 sequence.")
×
113
        ### END core bracket_info code ###
114

115
        self.funcalls += funcalls
12✔
116
        return xa, xb, xc, fa, fb, fc
12✔
117

118
    def optimize(self):
12✔
119
        # set up for optimization
120
        func = self.func
12✔
121
        if self._brack_info is None:
12✔
122
            self.set_bracket(None)
×
123
        xa, xb, xc, fa, fb, fc = self._brack_info
12✔
124
        _mintol = self._mintol
12✔
125
        _cg = self._cg
12✔
126
        #################################
127
        # BEGIN CORE ALGORITHM
128
        # we are making NO CHANGES in this
129
        #################################
130
        x = w = v = xb
12✔
131
        fw = fv = fx = func(x)
12✔
132
        if xa < xc:
12✔
133
            a = xa
12✔
134
            b = xc
12✔
135
        else:
136
            a = xc
12✔
137
            b = xa
12✔
138
        deltax = 0.0
12✔
139
        funcalls = 1
12✔
140
        iter = 0
12✔
141
        while iter < self.maxiter:
12✔
142
            tol1 = self.tol * abs(x) + _mintol
12✔
143
            tol2 = 2.0 * tol1
12✔
144
            xmid = 0.5 * (a + b)
12✔
145
            if abs(x - xmid) < (tol2 - 0.5 * (b - a)):  # check for convergence
12✔
146
                break
12✔
147
            infinities_present = [f for f in [fw, fv, fx] if numpy.isposinf(f)]
12✔
148
            if infinities_present or (abs(deltax) <= tol1):
12✔
149
                if x >= xmid:
12✔
150
                    deltax = a - x  # do a golden section step
12✔
151
                else:
152
                    deltax = b - x
12✔
153
                rat = _cg * deltax
12✔
154
            else:  # do a parabolic step
155
                tmp1 = (x - w) * (fx - fv)
12✔
156
                tmp2 = (x - v) * (fx - fw)
12✔
157
                p = (x - v) * tmp2 - (x - w) * tmp1
12✔
158
                tmp2 = 2.0 * (tmp2 - tmp1)
12✔
159
                if tmp2 > 0.0:
12✔
160
                    p = -p
12✔
161
                tmp2 = abs(tmp2)
12✔
162
                dx_temp = deltax
12✔
163
                deltax = rat
12✔
164
                # check parabolic fit
165
                if (
12✔
166
                    (p > tmp2 * (a - x))
167
                    and (p < tmp2 * (b - x))
168
                    and (abs(p) < abs(0.5 * tmp2 * dx_temp))
169
                ):
170
                    rat = p * 1.0 / tmp2  # if parabolic step is useful.
12✔
171
                    u = x + rat
12✔
172
                    if (u - a) < tol2 or (b - u) < tol2:
12✔
173
                        if xmid - x >= 0:
12✔
174
                            rat = tol1
12✔
175
                        else:
176
                            rat = -tol1
12✔
177
                else:
178
                    if x >= xmid:
12✔
179
                        deltax = a - x  # if it's not do a golden section step
12✔
180
                    else:
181
                        deltax = b - x
12✔
182
                    rat = _cg * deltax
12✔
183

184
            if abs(rat) < tol1:  # update by at least tol1
12✔
185
                if rat >= 0:
12✔
186
                    u = x + tol1
12✔
187
                else:
188
                    u = x - tol1
12✔
189
            else:
190
                u = x + rat
12✔
191
            fu = func(u)  # calculate new output value
12✔
192
            funcalls += 1
12✔
193

194
            if fu > fx:  # if it's bigger than current
12✔
195
                if u < x:
12✔
196
                    a = u
12✔
197
                else:
198
                    b = u
12✔
199
                if (fu <= fw) or (w == x):
12✔
200
                    v = w
12✔
201
                    w = u
12✔
202
                    fv = fw
12✔
203
                    fw = fu
12✔
204
                elif (fu <= fv) or (v == x) or (v == w):
12✔
205
                    v = u
12✔
206
                    fv = fu
12✔
207
            else:
208
                if u >= x:
12✔
209
                    a = x
12✔
210
                else:
211
                    b = x
12✔
212
                v = w
12✔
213
                w = x
12✔
214
                x = u
12✔
215
                fv = fw
12✔
216
                fw = fx
12✔
217
                fx = fu
12✔
218

219
            iter += 1
12✔
220
        #################################
221
        # END CORE ALGORITHM
222
        #################################
223

224
        self.xmin = x
12✔
225
        self.fval = fx
12✔
226
        self.iter = iter
12✔
227
        self.funcalls = funcalls
12✔
228

229
    def get_result(self, full_output=False):
12✔
230
        if full_output:
12✔
231
            return self.xmin, self.fval, self.iter, self.funcalls
12✔
232
        else:
233
            return self.xmin
×
234

235

236
def brent(func, brack=None, tol=1.48e-8, full_output=0, maxiter=500):
12✔
237
    """Given a function of one-variable and a possible bracketing interval,
238
    return the minimum of the function isolated to a fractional precision of
239
    tol.
240

241
    :Parameters:
242

243
        func : callable f(x)
244
            Objective function.
245
        brack : tuple
246
            Triple (a,b,c) where (a<b<c) and func(b) <
247
            func(a),func(c).  If bracket consists of two numbers (a,c)
248
            then they are assumed to be a starting interval for a
249
            downhill bracket search (see `bracket`); it doesn't always
250
            mean that the obtained solution will satisfy a<=x<=c.
251
        full_output : bool
252
            If True, return all output args (xmin, fval, iter,
253
            funcalls).
254

255
    :Returns:
256

257
        xmin : ndarray
258
            Optimum point.
259
        fval : float
260
            Optimum value.
261
        iter : int
262
            Number of iterations.
263
        funcalls : int
264
            Number of objective function evaluations made.
265

266
    Notes
267
    -----
268

269
    Uses inverse parabolic interpolation when possible to speed up convergence
270
    of golden section method.
271

272
    """
273
    brent = Brent(func=func, tol=tol, maxiter=maxiter)
12✔
274
    brent.set_bracket(brack)
12✔
275
    brent.optimize()
12✔
276
    return brent.get_result(full_output=full_output)
12✔
277

278

279
def bracket(func, xa=0.0, xb=1.0, args=(), grow_limit=110.0, maxiter=1000):
12✔
280
    """Given a function and distinct initial points, search in the
281
    downhill direction (as defined by the initital points) and return
282
    new points xa, xb, xc that bracket the minimum of the function
283
    f(xa) > f(xb) < f(xc). It doesn't always mean that obtained
284
    solution will satisfy xa<=x<=xb
285

286
    :Parameters:
287

288
        func : callable f(x,*args)
289
            Objective function to minimize.
290
        xa, xb : float
291
            Bracketing interval.
292
        args : tuple
293
            Additional arguments (if present), passed to `func`.
294
        grow_limit : float
295
            Maximum grow limit.
296
        maxiter : int
297
            Maximum number of iterations to perform.
298

299
    :Returns: xa, xb, xc, fa, fb, fc, funcalls
300

301
        xa, xb, xc : float
302
            Bracket.
303
        fa, fb, fc : float
304
            Objective function values in bracket.
305
        funcalls : int
306
            Number of function evaluations made.
307

308
    """
309
    _gold = 1.618034
12✔
310
    _verysmall_num = 1e-21
12✔
311
    fa = func(*(xa,) + args)
12✔
312
    fb = func(*(xb,) + args)
12✔
313
    if fa < fb:  # Switch so fa > fb
12✔
314
        dum = xa
12✔
315
        xa = xb
12✔
316
        xb = dum
12✔
317
        dum = fa
12✔
318
        fa = fb
12✔
319
        fb = dum
12✔
320
    xc = xb + _gold * (xb - xa)
12✔
321
    fc = func(*((xc,) + args))
12✔
322
    funcalls = 3
12✔
323
    iter = 0
12✔
324
    while fc < fb:
12✔
325
        tmp1 = (xb - xa) * (fb - fc)
12✔
326
        tmp2 = (xb - xc) * (fb - fa)
12✔
327
        val = tmp2 - tmp1
12✔
328
        if abs(val) < _verysmall_num:
12✔
329
            denom = 2.0 * _verysmall_num
×
330
        else:
331
            denom = 2.0 * val
12✔
332
        w = xb - ((xb - xc) * tmp2 - (xb - xa) * tmp1) / denom
12✔
333
        wlim = xb + grow_limit * (xc - xb)
12✔
334
        if iter > maxiter:
12✔
335
            raise RuntimeError("Too many iterations.")
×
336
        iter += 1
12✔
337
        if (w - xc) * (xb - w) > 0.0:
12✔
338
            fw = func(*((w,) + args))
12✔
339
            funcalls += 1
12✔
340
            if fw < fc:
12✔
341
                xa = xb
12✔
342
                xb = w
12✔
343
                fa = fb
12✔
344
                fb = fw
12✔
345
                return xa, xb, xc, fa, fb, fc, funcalls
12✔
346
            elif fw > fb:
12✔
347
                xc = w
×
348
                fc = fw
×
349
                return xa, xb, xc, fa, fb, fc, funcalls
×
350
            w = xc + _gold * (xc - xb)
12✔
351
            fw = func(*((w,) + args))
12✔
352
            funcalls += 1
12✔
353
        elif (w - wlim) * (wlim - xc) >= 0.0:
12✔
354
            w = wlim
4✔
355
            fw = func(*((w,) + args))
4✔
356
            funcalls += 1
4✔
357
        elif (w - wlim) * (xc - w) > 0.0:
12✔
358
            fw = func(*((w,) + args))
12✔
359
            funcalls += 1
12✔
360
            if fw < fc:
12✔
361
                xb = xc
12✔
362
                xc = w
12✔
363
                w = xc + _gold * (xc - xb)
12✔
364
                fb = fc
12✔
365
                fc = fw
12✔
366
                fw = func(*((w,) + args))
12✔
367
                funcalls += 1
12✔
368
        else:
369
            w = xc + _gold * (xc - xb)
12✔
370
            fw = func(*((w,) + args))
12✔
371
            funcalls += 1
12✔
372
        xa = xb
12✔
373
        xb = xc
12✔
374
        xc = w
12✔
375
        fa = fb
12✔
376
        fb = fc
12✔
377
        fc = fw
12✔
378
    return xa, xb, xc, fa, fb, fc, funcalls
12✔
379

380

381
def _linesearch_powell(linesearch, func, p, xi, tol):
12✔
382
    """Line-search algorithm using fminbound.
383

384
    Find the minimium of the function ``func(x0+ alpha*direc)``.
385

386
    """
387

388
    def myfunc(alpha):
12✔
389
        return func(p + alpha * xi)
12✔
390

391
    alpha_min, fret, iter, num = linesearch(myfunc, full_output=1, tol=tol)
12✔
392
    xi = alpha_min * xi
12✔
393
    return squeeze(fret), p + xi, xi
12✔
394

395

396
def fmin_powell(
12✔
397
    func,
398
    x0,
399
    args=(),
400
    xtol=1e-4,
401
    ftol=1e-4,
402
    maxiter=None,
403
    maxfun=None,
404
    full_output=0,
405
    disp=1,
406
    retall=0,
407
    callback=None,
408
    direc=None,
409
    linesearch=brent,
410
):
411
    """Minimize a function using modified Powell's method.
412

413
    :Parameters:
414

415
      func : callable f(x,*args)
416
          Objective function to be minimized.
417
      x0 : ndarray
418
          Initial guess.
419
      args : tuple
420
          Eextra arguments passed to func.
421
      callback : callable
422
          An optional user-supplied function, called after each
423
          iteration.  Called as ``callback(n,xk,f)``, where ``xk`` is the
424
          current parameter vector.
425
      direc : ndarray
426
          Initial direction set.
427

428
    :Returns: (xopt, {fopt, xi, direc, iter, funcalls, warnflag}, {allvecs})
429

430
        xopt : ndarray
431
            Parameter which minimizes `func`.
432
        fopt : number
433
            Value of function at minimum: ``fopt = func(xopt)``.
434
        direc : ndarray
435
            Current direction set.
436
        iter : int
437
            Number of iterations.
438
        funcalls : int
439
            Number of function calls made.
440
        warnflag : int
441
            Integer warning flag:
442
                1 : Maximum number of function evaluations.
443
                2 : Maximum number of iterations.
444
        allvecs : list
445
            List of solutions at each iteration.
446

447
    *Other Parameters*:
448

449
      xtol : float
450
          Line-search error tolerance.
451
      ftol : float
452
          Absolute error in ``func(xopt)`` acceptable for convergence.
453
      maxiter : int
454
          Maximum number of iterations to perform.
455
      maxfun : int
456
          Maximum number of function evaluations to make.
457
      full_output : bool
458
          If True, fopt, xi, direc, iter, funcalls, and
459
          warnflag are returned.
460
      disp : bool
461
          If True, print convergence messages.
462
      retall : bool
463
          If True, return a list of the solution at each iteration.
464

465

466
    :Notes:
467

468
        Uses a modification of Powell's method to find the minimum of
469
        a function of N variables.
470

471
    """
472
    # we need to use a mutable object here that we can update in the
473
    # wrapper function
474
    fcalls, func = wrap_function(func, args)
12✔
475
    x = asarray(x0).flatten()
12✔
476
    if retall:
12✔
477
        allvecs = [x]
×
478
    N = len(x)
12✔
479
    rank = len(x.shape)
12✔
480
    if not -1 < rank < 2:
12✔
481
        raise ValueError("Initial guess must be a scalar or rank-1 sequence.")
×
482
    if maxiter is None:
12✔
483
        maxiter = N * 1000
12✔
484
    if maxfun is None:
12✔
485
        maxfun = N * 1000
12✔
486

487
    if direc is None:
12✔
488
        direc = eye(N, dtype=float)
12✔
489
    else:
490
        direc = asarray(direc, dtype=float)
×
491

492
    fval = squeeze(func(x))
12✔
493
    x1 = x.copy()
12✔
494
    iter = 0
12✔
495
    ilist = list(range(N))
12✔
496
    while True:
9✔
497
        fx = fval
12✔
498
        bigind = 0
12✔
499
        delta = 0.0
12✔
500
        for i in ilist:
12✔
501
            direc1 = direc[i]
12✔
502
            fx2 = fval
12✔
503
            fval, x, direc1 = _linesearch_powell(
12✔
504
                linesearch, func, x, direc1, xtol * 100
505
            )
506
            if (fx2 - fval) > delta:
12✔
507
                delta = fx2 - fval
12✔
508
                bigind = i
12✔
509
        iter += 1
12✔
510
        if callback is not None:
12✔
511
            callback(fcalls[0], x, fval, delta)
12✔
512
        if retall:
12✔
513
            allvecs.append(x)
×
514
        if abs(fx - fval) < ftol:
12✔
515
            break
12✔
516
        if fcalls[0] >= maxfun:
12✔
517
            break
×
518
        if iter >= maxiter:
12✔
519
            break
×
520

521
        # Construct the extrapolated point
522
        direc1 = x - x1
12✔
523
        x2 = 2 * x - x1
12✔
524
        x1 = x.copy()
12✔
525
        fx2 = squeeze(func(x2))
12✔
526

527
        if fx > fx2:
12✔
528
            t = 2.0 * (fx + fx2 - 2.0 * fval)
12✔
529
            temp = fx - fval - delta
12✔
530
            t *= temp * temp
12✔
531
            temp = fx - fx2
12✔
532
            t -= delta * temp * temp
12✔
533
            if t < 0.0:
12✔
534
                fval, x, direc1 = _linesearch_powell(
12✔
535
                    linesearch, func, x, direc1, xtol * 100
536
                )
537
                direc[bigind] = direc[-1]
12✔
538
                direc[-1] = direc1
12✔
539

540
    warnflag = 0
12✔
541
    if fcalls[0] >= maxfun:
12✔
542
        warnflag = 1
×
543
        if disp:
×
544
            print(
×
545
                "Warning: Maximum number of function evaluations has " "been exceeded."
546
            )
547
    elif iter >= maxiter:
12✔
548
        warnflag = 2
×
549
        if disp:
×
550
            print("Warning: Maximum number of iterations has been exceeded")
×
551
    else:
552
        if disp:
12✔
553
            print("Optimization terminated successfully.")
×
554
            print(f"         Current function value: {fval:f}")
×
555
            print("         Iterations: %d" % iter)
×
556
            print("         Function evaluations: %d" % fcalls[0])
×
557

558
    x = squeeze(x)
12✔
559

560
    if full_output:
12✔
561
        retlist = x, fval, direc, iter, fcalls[0], warnflag
12✔
562
        if retall:
12✔
563
            retlist += (allvecs,)
×
564
    else:
565
        retlist = x
×
566
        if retall:
×
567
            retlist = (x, allvecs)
×
568

569
    return retlist
12✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc