• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pymanopt / pymanopt / 14702850242

28 Apr 2025 07:43AM UTC coverage: 84.632% (-0.3%) from 84.932%
14702850242

Pull #296

github

web-flow
Merge 56dc45acc into 38296893c
Pull Request #296: Incorporate feedback on backend rewrite

36 of 60 new or added lines in 8 files covered. (60.0%)

2 existing lines in 2 files now uncovered.

3519 of 4158 relevant lines covered (84.63%)

3.38 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

86.08
/src/pymanopt/backends/numpy_backend.py
1
from numbers import Number
4✔
2
from typing import Any, Literal, Optional, Union
4✔
3

4
import numpy as np
4✔
5
import packaging.version as pv
4✔
6
import scipy
4✔
7
import scipy.linalg
4✔
8

9
from pymanopt.backends.backend import Backend, DTypePrecision, TupleOrList
4✔
10

11

12
def _raise_not_implemented_error(*args, **kwargs):
4✔
13
    raise NotImplementedError(
4✔
14
        "No autodiff support available for the NumPy backend"
15
    )
16

17

18
class NumpyBackend(Backend):
4✔
19
    ##########################################################################
20
    # Common attributes, properties and methods
21
    ##########################################################################
22
    array_t = np.ndarray  # type: ignore
4✔
23
    # numpy dtypes follow a very complex hierarchy, and there doesn't seem
24
    # to exist a single super class they all inherit from. In particular:
25
    # np.float32, np.complex128 and others are classes. np.dtype("float32"),
26
    # np.dtype("complex128") and others are not.
27
    _dtype: type
4✔
28

29
    def __init__(self, dtype: type = np.float64):
4✔
30
        if dtype not in {
4✔
31
            np.float32,
32
            np.float64,
33
            np.complex64,
34
            np.complex128,
35
        }:
NEW
36
            raise ValueError(f"dtype {dtype} is not supported")
×
37
        self._dtype = dtype
4✔
38

39
    @property
4✔
40
    def dtype(self):
4✔
41
        return self._dtype
4✔
42

43
    @property
4✔
44
    def dtype_precision(self) -> DTypePrecision:
4✔
45
        return (
4✔
46
            DTypePrecision.SINGLE
47
            if (self.dtype == np.float32 or self.dtype == np.complex64)
48
            else DTypePrecision.DOUBLE
49
        )
50

51
    @property
4✔
52
    def is_dtype_real(self):
4✔
53
        return np.issubdtype(self.dtype, np.floating)
4✔
54

55
    @staticmethod
4✔
56
    def DEFAULT_REAL_DTYPE():
4✔
57
        return np.float64
×
58

59
    @staticmethod
4✔
60
    def DEFAULT_COMPLEX_DTYPE():
4✔
61
        return np.complex128
4✔
62

63
    def __repr__(self):
4✔
64
        return f"NumpyBackend(dtype={self.dtype})"
4✔
65

66
    def to_real_backend(self) -> "NumpyBackend":
4✔
67
        if self.is_dtype_real:
4✔
68
            return self
4✔
69
        if self.dtype == np.complex64:
4✔
70
            return NumpyBackend(dtype=np.float32)
×
71
        elif self.dtype == np.complex128:
4✔
72
            return NumpyBackend(dtype=np.float64)
4✔
73
        else:
74
            raise ValueError(f"dtype {self.dtype} is not supported")
×
75

76
    def to_complex_backend(self) -> "NumpyBackend":
4✔
77
        if not self.is_dtype_real:
4✔
78
            return self
4✔
79
        if self.dtype == np.float32:
4✔
80
            return NumpyBackend(dtype=np.complex64)
4✔
81
        elif self.dtype == np.float64:
4✔
82
            return NumpyBackend(dtype=np.complex128)
4✔
83
        else:
84
            raise ValueError(f"dtype {self.dtype} is not supported")
×
85

86
    ##############################################################################
87
    # Autodiff methods
88
    ##############################################################################
89

90
    generate_gradient_operator = _raise_not_implemented_error
4✔
91
    generate_hessian_operator = _raise_not_implemented_error
4✔
92

93
    ##############################################################################
94
    # Numerics functions
95
    ##############################################################################
96

97
    def abs(self, array: np.ndarray) -> np.ndarray:
4✔
98
        return np.abs(array)
4✔
99

100
    def all(self, array: np.ndarray) -> bool:
4✔
101
        return np.all(array).item()
4✔
102

103
    def allclose(
4✔
104
        self,
105
        array_a: np.ndarray,
106
        array_b: np.ndarray,
107
        rtol: float = 1e-5,
108
        atol: float = 1e-8,
109
    ) -> bool:
110
        return np.allclose(array_a, array_b, rtol, atol)
4✔
111

112
    def any(self, array: np.ndarray) -> bool:
4✔
113
        return np.any(array).item()
4✔
114

115
    def arange(
4✔
116
        self,
117
        start: int,
118
        stop: Optional[int] = None,
119
        step: Optional[int] = None,
120
    ) -> np.ndarray:
121
        return np.arange(start, stop, step)
×
122

123
    def arccos(self, array: np.ndarray) -> np.ndarray:
4✔
124
        return np.arccos(array)
4✔
125

126
    def arccosh(self, array: np.ndarray) -> np.ndarray:
4✔
127
        return np.arccosh(array)
4✔
128

129
    def arctan(self, array: np.ndarray) -> np.ndarray:
4✔
130
        return np.arctan(array)
4✔
131

132
    def arctanh(self, array: np.ndarray) -> np.ndarray:
4✔
133
        return np.arctanh(array)
4✔
134

135
    def argmin(self, array: np.ndarray):
4✔
136
        return np.argmin(array)
×
137

138
    def argsort(self, array: np.ndarray):
4✔
139
        return np.argsort(array)
×
140

141
    def array(self, array: Any) -> np.ndarray:  # type: ignore
4✔
142
        return np.asarray(array, dtype=self.dtype)
4✔
143

144
    def assert_allclose(
4✔
145
        self,
146
        array_a: np.ndarray,
147
        array_b: np.ndarray,
148
        rtol: float = 1e-6,
149
        atol: float = 1e-6,
150
    ) -> None:
151
        if not np.allclose(
4✔
152
            array_a, array_b, rtol=rtol, atol=atol, equal_nan=False
153
        ):
154
            raise ValueError(f"Arrays are not close: {array_a} vs {array_b}")
4✔
155

156
    def assert_equal(
4✔
157
        self,
158
        array_a: np.ndarray,
159
        array_b: np.ndarray,
160
    ) -> None:
161
        if not np.array_equal(array_a, array_b):
4✔
NEW
162
            raise ValueError(f"Arrays are not equal: {array_a} vs {array_b}")
×
163

164
    def concatenate(
4✔
165
        self, arrays: TupleOrList[np.ndarray], axis: int = 0
166
    ) -> np.ndarray:
167
        return np.concatenate(arrays, axis)
4✔
168

169
    def conjugate(self, array: np.ndarray) -> np.ndarray:
4✔
170
        return np.conjugate(array)
4✔
171

172
    def cos(self, array: np.ndarray) -> np.ndarray:
4✔
173
        return np.cos(array)
4✔
174

175
    def diag(self, array: np.ndarray) -> np.ndarray:
4✔
176
        return np.diag(array)
4✔
177

178
    def diagonal(
4✔
179
        self, array: np.ndarray, axis1: int, axis2: int
180
    ) -> np.ndarray:
181
        return np.diagonal(array, axis1, axis2)
×
182

183
    def eps(self) -> float:
4✔
184
        return float(np.finfo(self.dtype).eps)
4✔
185

186
    def exp(self, array: np.ndarray) -> np.ndarray:
4✔
187
        return np.exp(array)
4✔
188

189
    def expand_dims(self, array: np.ndarray, axis: int) -> np.ndarray:
4✔
190
        return np.expand_dims(array, axis)
4✔
191

192
    def eye(self, size: int) -> np.ndarray:
4✔
193
        return np.eye(size, dtype=self.dtype)
4✔
194

195
    def hstack(self, arrays: TupleOrList[np.ndarray]) -> np.ndarray:
4✔
196
        return np.hstack(arrays)
4✔
197

198
    def imag(self, array: np.ndarray) -> np.ndarray:
4✔
199
        return np.imag(array)
4✔
200

201
    def iscomplexobj(self, array: np.ndarray) -> bool:
4✔
202
        return np.iscomplexobj(array)
×
203

204
    def isnan(self, array: np.ndarray) -> np.ndarray:
4✔
205
        return np.isnan(array)
×
206

207
    def isrealobj(self, array: np.ndarray) -> bool:
4✔
208
        return np.isrealobj(array)
4✔
209

210
    def linalg_cholesky(self, array: np.ndarray) -> np.ndarray:
4✔
211
        return np.linalg.cholesky(array)
4✔
212

213
    def linalg_det(self, array: np.ndarray) -> np.ndarray:
4✔
214
        return np.linalg.det(array)
4✔
215

216
    def linalg_eigh(self, array: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
4✔
217
        return np.linalg.eigh(array)
×
218

219
    def linalg_eigvalsh(
4✔
220
        self, array_x: np.ndarray, array_y: Optional[np.ndarray] = None
221
    ) -> np.ndarray:
222
        if array_y is None:
4✔
223
            return np.linalg.eigvalsh(array_x)
4✔
224
        else:
225
            return np.vectorize(
×
226
                scipy.linalg.eigvalsh, signature="(m,m),(m,m)->(m)"
227
            )(array_x, array_y)
228

229
    def linalg_expm(
4✔
230
        self, array: np.ndarray, symmetric: bool = False
231
    ) -> np.ndarray:
232
        if not symmetric:
4✔
233
            # Scipy 1.9.0 added support for calling scipy.linalg.expm on stacked
234
            # matrices.
235
            if pv.parse(scipy.__version__) >= pv.parse("1.9.0"):
4✔
236
                scipy_expm = scipy.linalg.expm
4✔
237
            else:
238
                scipy_expm = np.vectorize(
×
239
                    scipy.linalg.expm, signature="(m,m)->(m,m)"
240
                )
241
            return scipy_expm(array)
4✔
242

243
        w, v = np.linalg.eigh(array)
×
244
        w = np.expand_dims(np.exp(w), axis=-1)
×
245
        expmA = v @ (w * self.conjugate_transpose(v))
×
246
        if np.isrealobj(array):
×
247
            return np.real(expmA)
×
248
        return expmA
×
249

250
    def linalg_inv(self, array: np.ndarray) -> np.ndarray:
4✔
251
        return np.linalg.inv(array)
4✔
252

253
    def linalg_logm(
4✔
254
        self, array: np.ndarray, positive_definite: bool = False
255
    ) -> np.ndarray:
256
        if not positive_definite:
4✔
257
            return np.vectorize(scipy.linalg.logm, signature="(m,m)->(m,m)")(
4✔
258
                array
259
            )
260

261
        w, v = np.linalg.eigh(array)
4✔
262
        w = np.expand_dims(np.log(w), axis=-1)
4✔
263
        logmA = v @ (w * self.conjugate_transpose(v))
4✔
264
        if np.isrealobj(array):
4✔
265
            return np.real(logmA)
4✔
266
        return logmA
4✔
267

268
    def linalg_matrix_rank(self, array: np.ndarray) -> int:
4✔
269
        return np.linalg.matrix_rank(array)
4✔
270

271
    def linalg_norm(
4✔
272
        self,
273
        array: np.ndarray,
274
        ord: Union[int, Literal["fro"], None] = None,
275
        axis: Union[int, TupleOrList[int], None] = None,
276
        keepdims: bool = False,
277
    ) -> Union[np.ndarray, Number]:
278
        return np.linalg.norm(array, ord=ord, axis=axis, keepdims=keepdims)
4✔
279

280
    def linalg_qr(self, array: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
4✔
281
        q, r = np.linalg.qr(array)
4✔
282
        # Compute signs or unit-modulus phase of entries of diagonal of r.
283
        s = np.diagonal(r, axis1=-2, axis2=-1).copy()
4✔
284
        s[s == 0] = 1
4✔
285
        s = s / np.abs(s)
4✔
286
        s = np.expand_dims(s, axis=-1)
4✔
287
        # normalize q and r to have either 1 or unit-modulus on the diagonal of r
288
        q = q * self.transpose(s)
4✔
289
        r = r * np.conjugate(s)
4✔
290
        return q, r
4✔
291

292
    def linalg_solve(
4✔
293
        self, array_a: np.ndarray, array_b: np.ndarray
294
    ) -> np.ndarray:
295
        return np.linalg.solve(array_a, array_b)
4✔
296

297
    def linalg_solve_continuous_lyapunov(
4✔
298
        self, array_a: np.ndarray, array_q: np.ndarray
299
    ) -> np.ndarray:
300
        return scipy.linalg.solve_continuous_lyapunov(array_a, array_q)
4✔
301

302
    def linalg_svd(
4✔
303
        self,
304
        array: np.ndarray,
305
        full_matrices: bool = True,
306
    ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
307
        return np.linalg.svd(array, full_matrices=full_matrices)
4✔
308

309
    def linalg_svdvals(self, array: np.ndarray) -> np.ndarray:
4✔
310
        return np.linalg.svd(array, compute_uv=False)
4✔
311

312
    def log(self, array: np.ndarray) -> np.ndarray:
4✔
313
        return np.log(array)
4✔
314

315
    def log10(self, array: np.ndarray) -> np.ndarray:
4✔
316
        return np.log10(array)
×
317

318
    def logical_not(self, array: np.ndarray) -> np.ndarray:
4✔
319
        return np.logical_not(array)
×
320

321
    def logspace(self, start: float, stop: float, num: int) -> np.ndarray:
4✔
322
        return np.logspace(start, stop, num, dtype=self.dtype)
×
323

324
    def ndim(self, array: np.ndarray) -> int:
4✔
325
        return array.ndim
4✔
326

327
    def ones(self, shape: TupleOrList[int]) -> np.ndarray:
4✔
328
        return np.ones(shape, self.dtype)
4✔
329

330
    def ones_bool(self, shape: TupleOrList[int]) -> np.ndarray:
4✔
331
        return np.ones(shape, bool)
×
332

333
    def prod(self, array: np.ndarray) -> float:
4✔
334
        return np.prod(array)  # type: ignore
×
335

336
    def random_normal(
4✔
337
        self,
338
        loc: float = 0.0,
339
        scale: float = 1.0,
340
        size: Union[int, TupleOrList[int], None] = None,
341
    ) -> np.ndarray:
342
        if self.is_dtype_real:
4✔
343
            return np.asarray(
4✔
344
                np.random.normal(loc=loc, scale=scale, size=size),
345
                dtype=self.dtype,
346
            )
347
        else:
348
            real_dtype = np.finfo(self.dtype).dtype
4✔
349
            return np.asarray(
4✔
350
                np.random.normal(loc=loc, scale=scale, size=size),
351
                dtype=real_dtype,
352
            ) + 1j * np.asarray(
353
                np.random.normal(loc=loc, scale=scale, size=size),
354
                dtype=real_dtype,
355
            )
356

357
    def random_uniform(
4✔
358
        self, size: Union[int, TupleOrList[int], None] = None
359
    ) -> np.ndarray:
360
        if self.is_dtype_real:
4✔
361
            return np.asarray(np.random.uniform(size=size), dtype=self.dtype)
4✔
362
        else:
363
            real_dtype = np.finfo(self.dtype).dtype
×
364
            return np.asarray(
×
365
                np.random.uniform(size=size), dtype=real_dtype
366
            ) + 1j * np.asarray(np.random.uniform(size=size), dtype=real_dtype)
367

368
    def real(self, array: np.ndarray) -> np.ndarray:
4✔
369
        return np.real(array)
4✔
370

371
    def reshape(
4✔
372
        self, array: np.ndarray, newshape: TupleOrList[int]
373
    ) -> np.ndarray:
374
        return np.reshape(array, newshape)
4✔
375

376
    def sin(self, array: np.ndarray) -> np.ndarray:
4✔
377
        return np.sin(array)
4✔
378

379
    def sinc(self, array: np.ndarray) -> np.ndarray:
4✔
380
        return np.sinc(array)
4✔
381

382
    def sort(self, array: np.ndarray, descending: bool = False) -> np.ndarray:
4✔
383
        return np.sort(array)
4✔
384

385
    def sqrt(self, array: np.ndarray) -> np.ndarray:
4✔
386
        return np.sqrt(array)
4✔
387

388
    def squeeze(self, array: np.ndarray) -> np.ndarray:
4✔
389
        return np.squeeze(array)
4✔
390

391
    def stack(
4✔
392
        self, arrays: TupleOrList[np.ndarray], axis: int = 0
393
    ) -> np.ndarray:
394
        return np.stack(arrays, axis=axis)
4✔
395

396
    def sum(
4✔
397
        self,
398
        array: np.ndarray,
399
        axis: Union[int, TupleOrList[int], None] = None,
400
        keepdims: bool = False,
401
    ) -> np.ndarray:
402
        return np.sum(array, axis=axis, keepdims=keepdims)  # type: ignore
4✔
403

404
    def tan(self, array: np.ndarray) -> np.ndarray:
4✔
405
        return np.tan(array)
×
406

407
    def tanh(self, array: np.ndarray) -> np.ndarray:
4✔
408
        return np.tanh(array)
4✔
409

410
    def tensordot(
4✔
411
        self, a: np.ndarray, b: np.ndarray, axes: int = 2
412
    ) -> np.ndarray:
413
        return np.tensordot(a, b, axes=axes)
4✔
414

415
    def tile(
4✔
416
        self, array: np.ndarray, reps: Union[int, TupleOrList[int]]
417
    ) -> np.ndarray:
418
        return np.tile(array, reps)
4✔
419

420
    def trace(self, array: np.ndarray) -> Union[np.ndarray, Number]:
4✔
421
        return (
4✔
422
            np.trace(array).item()
423
            if array.ndim == 2
424
            else np.trace(array, axis1=-2, axis2=-1)
425
        )
426

427
    def transpose(self, array: np.ndarray) -> np.ndarray:
4✔
428
        new_shape = list(range(self.ndim(array)))
4✔
429
        new_shape[-1], new_shape[-2] = new_shape[-2], new_shape[-1]
4✔
430
        return np.transpose(array, new_shape)
4✔
431

432
    def triu(self, array: np.ndarray, k: int = 0) -> np.ndarray:
4✔
433
        return np.triu(array, k)
×
434

435
    def vstack(self, arrays: TupleOrList[np.ndarray]) -> np.ndarray:
4✔
436
        return np.vstack(arrays)
4✔
437

438
    def where(
4✔
439
        self,
440
        condition: np.ndarray,
441
        x: Optional[np.ndarray] = None,
442
        y: Optional[np.ndarray] = None,
443
    ) -> Union[np.ndarray, tuple[np.ndarray, ...]]:
444
        if x is None and y is None:
4✔
445
            return np.where(condition)
×
446
        elif x is not None and y is not None:
4✔
447
            return np.where(condition, x, y)
4✔
448
        else:
449
            raise ValueError(
×
450
                f"Both x and y have to be specified but are respectively {x} and {y}"
451
            )
452

453
    def zeros(self, shape: TupleOrList[int]) -> np.ndarray:
4✔
454
        return np.zeros(shape, dtype=self.dtype)
4✔
455

456
    def zeros_bool(self, shape: TupleOrList[int]) -> np.ndarray:
4✔
457
        return np.zeros(shape, bool)
×
458

459
    def zeros_like(self, array: np.ndarray) -> np.ndarray:
4✔
460
        return np.zeros_like(array, dtype=self.dtype)
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc