• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

glass-dev / glass / 18493491435

14 Oct 2025 10:25AM UTC coverage: 94.424%. First build
18493491435

Pull #643

github

web-flow
Merge 090fb00ed into 6f6ee4c58
Pull Request #643: gh-408: porting straightforward functions in `fields`

200 of 202 branches covered (99.01%)

Branch coverage included in aggregate %.

187 of 212 new or added lines in 5 files covered. (88.21%)

1375 of 1466 relevant lines covered (93.79%)

7.49 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.84
/glass/algorithm.py
1
"""Module for algorithms."""
2

3
from __future__ import annotations
8✔
4

5
from typing import TYPE_CHECKING
8✔
6

7
import glass._array_api_utils as _utils
8✔
8

9
if TYPE_CHECKING:
10
    import numpy as np
11
    from jaxtyping import Array
12
    from numpy.typing import NDArray
13

14
    from glass._array_api_utils import GlassFloatArray
15

16

17
def nnls(
8✔
18
    a: GlassFloatArray,
19
    b: GlassFloatArray,
20
    *,
21
    tol: float = 0.0,
22
    maxiter: int | None = None,
23
) -> GlassFloatArray:
24
    """
25
    Compute a non-negative least squares solution.
26

27
    Implementation of the algorithm due to [Lawson95]_ as described by
28
    [Bro97]_.
29

30
    Parameters
31
    ----------
32
    a
33
        The matrix.
34
    b
35
        The vector.
36
    tol
37
        The tolerance for convergence.
38
    maxiter
39
        The maximum number of iterations.
40

41
    Returns
42
    -------
43
        The non-negative least squares solution.
44

45
    Raises
46
    ------
47
    ValueError
48
        If ``a`` is not a matrix.
49
    ValueError
50
        If ``b`` is not a vector.
51
    ValueError
52
        If the shapes of ``a`` and ``b`` do not match.
53

54
    """
55
    xp = _utils.get_namespace(a, b)
8✔
56

57
    a = xp.asarray(a)
8✔
58
    b = xp.asarray(b)
8✔
59

60
    if a.ndim != 2:
8✔
61
        msg = "input `a` is not a matrix"
8✔
62
        raise ValueError(msg)
8✔
63
    if b.ndim != 1:
8✔
64
        msg = "input `b` is not a vector"
8✔
65
        raise ValueError(msg)
8✔
66
    if a.shape[0] != b.shape[0]:
8✔
67
        msg = "the shapes of `a` and `b` do not match"
8✔
68
        raise ValueError(msg)
8✔
69

70
    _, n = a.shape
8✔
71

72
    if maxiter is None:
8✔
73
        maxiter = 3 * n
8✔
74

75
    index = xp.arange(n)
8✔
76
    q = xp.full(n, fill_value=False)
8✔
77
    x = xp.zeros(n)
8✔
78
    for _ in range(maxiter):
8✔
79
        if xp.all(q):
8✔
80
            break
8✔
81
        # The sum product over the last axis of arg1 and the second-to-last axis of arg2
82
        w = xp.sum((b - a @ x)[..., None] * a, axis=-2)
8✔
83

84
        m = int(index[~q][xp.argmax(w[~q])])
8✔
85
        if w[m] <= tol:
8✔
86
            break
8✔
87
        q[m] = True
8✔
88
        while True:
8✔
89
            aq = xp.take(a, xp.nonzero(q)[0], axis=1)
8✔
90
            xq = x[q]
8✔
91
            sq = xp.linalg.solve(aq.T @ aq, b @ aq)
8✔
92
            t = sq <= 0
8✔
93
            if not xp.any(t):
8✔
94
                break
8✔
NEW
95
            alpha = -xp.min(xq[t] / (xq[t] - sq[t]))
×
96
            x[q] += alpha * (sq - xq)
×
97
            q[x <= 0] = False
×
98
        x[q] = sq
8✔
99
        x[~q] = 0
8✔
100
    return x
8✔
101

102

103
def cov_clip(
8✔
104
    cov: NDArray[np.float64] | Array,
105
    rtol: float | None = None,
106
) -> NDArray[np.float64] | Array:
107
    """
108
    Covariance matrix from clipping non-positive eigenvalues.
109

110
    The relative tolerance *rtol* is defined as for
111
    :func:`~array_api.linalg.matrix_rank`.
112

113
    Parameter
114
    ---------
115
    cov
116
        A symmetric matrix (or a stack of matrices).
117
    rtol
118
        An optional relative tolerance for eigenvalues to be considered
119
        positive.
120

121
    Returns
122
    -------
123
        Covariance matrix with negative eigenvalues clipped.
124

125
    """
126
    xp = cov.__array_namespace__()
8✔
127

128
    # Hermitian eigendecomposition
129
    w, v = xp.linalg.eigh(cov)
8✔
130

131
    # get tolerance if not given
132
    if rtol is None:
8✔
133
        rtol = max(v.shape[-2], v.shape[-1]) * xp.finfo(w.dtype).eps
8✔
134

135
    # clip negative diagonal values
136
    w = xp.clip(w, rtol * xp.max(w, axis=-1, keepdims=True), None)
8✔
137

138
    # put matrix back together
139
    # enforce symmetry
140
    v = xp.sqrt(w[..., None, :]) * v
8✔
141
    return xp.matmul(v, xp.matrix_transpose(v))
8✔
142

143

144
def nearcorr(
8✔
145
    a: NDArray[np.float64] | Array,
146
    *,
147
    tol: float | None = None,
148
    niter: int = 100,
149
) -> NDArray[np.float64] | Array:
150
    """
151
    Compute the nearest correlation matrix.
152

153
    Returns the nearest correlation matrix using the alternating
154
    projections algorithm of [Higham02]_.
155

156
    Parameters
157
    ----------
158
    a
159
        Square matrix (or a stack of square matrices).
160
    tol
161
        Tolerance for convergence. Default is dimension times machine
162
        epsilon.
163
    niter
164
        Maximum number of iterations.
165

166
    Returns
167
    -------
168
        Nearest correlation matrix.
169

170
    """
171
    xp = a.__array_namespace__()
8✔
172

173
    # shorthand for Frobenius norm
174
    frob = xp.linalg.matrix_norm
8✔
175

176
    # get size of the covariance matrix and flatten leading dimensions
177
    *dim, m, n = a.shape
8✔
178
    if m != n:
8✔
179
        msg = "non-square matrix"
8✔
180
        raise ValueError(msg)
8✔
181

182
    # default tolerance
183
    if tol is None:
8✔
184
        tol = n * xp.finfo(a.dtype).eps
8✔
185

186
    # current result, flatten leading dimensions
187
    y = xp.reshape(a, (-1, n, n))
8✔
188

189
    # initial correction is zero
190
    ds = xp.zeros_like(a)
8✔
191

192
    # store identity matrix
193
    diag = xp.eye(n)
8✔
194

195
    # find the nearest correlation matrix
196
    for _ in range(niter):
8✔
197
        # apply Dykstra's correction to current result
198
        r = y - ds
8✔
199

200
        # project onto positive semi-definite matrices
201
        x = cov_clip(r)
8✔
202

203
        # compute Dykstra's correction
204
        ds = x - r
8✔
205

206
        # project onto matrices with unit diagonal
207
        y = (1 - diag) * x + diag
8✔
208

209
        # check for convergence
210
        if xp.all(frob(y - x) <= tol * frob(y)):
8✔
211
            break
8✔
212

213
    # return result in original shape
214
    return xp.reshape(y, (*dim, n, n))
8✔
215

216

217
def cov_nearest(
8✔
218
    cov: NDArray[np.float64] | Array,
219
    tol: float | None = None,
220
    niter: int = 100,
221
) -> NDArray[np.float64] | Array:
222
    """
223
    Covariance matrix from nearest correlation matrix.
224

225
    Divides *cov* along rows and columns by the square root of the
226
    diagonal, then computes the nearest valid correlation matrix using
227
    :func:`nearcorr`, before scaling rows and columns back.  The
228
    diagonal of the input is hence unchanged.
229

230
    Parameters
231
    ----------
232
    cov
233
        A square matrix (or a stack of matrices).
234
    tol
235
        Tolerance for convergence, see :func:`nearcorr`.
236
    niter
237
        Maximum number of iterations.
238

239
    Returns
240
    -------
241
        Covariance matrix from nearest correlation matrix.
242

243
    """
244
    xp = cov.__array_namespace__()
8✔
245

246
    # get the diagonal
247
    diag = xp.linalg.diagonal(cov)
8✔
248

249
    # cannot fix negative diagonal
250
    if xp.any(diag < 0):
8✔
251
        msg = "negative values on the diagonal"
8✔
252
        raise ValueError(msg)
8✔
253

254
    # store the normalisation of the matrix
255
    norm = xp.sqrt(diag)
8✔
256
    norm = norm[..., None, :] * norm[..., :, None]
8✔
257

258
    # find nearest correlation matrix
259
    corr = cov / xp.where(norm > 0, norm, 1.0)
8✔
260
    return nearcorr(corr, niter=niter, tol=tol) * norm
8✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc