• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

hahnec / torchimize / 6996238836

26 Nov 2023 02:28PM UTC coverage: 81.932% (+0.7%) from 81.261%
6996238836

push

github

hahnec
feat(ver): version increment

107 of 155 branches covered (0.0%)

Branch coverage included in aggregate %.

351 of 404 relevant lines covered (86.88%)

2.61 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

82.81
/torchimize/functions/parallel/gda_fun_parallel.py
1
__author__ = "Christopher Hahne"
3✔
2
__email__ = "inbox@christopherhahne.de"
3✔
3
__license__ = """
3✔
4
    Copyright (c) 2022 Christopher Hahne <inbox@christopherhahne.de>
5
    This program is free software: you can redistribute it and/or modify
6
    it under the terms of the GNU General Public License as published by
7
    the Free Software Foundation, either version 3 of the License, or
8
    (at your option) any later version.
9
    This program is distributed in the hope that it will be useful,
10
    but WITHOUT ANY WARRANTY; without even the implied warranty of
11
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12
    GNU General Public License for more details.
13
    You should have received a copy of the GNU General Public License
14
    along with this program. If not, see <http://www.gnu.org/licenses/>.
15
"""
16

17
import torch
3✔
18
from typing import Union, Callable, Tuple, List
3✔
19

20
from torchimize.functions.jacobian import jacobian_approx_t
3✔
21

22

23
def gradient_descent_parallel(
3✔
24
        p: torch.Tensor,
25
        function: Callable, 
26
        jac_function: Callable = None,
27
        args: Union[Tuple, List] = (),
28
        wvec: torch.Tensor = None,
29
        ftol: float = 1e-8,
30
        ptol: float = 1e-8,
31
        gtol: float = 1e-8,
32
        l: float = 1.,
33
        max_iter: int = 100,
34
    ) -> List[torch.Tensor]:
35
    """
36
    Gradient Descent implementation for parallel least-squares fitting of non-linear functions with conditions.
37

38
    :param p: initial value(s)
39
    :param function: user-provided function which takes p (and additional arguments) as input
40
    :param jac_fun: user-provided Jacobian function which takes p (and additional arguments) as input
41
    :param args: optional arguments passed to function
42
    :param wvec: weights vector used in reduction of multiple costs
43
    :param ftol: relative change in cost function as stop condition
44
    :param ptol: relative change in independant variables as stop condition
45
    :param gtol: maximum gradient tolerance as stop condition
46
    :param l: step size damping parameter
47
    :param max_iter: maximum number of iterations
48
    :return: list of results
49
    """
50

51
    if len(args) > 0:
3!
52
        # pass optional arguments to function
53
        fun = lambda p: function(p, *args)
3✔
54
    else:
55
        fun = function
×
56

57
    if jac_function is None:
3!
58
        # use numerical Jacobian if analytical is not provided
59
        jac_fun = lambda p: jacobian_approx_t(p, f=fun)
×
60
    else:
61
        jac_fun = lambda p: jac_function(p, *args)
3✔
62

63
    assert len(p.shape) == 2, 'parameter tensor is supposed to have 2 dims, but has %s' % str(len(p.shape))
3✔
64

65
    wvec = torch.ones(1, device=p.device, dtype=p.dtype) if wvec is None else wvec
3✔
66

67
    p_list = []
3✔
68
    f_prev = torch.zeros(1, device=p.device, dtype=p.dtype)
3✔
69
    while len(p_list) < max_iter:
3✔
70
        pn, f, h = newton_raphson_step(p, fun, jac_fun, wvec, l)
3✔
71
        g = h/l
3✔
72

73
        # batched stop conditions
74
        gcon = torch.max(abs(g), dim=-1)[0] < gtol
3✔
75
        pcon = (h**2).sum(-1)**.5 < ptol*(ptol + (p**2).sum(-1)**.5)
3✔
76
        fcon = ((f_prev-f)**2).sum((-1,-2)) < ((ftol*f)**2).sum((-1,-2)) if f_prev.shape == f.shape else torch.zeros_like(gcon)
3✔
77
        f_prev = f.clone()
3✔
78

79
        # update only parameters, which have not converged yet
80
        converged = gcon | pcon | fcon
3✔
81
        p[~converged] = pn[~converged]
3✔
82
        p_list.append(p.clone())
3✔
83
        
84
        if converged.all():
3✔
85
            break
3✔
86

87
    return p_list
3✔
88

89

90
def gradient_descent_parallel_plain(
3✔
91
        p: torch.Tensor,
92
        function: Callable, 
93
        jac_function: Callable,
94
        wvec: torch.Tensor,
95
        l: float = 1.,
96
        max_iter: int = 100,
97
    ) -> torch.Tensor:
98
    """
99
    Gradient Descent implementation for parallel least-squares fitting of non-linear functions without conditions.
100

101
    :param p: initial value(s)
102
    :param function: user-provided function which takes p (and additional arguments) as input
103
    :param jac_fun: user-provided Jacobian function which takes p (and additional arguments) as input
104
    :param wvec: weights vector used in reduction of multiple costs
105
    :param l: step size damping parameter
106
    :param max_iter: maximum number of iterations
107
    :return: result
108
    """
109

110
    for _ in range(max_iter):
3✔
111
        p = newton_raphson_step(p, function=function, jac_function=jac_function, wvec=wvec, l=l)[0]
3✔
112

113
    return p
3✔
114

115

116
def newton_raphson_step(
3✔
117
        p: torch.Tensor,
118
        function: Callable,
119
        jac_function: Callable,
120
        wvec: torch.Tensor,
121
        l: float = 1.,
122
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
123
    """
124
    Gradient Descent step function for parallel least-squares fitting of non-linear functions
125

126
    :param p: current guess
127
    :param function: user-provided function which takes p (and additional arguments) as input
128
    :param jac_fun: user-provided Jacobian function which takes p (and additional arguments) as input
129
    :param wvec: weights vector used in reduction of multiple costs
130
    :param l: step size damping parameter
131
    :return: tuple of results
132
    """
133

134
    fc = function(p)
3✔
135
    jc = jac_function(p)
3✔
136
    f = torch.einsum('bcp,c->bp', fc, wvec)
3✔
137
    j = torch.einsum('bcpi,c->bpi', jc, wvec)
3✔
138
    try:
3✔
139
        h = torch.linalg.lstsq(j.double(), f.double(), rcond=None, driver=None)[0].to(dtype=p.dtype)
3✔
140
    except torch._C._LinAlgError:
×
141
        jmin_rank = min(j.shape[1:])
×
142
        rank_mask = torch.linalg.matrix_rank(j) < jmin_rank
×
143
        j[rank_mask] = torch.eye(*j.shape[1:], dtype=j.dtype, device=j.device)
×
144
        h = torch.linalg.lstsq(j.double(), f.double(), rcond=None, driver=None)[0].to(dtype=p.dtype)
×
145

146
    p -= l*h
3✔
147

148
    return p, f, h
3✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc