• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pymanopt / pymanopt / 15239634412

25 May 2025 04:05PM UTC coverage: 84.586% (-0.3%) from 84.932%
15239634412

Pull #296

github

web-flow
Merge 0dccf0ddf into 38296893c
Pull Request #296: Incorporate feedback on backend rewrite

159 of 188 new or added lines in 23 files covered. (84.57%)

3 existing lines in 3 files now uncovered.

3501 of 4139 relevant lines covered (84.59%)

3.38 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.62
/src/pymanopt/function.py
1
__all__ = ["Function", "autograd", "jax", "numpy", "pytorch", "tensorflow"]
4✔
2

3
import inspect
4✔
4
from importlib import import_module
4✔
5
from typing import Any, Callable, Optional, Protocol
4✔
6

7
from pymanopt.backends import Backend
4✔
8
from pymanopt.manifolds.manifold import Manifold
4✔
9

10

11
class Function:
4✔
12
    def __init__(
4✔
13
        self, *, function: Callable, manifold: Manifold, backend: Backend
14
    ):
15
        self._original_function = function
4✔
16
        self._function = function
4✔
17
        self._num_arguments = manifold.num_values
4✔
18

19
        self._gradient = None
4✔
20
        self._hessian = None
4✔
21

22
        self.backend = backend
4✔
23

24
    def __str__(self):
4✔
NEW
25
        return f"Function <{self.backend}>"
×
26

27
    def get_gradient_operator(self):
4✔
28
        if self._gradient is None:
4✔
29
            self._gradient = self.backend.generate_gradient_operator(
4✔
30
                self._original_function, self._num_arguments
31
            )
32
        return self._gradient
4✔
33

34
    def get_hessian_operator(self):
4✔
35
        if self._hessian is None:
4✔
36
            self._hessian = self.backend.generate_hessian_operator(
4✔
37
                self._original_function, self._num_arguments
38
            )
39
        return self._hessian
4✔
40

41
    def __call__(self, *args, **kwargs):
4✔
42
        return self._function(*args, **kwargs)
4✔
43

44

45
def _only_one_true(*args):
4✔
46
    return sum(args) == 1
4✔
47

48

49
class _ObjectiveFunctionDecorator(Protocol):
4✔
50
    def __call__(
4✔
51
        self, manifold: Manifold, dtype: Optional[Any] = None
52
    ) -> Callable[[Callable[..., Any]], Function]:
53
        ...
×
54

55

56
def decorator_factory(
4✔
57
    module: str, backend_class: str
58
) -> _ObjectiveFunctionDecorator:
59
    def decorator(
4✔
60
        manifold: Manifold, dtype: Optional[Any] = None
61
    ) -> Callable[[Callable[..., Any]], Function]:
62
        def inner(cost: Callable[..., Any]) -> Function:
4✔
63
            argspec = inspect.getfullargspec(cost)
4✔
64
            if not (
4✔
65
                _only_one_true(bool(argspec.args), bool(argspec.varargs))
66
                and not argspec.varkw
67
                and not argspec.kwonlyargs
68
            ):
NEW
69
                raise TypeError(
×
70
                    "Decorated function must only accept positional arguments "
71
                    "or a variable-length argument like *x"
72
                )
73
            backend_type = getattr(
4✔
74
                import_module(
75
                    f"pymanopt.backends.{module}",
76
                ),
77
                backend_class,
78
            )
79
            backend = (
4✔
80
                backend_type(dtype=dtype)
81
                if dtype is not None
82
                # by default use float64, which is fine for a function it only
83
                # uses autodiff methods (which do not depend on realness)
84
                else backend_type()
85
            )
86
            return Function(function=cost, manifold=manifold, backend=backend)
4✔
87

88
        return inner
4✔
89

90
    return decorator
4✔
91

92

93
numpy = decorator_factory("numpy_backend", "NumpyBackend")
4✔
94
jax = decorator_factory("jax_backend", "JaxBackend")
4✔
95
pytorch = decorator_factory("pytorch_backend", "PytorchBackend")
4✔
96
autograd = decorator_factory("autograd_backend", "AutogradBackend")
4✔
97
tensorflow = decorator_factory("tensorflow_backend", "TensorflowBackend")
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc