• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pymanopt / pymanopt / 14637325149

24 Apr 2025 08:43AM UTC coverage: 84.932% (-3.2%) from 88.102%
14637325149

Pull #295

github

web-flow
Merge c93dc573f into a1f52e740
Pull Request #295: Proposal to merge new integration of backends in PyManopt

1915 of 2267 new or added lines in 31 files covered. (84.47%)

36 existing lines in 11 files now uncovered.

3534 of 4161 relevant lines covered (84.93%)

3.39 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

94.23
/src/pymanopt/function.py
1
__all__ = ["Function", "autograd", "jax", "numpy", "pytorch", "tensorflow"]
4✔
2

3
import inspect
4✔
4
from importlib import import_module
4✔
5
from typing import Any, Callable, Optional, Protocol
4✔
6

7
from pymanopt.backends import Backend
4✔
8
from pymanopt.manifolds.manifold import Manifold
4✔
9

10

11
class Function:
4✔
12
    def __init__(
4✔
13
        self, *, function: Callable, manifold: Manifold, backend: Backend
14
    ):
15
        if not callable(function):
4✔
NEW
16
            raise TypeError(f"Object {function} is not callable")
×
17

18
        self._original_function = function
4✔
19
        self._backend = backend
4✔
20
        self._function = backend.prepare_function(function)
4✔
21
        self._num_arguments = manifold.num_values
4✔
22

23
        self._gradient = None
4✔
24
        self._hessian = None
4✔
25

26
    def __str__(self):
4✔
NEW
27
        return f"Function <{self._backend}>"
×
28

29
    @property
4✔
30
    def backend(self):
4✔
31
        return self._backend
4✔
32

33
    def get_gradient_operator(self):
4✔
34
        if self._gradient is None:
4✔
35
            self._gradient = self._backend.generate_gradient_operator(
4✔
36
                self._original_function, self._num_arguments
37
            )
38
        return self._gradient
4✔
39

40
    def get_hessian_operator(self):
4✔
41
        if self._hessian is None:
4✔
42
            self._hessian = self._backend.generate_hessian_operator(
4✔
43
                self._original_function, self._num_arguments
44
            )
45
        return self._hessian
4✔
46

47
    def __call__(self, *args, **kwargs):
4✔
48
        return self._function(*args, **kwargs)
4✔
49

50

51
def _only_one_true(*args):
4✔
52
    return sum(args) == 1
4✔
53

54

55
class _ObjectiveFunctionDecorator(Protocol):
4✔
56
    def __call__(
4✔
57
        self, manifold: Manifold, dtype: Optional[Any] = None
58
    ) -> Callable[[Callable[..., Any]], Function]:
NEW
59
        ...
×
60

61

62
def decorator_factory(
4✔
63
    module: str, backend_class: str
64
) -> _ObjectiveFunctionDecorator:
65
    def decorator(
4✔
66
        manifold: Manifold, dtype: Optional[Any] = None
67
    ) -> Callable[[Callable[..., Any]], Function]:
68
        assert isinstance(manifold, Manifold)
4✔
69

70
        def inner(cost: Callable[..., Any]) -> Function:
4✔
71
            argspec = inspect.getfullargspec(cost)
4✔
72
            assert (
4✔
73
                _only_one_true(bool(argspec.args), bool(argspec.varargs))
74
                and not argspec.varkw
75
                and not argspec.kwonlyargs
76
            ), (
77
                "Decorated function must only accept positional arguments "
78
                "or a variable-length argument like *x"
79
            )
80
            backend_type = getattr(
4✔
81
                import_module(
82
                    f"pymanopt.backends.{module}",
83
                ),
84
                backend_class,
85
            )
86
            backend = (
4✔
87
                backend_type(dtype=dtype)
88
                if dtype is not None
89
                # by default use float64, which is fine for a function it only
90
                # uses autodiff methods (which do not depend on realness)
91
                else backend_type()
92
            )
93
            return Function(function=cost, manifold=manifold, backend=backend)
4✔
94

95
        return inner
4✔
96

97
    return decorator
4✔
98

99

100
numpy = decorator_factory("numpy_backend", "NumpyBackend")
4✔
101
jax = decorator_factory("jax_backend", "JaxBackend")
4✔
102
pytorch = decorator_factory("pytorch_backend", "PytorchBackend")
4✔
103
autograd = decorator_factory("autograd_backend", "AutogradBackend")
4✔
104
tensorflow = decorator_factory("tensorflow_backend", "TensorflowBackend")
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc