• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

hahnec / torchimize / 6996238836

26 Nov 2023 02:28PM UTC coverage: 81.932% (+0.7%) from 81.261%
6996238836

push

github

hahnec
feat(ver): version increment

107 of 155 branches covered (0.0%)

Branch coverage included in aggregate %.

351 of 404 relevant lines covered (86.88%)

2.61 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

79.76
/torchimize/functions/parallel/lma_fun_parallel.py
1
__author__ = "Christopher Hahne"
3✔
2
__email__ = "inbox@christopherhahne.de"
3✔
3
__license__ = """
3✔
4
    Copyright (c) 2022 Christopher Hahne <inbox@christopherhahne.de>
5
    This program is free software: you can redistribute it and/or modify
6
    it under the terms of the GNU General Public License as published by
7
    the Free Software Foundation, either version 3 of the License, or
8
    (at your option) any later version.
9
    This program is distributed in the hope that it will be useful,
10
    but WITHOUT ANY WARRANTY; without even the implied warranty of
11
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12
    GNU General Public License for more details.
13
    You should have received a copy of the GNU General Public License
14
    along with this program. If not, see <http://www.gnu.org/licenses/>.
15
"""
16

17
import torch 
3✔
18
from typing import Union, Callable, List, Tuple
3✔
19

20
from torchimize.functions.jacobian import jacobian_approx_t
3✔
21
from torchimize.functions.parallel.newton_parallel import newton_step_parallel
3✔
22

23

24
def lsq_lma_parallel(
3✔
25
        p: torch.Tensor,
26
        function: Callable, 
27
        jac_function: Callable = None,
28
        args: Union[Tuple, List] = (),
29
        wvec: torch.Tensor = None,
30
        ftol: float = 1e-8,
31
        ptol: float = 1e-8,
32
        gtol: float = 1e-8,
33
        tau: float = 1e-3, 
34
        meth: str = 'lev',
35
        rho1: float = .25, 
36
        rho2: float = .75, 
37
        beta: float = 2, 
38
        gama: float = 3, 
39
        max_iter: int = 100, 
40
    ) -> List[torch.Tensor]:
41
    """
42
    Levenberg-Marquardt implementation for parallel least-squares fitting of non-linear functions
43
    
44
    :param p: initial value(s)
45
    :param function: user-provided function which takes p (and additional arguments) as input
46
    :param jac_fun: user-provided Jacobian function which takes p (and additional arguments) as input
47
    :param args: optional arguments passed to function
48
    :param wvec: weights vector used in reduction of multiple costs
49
    :param ftol: relative change in cost function as stop condition
50
    :param ptol: relative change in independant variables as stop condition
51
    :param gtol: maximum gradient tolerance as stop condition
52
    :param tau: factor to initialize damping parameter
53
    :param meth: method which is default 'lev' for Levenberg and otherwise Marquardt
54
    :param rho1: first gain factor threshold for damping parameter adjustment for Marquardt
55
    :param rho2: second gain factor threshold for damping parameter adjustment for Marquardt
56
    :param beta: multiplier for damping parameter adjustment for Marquardt
57
    :param gama: divisor for damping parameter adjustment for Marquardt
58
    :param max_iter: maximum number of iterations
59
    :return: list of results
60
    """
61

62
    if len(args) > 0:
3!
63
        # pass optional arguments to function
64
        fun = lambda p: function(p, *args)
3✔
65
    else:
66
        fun = function
×
67

68
    if jac_function is None:
3!
69
        # use numerical Jacobian if analytical is not provided
70
        jac_fun = lambda p: jacobian_approx_t(p, f=fun)
×
71
    else:
72
        jac_fun = lambda p: jac_function(p, *args)
3✔
73
    
74
    assert len(p.shape) == 2, 'parameter tensor is supposed to have 2 dims, but has %s' % str(len(p.shape))
3✔
75

76
    wvec = torch.ones(1, device=p.device, dtype=p.dtype) if wvec is None else wvec
3✔
77
    D = torch.eye(p.shape[-1], dtype=p.dtype, device=p.device)[None, ...].repeat(p.shape[0], 1, 1)
3✔
78
    u = tau * torch.max(torch.diagonal(D, dim1=-2, dim2=-1), 1)[0]
3✔
79
    sinf = torch.tensor([-torch.inf, torch.inf], dtype=p.dtype, device=p.device)
3✔
80
    ones = torch.ones(p.shape[0], dtype=p.dtype, device=p.device)
3✔
81
    v = 2*ones
3✔
82

83
    if meth == 'lev':
3!
84
        lm_uv_step = lambda rho, u, v: levenberg_uv(rho, u, v, ones=ones)
3✔
85
        lm_dg_step = lambda H, D: D * torch.ones_like(H)
3✔
86
    else:
87
        lm_uv_step = lambda rho, u, v=None: marquardt_uv(rho, u, v, rho1=rho1, rho2=rho2, beta=beta, gama=gama)
×
88
        lm_dg_step = lambda H, D: D * torch.max(torch.maximum(H.diagonal(dim1=2), D.diagonal(dim1=2)), dim=1)[0][..., None, None]
×
89

90
    p_list = []
3✔
91
    f_prev = torch.zeros(1, device=p.device, dtype=p.dtype)
3✔
92
    while len(p_list) < max_iter:
3!
93

94
        # levenberg-marquardt step
95
        pn, f, g, H = newton_step_parallel(p, fun, jac_fun, wvec)
3✔
96
        D = lm_dg_step(H, D)
3✔
97
        Hu = H+u[:, None, None]*D
3✔
98
        h = -torch.linalg.lstsq(Hu.double(), g.double(), rcond=None, driver=None)[0].to(dtype=p.dtype)
3✔
99
        f_h = fun(pn+h)
3✔
100
        rho_nom = torch.einsum('bcp,bci->bc', f, f).sum(1) - torch.einsum('bcp,bci->bc', f_h, f_h).sum(1)
3✔
101
        rho_denom = torch.einsum('bnp,bni->bi', h[..., None], (u[:, None]*h-g)[..., None])[..., 0]
3✔
102
        rho = rho_nom / rho_denom
3✔
103
        rho[rho_denom==0] = sinf[(rho_nom > 0).type(torch.int64)][rho_denom==0]
3✔
104
        u, v = lm_uv_step(rho, u, v)
3✔
105
        pn[rho>0, ...] += h[rho>0, ...]
3✔
106

107
        # batched stop conditions
108
        gcon = torch.max(abs(g), dim=-1)[0] < gtol
3✔
109
        pcon = (h**2).sum(-1)**.5 < ptol*(ptol + (p**2).sum(-1)**.5)
3✔
110
        fcon = ((f_prev-f)**2).sum((-1,-2)) < ((ftol*f)**2).sum((-1,-2)) if (rho > 0).sum() > 0 and f_prev.shape == f.shape else torch.zeros_like(gcon)
3✔
111
        f_prev = f.clone()
3✔
112

113
        # update only parameters, which have not converged yet
114
        converged = gcon | pcon | fcon
3✔
115
        p[~converged] = pn[~converged]
3✔
116
        p_list.append(p.clone())
3✔
117
        
118
        if converged.all():
3✔
119
            break
3✔
120

121
    return p_list
3✔
122

123

124
def levenberg_uv(
3✔
125
        rho: torch.Tensor,
126
        u: torch.Tensor,
127
        v: torch.Tensor,
128
        ones: torch.Tensor,
129
    ) -> Tuple[torch.Tensor, torch.Tensor]:
130

131
    u[rho>0] *= torch.maximum(ones/3, 1-(2*rho-1)**3)[rho>0]
3✔
132
    u[rho<0] *= v[rho<0]
3✔
133
    v[rho>0] = 2*ones[rho>0]
3✔
134
    v[rho<0] *= 2
3✔
135

136
    return u, v
3✔
137

138

139
def marquardt_uv(
3✔
140
        rho: torch.Tensor,
141
        u: torch.Tensor,
142
        v: torch.Tensor,
143
        rho1: float,
144
        rho2: float,
145
        beta: float,
146
        gama: float,
147
    ) -> Tuple[torch.Tensor, torch.Tensor]:
148

149
    u[rho < rho1] *= beta
×
150
    u[rho > rho2] /= gama
×
151

152
    return u, v
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc