• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pymanopt / pymanopt / 14702850242

28 Apr 2025 07:43AM UTC coverage: 84.632% (-0.3%) from 84.932%
14702850242

Pull #296

github

web-flow
Merge 56dc45acc into 38296893c
Pull Request #296: Incorporate feedback on backend rewrite

36 of 60 new or added lines in 8 files covered. (60.0%)

2 existing lines in 2 files now uncovered.

3519 of 4158 relevant lines covered (84.63%)

3.38 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

87.04
/src/pymanopt/core/problem.py
1
"""The Pymanopt problem class."""
2

3
import functools
4✔
4
from typing import Callable, Optional
4✔
5

6
from pymanopt.function import Function
4✔
7
from pymanopt.manifolds.manifold import Manifold
4✔
8

9

10
class Problem:
4✔
11
    """Problem class to define a Riemannian optimization problem.
12

13
    Args:
14
        manifold: Manifold to optimize over.
15
        cost: A callable decorated with a decorator from
16
            :mod:`pymanopt.functions` which takes a point on a manifold and
17
            returns a real scalar.
18
            If any decorator other than :func:`pymanopt.function.numpy` is
19
            used, the gradient and Hessian functions are generated
20
            automatically if needed and no ``{euclidean,riemannian}_gradient``
21
            or ``{euclidean,riemannian}_hessian`` arguments are provided.
22
        euclidean_gradient: The Euclidean gradient, i.e., the gradient of the
23
            cost function in the typical sense in the ambient space.
24
            The returned value need not belong to the tangent space of
25
            ``manifold``.
26
        riemannian_gradient: The Riemannian gradient.
27
            For embedded submanifolds this is simply the projection of
28
            ``euclidean_gradient`` on the tangent space of ``manifold``.
29
            In most cases this need not be provided and the Riemannian gradient
30
            is instead computed internally.
31
            If provided, the function needs to return a vector in the tangent
32
            space of ``manifold``.
33
        euclidean_hessian: The Euclidean Hessian, i.e., the directional
34
            derivative of ``euclidean_gradient`` in the direction of a tangent
35
            vector.
36
        riemannian_hessian: The Riemannian Hessian, i.e., the directional
37
            derivative of ``riemannian_gradient`` in the direction of a tangent
38
            vector.
39
            As with ``riemannian_gradient`` this usually need not be provided
40
            explicitly.
41
    """
42

43
    def __init__(
4✔
44
        self,
45
        manifold: Manifold,
46
        cost: Callable,
47
        *,
48
        euclidean_gradient: Optional[Callable] = None,
49
        riemannian_gradient: Optional[Callable] = None,
50
        euclidean_hessian: Optional[Callable] = None,
51
        riemannian_hessian: Optional[Callable] = None,
52
        preconditioner: Optional[Callable] = None,
53
    ):
54
        self.manifold = manifold
4✔
55

56
        for function, name in (
4✔
57
            (cost, "cost"),
58
            (euclidean_gradient, "euclidean_gradient"),
59
            (euclidean_hessian, "euclidean_hessian"),
60
            (riemannian_gradient, "riemannian_gradient"),
61
            (riemannian_hessian, "riemannian_hessian"),
62
        ):
63
            if function is not None and not isinstance(function, Callable):
4✔
NEW
64
                raise TypeError(f"Function {name} must be callable")
×
65

66
        if manifold.has_dummy_backend():
4✔
67
            if isinstance(cost, Function):
4✔
68
                manifold.set_compatible_backend(cost.backend)
4✔
69
            else:
70
                raise ValueError(
×
71
                    "Neither cost nor manifold have a specified backend."
72
                )
73
        else:
74
            cost = self._validate_function_backend(cost, "cost", manifold)
4✔
75

76
        self._original_cost = cost
4✔
77
        self._cost = self._wrap_function(cost)
4✔
78

79
        if euclidean_gradient is not None and riemannian_gradient is not None:
4✔
80
            raise ValueError(
×
81
                "Only 'euclidean_gradient' or 'riemannian_gradient' should be "
82
                "provided, not both"
83
            )
84
        if euclidean_hessian is not None and riemannian_hessian is not None:
4✔
85
            raise ValueError(
×
86
                "Only 'euclidean_hessian' or 'riemannian_hessian' should be "
87
                "provided, not both"
88
            )
89

90
        if euclidean_gradient is not None:
4✔
91
            euclidean_gradient = self._validate_function_backend(
4✔
92
                euclidean_gradient, "euclidean_gradient", manifold
93
            )
94
            euclidean_gradient = self._wrap_gradient_operator(
4✔
95
                euclidean_gradient
96
            )
97
        self._euclidean_gradient = euclidean_gradient
4✔
98
        if euclidean_hessian is not None:
4✔
99
            euclidean_hessian = self._validate_function_backend(
4✔
100
                euclidean_hessian, "euclidean_hessian", manifold
101
            )
102
            euclidean_hessian = self._wrap_hessian_operator(
4✔
103
                euclidean_hessian, embed_tangent_vectors=True
104
            )
105
        self._euclidean_hessian = euclidean_hessian
4✔
106

107
        if riemannian_gradient is not None:
4✔
108
            riemannian_gradient = self._validate_function_backend(
4✔
109
                riemannian_gradient, "riemannian_gradient", manifold
110
            )
111
            riemannian_gradient = self._wrap_gradient_operator(
4✔
112
                riemannian_gradient
113
            )
114
        self._riemannian_gradient = riemannian_gradient
4✔
115
        if riemannian_hessian is not None:
4✔
116
            riemannian_hessian = self._validate_function_backend(
×
117
                riemannian_hessian, "riemannian_hessian", manifold
118
            )
119
            riemannian_hessian = self._wrap_hessian_operator(
×
120
                riemannian_hessian
121
            )
122
        self._riemannian_hessian = riemannian_hessian
4✔
123

124
        if preconditioner is not None:
4✔
125

NEW
126
            self.preconditioner = preconditioner
×
127
        else:
128

129
            def default_preconditioner(point, tangent_vector):
4✔
130
                return tangent_vector
4✔
131

132
            self.preconditioner = default_preconditioner
4✔
133

134
    def __setattr__(self, key, value):
4✔
135
        if hasattr(self, key) and key in ("manifold", "preconditioner"):
4✔
136
            raise AttributeError(f"Cannot override '{key}' attribute")
4✔
137
        super().__setattr__(key, value)
4✔
138

139
    @staticmethod
4✔
140
    def _validate_function(function, name):
4✔
NEW
141
        if not isinstance(function, Function):
×
NEW
142
            raise TypeError(
×
143
                f"Function '{name}' must be decorated with a backend decorator."
144
            )
145

146
    @staticmethod
4✔
147
    def _validate_function_backend(
4✔
148
        function: Callable, name: str, manifold: Manifold
149
    ):
150
        if isinstance(function, Function):
4✔
151
            if not manifold.is_backend_compatible(function.backend):
4✔
NEW
152
                raise ValueError(
×
153
                    f"Function '{name}' has a backend {function.backend} "
154
                    "which is not compatible with the manifold's backend"
155
                    f" {manifold.backend}."
156
                )
157
            return function
4✔
158
        else:
159
            return Function(
4✔
160
                function=function, manifold=manifold, backend=manifold.backend
161
            )
162

163
    def _flatten_arguments(self, arguments, signature):
4✔
164
        if len(arguments) != len(signature):
4✔
NEW
165
            raise ValueError("Arguments do not match function signature")
×
166

167
        flattened_arguments = []
4✔
168
        for i, group_size in enumerate(signature):
4✔
169
            argument = arguments[i]
4✔
170
            if group_size == 1:
4✔
171
                if isinstance(argument, (list, tuple)):
4✔
NEW
172
                    raise TypeError(
×
173
                        "Expected a single value, but got "
174
                        f"{type(argument).__name__}"
175
                    )
176
                flattened_arguments.append(argument)
4✔
177
            else:
178
                if len(argument) != group_size:
4✔
NEW
179
                    raise ValueError(
×
180
                        f"Expected {group_size} values, but got {len(argument)}"
181
                    )
182
                flattened_arguments.extend(argument)
4✔
183
        return flattened_arguments
4✔
184

185
    def _group_return_values(self, function, signature):
4✔
186
        """Decorator to group return values according to a given signature.
187

188
        Wraps a function inside another function which groups the return
189
        values of ``function`` according to the group sizes delineated by
190
        ``signature``.
191
        """
192
        if not all(isinstance(group, int) for group in signature):
4✔
NEW
193
            raise ValueError("All elements of signature must be integers")
×
194

195
        num_return_values = sum(signature)
4✔
196

197
        @functools.wraps(function)
4✔
198
        def wrapper(*args, **kwargs):
4✔
199
            return_values = function(*args, **kwargs)
4✔
200
            if not isinstance(return_values, (list, tuple)):
4✔
201
                raise ValueError("Function returned an unexpected value")
×
202
            if len(return_values) != num_return_values:
4✔
203
                raise ValueError(
×
204
                    "Function returned an unexpected number of arguments"
205
                )
206
            groups = []
4✔
207
            i = 0
4✔
208
            for group_size in signature:
4✔
209
                if group_size == 1:
4✔
210
                    group = return_values[i]
4✔
211
                else:
212
                    group = return_values[i : i + group_size]
4✔
213
                groups.append(group)
4✔
214
                i += group_size
4✔
215
            return groups
4✔
216

217
        return wrapper
4✔
218

219
    def _wrap_function(self, function):
4✔
220
        point_layout = self.manifold.point_layout
4✔
221
        if isinstance(point_layout, (tuple, list)):
4✔
222

223
            @functools.wraps(function)
4✔
224
            def wrapper(point):
4✔
225
                return function(*self._flatten_arguments(point, point_layout))
4✔
226

227
            return wrapper
4✔
228

229
        if not isinstance(point_layout, int):
4✔
NEW
230
            raise TypeError("Point layout must be an integer")
×
231

232
        if point_layout == 1:
4✔
233

234
            @functools.wraps(function)
4✔
235
            def wrapper(point):
4✔
236
                return function(point)
4✔
237

238
        else:
239

240
            @functools.wraps(function)
4✔
241
            def wrapper(point):
4✔
242
                return function(*point)
4✔
243

244
        return wrapper
4✔
245

246
    def _wrap_gradient_operator(self, gradient):
4✔
247
        wrapped_gradient = self._wrap_function(gradient)
4✔
248
        point_layout = self.manifold.point_layout
4✔
249
        if isinstance(point_layout, (list, tuple)):
4✔
250
            return self._group_return_values(wrapped_gradient, point_layout)
4✔
251
        return wrapped_gradient
4✔
252

253
    def _wrap_hessian_operator(
4✔
254
        self, hessian_operator, *, embed_tangent_vectors=False
255
    ):
256
        point_layout = self.manifold.point_layout
4✔
257
        if isinstance(point_layout, (list, tuple)):
4✔
258

259
            @functools.wraps(hessian_operator)
4✔
260
            def wrapper(point, vector):
4✔
261
                return hessian_operator(
4✔
262
                    *self._flatten_arguments(point, point_layout),
263
                    *self._flatten_arguments(vector, point_layout),
264
                )
265

266
            wrapper = self._group_return_values(wrapper, point_layout)
4✔
267

268
        elif point_layout == 1:
4✔
269

270
            @functools.wraps(hessian_operator)
4✔
271
            def wrapper(point, vector):
4✔
272
                return hessian_operator(point, vector)
4✔
273

274
        else:
275

276
            @functools.wraps(hessian_operator)
×
277
            def wrapper(point, vector):
×
278
                return hessian_operator(*point, *vector)
×
279

280
        if embed_tangent_vectors:
4✔
281

282
            @functools.wraps(wrapper)
4✔
283
            def hvp(point, vector):
4✔
284
                return wrapper(point, self.manifold.embedding(point, vector))
4✔
285

286
        else:
287
            hvp = wrapper
×
288
        return hvp
4✔
289

290
    @property
4✔
291
    def cost(self):
4✔
292
        return self._cost
4✔
293

294
    @property
4✔
295
    def euclidean_gradient(self):
4✔
296
        if self._euclidean_gradient is None:
4✔
297
            self._euclidean_gradient = self._wrap_gradient_operator(
4✔
298
                self._original_cost.get_gradient_operator()
299
            )
300
        return self._euclidean_gradient
4✔
301

302
    @property
4✔
303
    def riemannian_gradient(self):
4✔
304
        if self._riemannian_gradient is None:
4✔
305

306
            def riemannian_gradient(point):
4✔
307
                return self.manifold.euclidean_to_riemannian_gradient(
4✔
308
                    point, self.euclidean_gradient(point)
309
                )
310

311
            self._riemannian_gradient = riemannian_gradient
4✔
312
        return self._riemannian_gradient
4✔
313

314
    @property
4✔
315
    def euclidean_hessian(self):
4✔
316
        if self._euclidean_hessian is None:
4✔
317
            self._euclidean_hessian = self._wrap_hessian_operator(
4✔
318
                self._original_cost.get_hessian_operator(),
319
                embed_tangent_vectors=True,
320
            )
321
        return self._euclidean_hessian
4✔
322

323
    @property
4✔
324
    def riemannian_hessian(self):
4✔
325
        if self._riemannian_hessian is None:
4✔
326

327
            def riemannian_hessian(point, tangent_vector):
4✔
328
                return self.manifold.euclidean_to_riemannian_hessian(
4✔
329
                    point,
330
                    self.euclidean_gradient(point),
331
                    self.euclidean_hessian(point, tangent_vector),
332
                    tangent_vector,
333
                )
334

335
            self._riemannian_hessian = riemannian_hessian
4✔
336
        return self._riemannian_hessian
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc