• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pymanopt / pymanopt / 14701266283

28 Apr 2025 05:56AM UTC coverage: 84.656% (-0.3%) from 84.932%
14701266283

Pull #296

github

web-flow
Merge 63f51d864 into 38296893c
Pull Request #296: Incorporate feedback on backend rewrite

36 of 60 new or added lines in 8 files covered. (60.0%)

1 existing line in 1 file now uncovered.

3520 of 4158 relevant lines covered (84.66%)

3.38 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

94.0
/src/pymanopt/function.py
1
__all__ = ["Function", "autograd", "jax", "numpy", "pytorch", "tensorflow"]
4✔
2

3
import inspect
4✔
4
from importlib import import_module
4✔
5
from typing import Any, Callable, Optional, Protocol
4✔
6

7
from pymanopt.backends import Backend
4✔
8
from pymanopt.manifolds.manifold import Manifold
4✔
9

10

11
class Function:
4✔
12
    def __init__(
4✔
13
        self, *, function: Callable, manifold: Manifold, backend: Backend
14
    ):
15
        self._original_function = function
4✔
16
        self._backend = backend
4✔
17
        self._function = function
4✔
18
        self._num_arguments = manifold.num_values
4✔
19

20
        self._gradient = None
4✔
21
        self._hessian = None
4✔
22

23
    def __str__(self):
4✔
24
        return f"Function <{self._backend}>"
×
25

26
    @property
4✔
27
    def backend(self):
4✔
28
        return self._backend
4✔
29

30
    def get_gradient_operator(self):
4✔
31
        if self._gradient is None:
4✔
32
            self._gradient = self._backend.generate_gradient_operator(
4✔
33
                self._original_function, self._num_arguments
34
            )
35
        return self._gradient
4✔
36

37
    def get_hessian_operator(self):
4✔
38
        if self._hessian is None:
4✔
39
            self._hessian = self._backend.generate_hessian_operator(
4✔
40
                self._original_function, self._num_arguments
41
            )
42
        return self._hessian
4✔
43

44
    def __call__(self, *args, **kwargs):
4✔
45
        return self._function(*args, **kwargs)
4✔
46

47

48
def _only_one_true(*args):
4✔
49
    return sum(args) == 1
4✔
50

51

52
class _ObjectiveFunctionDecorator(Protocol):
4✔
53
    def __call__(
4✔
54
        self, manifold: Manifold, dtype: Optional[Any] = None
55
    ) -> Callable[[Callable[..., Any]], Function]:
56
        ...
×
57

58

59
def decorator_factory(
4✔
60
    module: str, backend_class: str
61
) -> _ObjectiveFunctionDecorator:
62
    def decorator(
4✔
63
        manifold: Manifold, dtype: Optional[Any] = None
64
    ) -> Callable[[Callable[..., Any]], Function]:
65
        def inner(cost: Callable[..., Any]) -> Function:
4✔
66
            argspec = inspect.getfullargspec(cost)
4✔
67
            if not (
4✔
68
                _only_one_true(bool(argspec.args), bool(argspec.varargs))
69
                and not argspec.varkw
70
                and not argspec.kwonlyargs
71
            ):
NEW
72
                raise TypeError(
×
73
                    "Decorated function must only accept positional arguments "
74
                    "or a variable-length argument like *x"
75
                )
76
            backend_type = getattr(
4✔
77
                import_module(
78
                    f"pymanopt.backends.{module}",
79
                ),
80
                backend_class,
81
            )
82
            backend = (
4✔
83
                backend_type(dtype=dtype)
84
                if dtype is not None
85
                # by default use float64, which is fine for a function it only
86
                # uses autodiff methods (which do not depend on realness)
87
                else backend_type()
88
            )
89
            return Function(function=cost, manifold=manifold, backend=backend)
4✔
90

91
        return inner
4✔
92

93
    return decorator
4✔
94

95

96
numpy = decorator_factory("numpy_backend", "NumpyBackend")
4✔
97
jax = decorator_factory("jax_backend", "JaxBackend")
4✔
98
pytorch = decorator_factory("pytorch_backend", "PytorchBackend")
4✔
99
autograd = decorator_factory("autograd_backend", "AutogradBackend")
4✔
100
tensorflow = decorator_factory("tensorflow_backend", "TensorflowBackend")
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc