• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

GW-JAX-Team / flowMC / 20065757058

09 Dec 2025 01:48PM UTC coverage: 91.566% (-0.09%) from 91.653%
20065757058

push

github

thomasckng
refactor: replace capsys with caplog for logging in parameter print tests

1748 of 1909 relevant lines covered (91.57%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

94.55
/src/flowMC/strategy/optimization.py
1
from typing import Callable
1✔
2

3
import jax
1✔
4
import jax.numpy as jnp
1✔
5
import optax
1✔
6
from jaxtyping import Array, Float, PRNGKeyArray
1✔
7

8
from flowMC.strategy.base import Strategy
1✔
9
from flowMC.resource.base import Resource
1✔
10
import logging
1✔
11

12
logger = logging.getLogger(__name__)
1✔
13

14

15
class AdamOptimization(Strategy):
1✔
16
    """Optimize a set of chains using Adam optimization. Note that if the posterior can
17
    go to infinity, this optimization scheme is likely to return NaNs.
18

19
    Args:
20
        n_steps: int = 100
21
            Number of optimization steps.
22
        learning_rate: float = 1e-2
23
            Learning rate for the optimization.
24
        noise_level: float = 10
25
            Variance of the noise added to the gradients.
26
        bounds: Float[Array, " n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]])
27
            Bounds for the optimization. The optimization will be projected to these bounds.
28
            If bounds has shape (1, 2), it will be broadcast to all dimensions. For n_dim > 1,
29
            passing a (1, 2) array applies the same bound to every dimension. To specify different
30
            bounds per dimension, provide an array of shape (n_dim, 2).
31
    """
32

33
    logpdf: Callable[[Float[Array, " n_dim"], dict], Float]
1✔
34
    n_steps: int = 100
1✔
35
    learning_rate: float = 1e-2
1✔
36
    noise_level: float = 10
1✔
37
    bounds: Float[Array, " n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]])
1✔
38

39
    def __repr__(self):
40
        return "AdamOptimization"
41

42
    def __init__(
1✔
43
        self,
44
        logpdf: Callable[[Float[Array, " n_dim"], dict], Float],
45
        n_steps: int = 100,
46
        learning_rate: float = 1e-2,
47
        noise_level: float = 10,
48
        bounds: Float[Array, " n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]]),
49
    ):
50
        self.logpdf = logpdf
1✔
51
        self.n_steps = n_steps
1✔
52
        self.learning_rate = learning_rate
1✔
53
        self.noise_level = noise_level
1✔
54
        self.bounds = bounds
1✔
55

56
        # Validate bounds shape
57
        if bounds.ndim != 2 or bounds.shape[1] != 2:
1✔
58
            raise ValueError(
×
59
                f"bounds must have shape (n_dim, 2) or (1, 2), got {bounds.shape}"
60
            )
61
        # If bounds is (1, 2), it will be broadcast to all dimensions. If not, check compatibility.
62
        # Try to infer n_dim from logpdf signature or initial_position, but here we can't, so warn in runtime.
63

64
        self.solver = optax.chain(
1✔
65
            optax.adam(learning_rate=self.learning_rate),
66
        )
67

68
    def __call__(
1✔
69
        self,
70
        rng_key: PRNGKeyArray,
71
        resources: dict[str, Resource],
72
        initial_position: Float[Array, " n_chain n_dim"],
73
        data: dict,
74
    ) -> tuple[
75
        PRNGKeyArray,
76
        dict[str, Resource],
77
        Float[Array, " n_chain n_dim"],
78
    ]:
79
        def loss_fn(params: Float[Array, " n_dim"], data: dict) -> Float:
1✔
80
            return -self.logpdf(params, data)
1✔
81

82
        rng_key, optimized_positions, _ = self.optimize(
1✔
83
            rng_key, loss_fn, initial_position, data
84
        )
85

86
        return rng_key, resources, optimized_positions
1✔
87

88
    def optimize(
1✔
89
        self,
90
        rng_key: PRNGKeyArray,
91
        objective: Callable,
92
        initial_position: Float[Array, " n_chain n_dim"],
93
        data: dict,
94
    ):
95
        # Validate bounds shape against n_dim
96
        n_dim = initial_position.shape[-1]
1✔
97
        if not (self.bounds.shape[0] == 1 or self.bounds.shape[0] == n_dim):
1✔
98
            raise ValueError(
×
99
                f"bounds shape {self.bounds.shape} is incompatible with n_dim={n_dim}. "
100
                "Provide bounds of shape (1, 2) for broadcasting or (n_dim, 2) for per-dimension bounds."
101
            )
102

103
        """Optimization kernel. This can be used independently of the __call__ method.
1✔
104

105
        Args:
106
            rng_key: PRNGKeyArray
107
                Random key for the optimization.
108
            objective: Callable
109
                Objective function to optimize.
110
            initial_position: Float[Array, " n_chain n_dim"]
111
                Initial positions for the optimization.
112
            data: dict
113
                Data to pass to the objective function.
114

115
        Returns:
116
            rng_key: PRNGKeyArray
117
                Updated random key.
118
            optimized_positions: Float[Array, " n_chain n_dim"]
119
                Optimized positions.
120
            final_log_prob: Float[Array, " n_chain"]
121
                Final log-probabilities of the optimized positions.
122
        """
123
        grad_fn = jax.jit(jax.grad(objective))
1✔
124

125
        def _kernel(carry, _step):
1✔
126
            key, params, opt_state = carry
1✔
127

128
            key, subkey = jax.random.split(key)
1✔
129
            grad = grad_fn(params, data) * (
1✔
130
                1 + jax.random.normal(subkey) * self.noise_level
131
            )
132
            updates, opt_state = self.solver.update(grad, opt_state, params)
1✔
133
            params = optax.apply_updates(params, updates)
1✔
134
            params = optax.projections.projection_box(
1✔
135
                params, self.bounds[:, 0], self.bounds[:, 1]
136
            )
137
            return (key, params, opt_state), None
1✔
138

139
        def _single_optimize(
1✔
140
            key: PRNGKeyArray,
141
            initial_position: Float[Array, " n_dim"],
142
        ) -> Float[Array, " n_dim"]:
143
            opt_state = self.solver.init(initial_position)
1✔
144

145
            (key, params, opt_state), _ = jax.lax.scan(
1✔
146
                _kernel,
147
                (key, initial_position, opt_state),
148
                jnp.arange(self.n_steps),
149
            )
150

151
            return params  # type: ignore
1✔
152

153
        logger.info("Using Adam optimization")
1✔
154
        rng_key, subkey = jax.random.split(rng_key)
1✔
155
        keys = jax.random.split(subkey, initial_position.shape[0])
1✔
156
        optimized_positions = jax.vmap(_single_optimize, in_axes=(0, 0))(
1✔
157
            keys, initial_position
158
        )
159

160
        final_log_prob = jax.vmap(self.logpdf, in_axes=(0, None))(
1✔
161
            optimized_positions, data
162
        )
163

164
        if jnp.isinf(final_log_prob).any() or jnp.isnan(final_log_prob).any():
1✔
165
            logger.warning("Optimization accessed infinite or NaN log-probabilities.")
×
166

167
        return rng_key, optimized_positions, final_log_prob
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc