• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

hahnec / torchimize / 5028371039

pending completion
5028371039

push

github-actions

hahnec
fix(fun): consider new stateless function for torch v2.0.1

106 of 157 branches covered (67.52%)

Branch coverage included in aggregate %.

6 of 6 new or added lines in 1 file covered. (100.0%)

340 of 398 relevant lines covered (85.43%)

2.56 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

71.43
/torchimize/functions/parallel/gda_fun_parallel.py
1
__author__ = "Christopher Hahne"
3✔
2
__email__ = "inbox@christopherhahne.de"
3✔
3
__license__ = """
3✔
4
    Copyright (c) 2022 Christopher Hahne <inbox@christopherhahne.de>
5
    This program is free software: you can redistribute it and/or modify
6
    it under the terms of the GNU General Public License as published by
7
    the Free Software Foundation, either version 3 of the License, or
8
    (at your option) any later version.
9
    This program is distributed in the hope that it will be useful,
10
    but WITHOUT ANY WARRANTY; without even the implied warranty of
11
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12
    GNU General Public License for more details.
13
    You should have received a copy of the GNU General Public License
14
    along with this program. If not, see <http://www.gnu.org/licenses/>.
15
"""
16

17
import torch
3✔
18
from typing import Union, Callable, Tuple, List
3✔
19

20
from torchimize.functions.jacobian import jacobian_approx_t
3✔
21

22

23
def gradient_descent_parallel(
3✔
24
        p: torch.Tensor,
25
        function: Callable, 
26
        jac_function: Callable = None,
27
        args: Union[Tuple, List] = (),
28
        wvec: torch.Tensor = None,
29
        ftol: float = 1e-8,
30
        ptol: float = 1e-8,
31
        gtol: float = 1e-8,
32
        l: float = 1.,
33
        max_iter: int = 100,
34
    ) -> List[torch.Tensor]:
35
    """
36
    Gradient Descent implementation for parallel least-squares fitting of non-linear functions with conditions.
37

38
    :param p: initial value(s)
39
    :param function: user-provided function which takes p (and additional arguments) as input
40
    :param jac_fun: user-provided Jacobian function which takes p (and additional arguments) as input
41
    :param args: optional arguments passed to function
42
    :param wvec: weights vector used in reduction of multiple costs
43
    :param ftol: relative change in cost function as stop condition
44
    :param ptol: relative change in independant variables as stop condition
45
    :param gtol: maximum gradient tolerance as stop condition
46
    :param l: step size damping parameter
47
    :param max_iter: maximum number of iterations
48
    :return: list of results
49
    """
50

51
    if len(args) > 0:
3!
52
        # pass optional arguments to function
53
        fun = lambda p: function(p, *args)
3✔
54
    else:
55
        fun = function
×
56

57
    if jac_function is None:
3!
58
        # use numerical Jacobian if analytical is not provided
59
        jac_fun = lambda p: jacobian_approx_t(p, f=fun)
×
60
    else:
61
        jac_fun = lambda p: jac_function(p, *args)
3✔
62

63
    assert len(p.shape) == 2, 'parameter tensor is supposed to have 2 dims, but has %s' % str(len(p.shape))
3✔
64

65
    wvec = torch.ones(1, device=p.device, dtype=p.dtype) if wvec is None else wvec
3✔
66

67
    p_list = []
3✔
68
    f_prev = torch.zeros(1, device=p.device, dtype=p.dtype)
3✔
69
    while len(p_list) < max_iter:
3✔
70
        p, f, h = newton_raphson_step(p, fun, jac_fun, wvec, l)
3✔
71
        p_list.append(p.clone())
3✔
72
        g = h/l
3✔
73

74
        # stop conditions
75
        gcon = torch.max(abs(g)) < gtol
3✔
76
        pcon = (h**2).sum()**.5 < ptol*(ptol + (p**2).sum()**.5)
3✔
77
        fcon = ((f_prev-f)**2).sum() < ((ftol*f)**2).sum() if f_prev.shape == f.shape else False
3✔
78
        f_prev = f.clone()
3✔
79
        
80
        if gcon or pcon or fcon:
3!
81
            break
×
82

83
    return p_list
3✔
84

85

86
def gradient_descent_parallel_plain(
3✔
87
        p: torch.Tensor,
88
        function: Callable, 
89
        jac_function: Callable,
90
        wvec: torch.Tensor,
91
        l: float = 1.,
92
        max_iter: int = 100,
93
    ) -> torch.Tensor:
94
    """
95
    Gradient Descent implementation for parallel least-squares fitting of non-linear functions without conditions.
96

97
    :param p: initial value(s)
98
    :param function: user-provided function which takes p (and additional arguments) as input
99
    :param jac_fun: user-provided Jacobian function which takes p (and additional arguments) as input
100
    :param wvec: weights vector used in reduction of multiple costs
101
    :param l: step size damping parameter
102
    :param max_iter: maximum number of iterations
103
    :return: result
104
    """
105

106
    for _ in range(max_iter):
×
107
        p = newton_raphson_step(p, function=function, jac_function=jac_function, wvec=wvec, l=l)[0]
×
108

109
    return p
×
110

111

112
def newton_raphson_step(
3✔
113
        p: torch.Tensor,
114
        function: Callable,
115
        jac_function: Callable,
116
        wvec: torch.Tensor,
117
        l: float = 1.,
118
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
119
    """
120
    Gradient Descent step function for parallel least-squares fitting of non-linear functions
121

122
    :param p: current guess
123
    :param function: user-provided function which takes p (and additional arguments) as input
124
    :param jac_fun: user-provided Jacobian function which takes p (and additional arguments) as input
125
    :param wvec: weights vector used in reduction of multiple costs
126
    :param l: step size damping parameter
127
    :return: list of results
128
    """
129

130
    fc = function(p)
3✔
131
    jc = jac_function(p)
3✔
132
    f = torch.einsum('bcp,c->bp', fc, wvec)
3✔
133
    j = torch.einsum('bcpi,c->bpi', jc, wvec)
3✔
134
    try:
3✔
135
        h = torch.linalg.lstsq(j.double(), f.double(), rcond=None, driver=None)[0].to(dtype=p.dtype)
3✔
136
    except torch._C._LinAlgError:
×
137
        jmin_rank = min(j.shape[1:])
×
138
        rank_mask = torch.linalg.matrix_rank(j) < jmin_rank
×
139
        j[rank_mask] = torch.eye(*j.shape[1:], dtype=j.dtype, device=j.device)
×
140
        h = torch.linalg.lstsq(j.double(), f.double(), rcond=None, driver=None)[0].to(dtype=p.dtype)
×
141

142
    p -= l*h
3✔
143

144
    return p, f, h
3✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc