• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

kazewong / flowMC / 13884581419

16 Mar 2025 02:45PM UTC coverage: 92.824% (+11.2%) from 81.604%
13884581419

push

github

web-flow
Merge pull request #197 from kazewong/193-improve-coverage-of-flowmc

193 improve coverage of flowmc

6 of 6 new or added lines in 2 files covered. (100.0%)

1 existing line in 1 file now uncovered.

996 of 1073 relevant lines covered (92.82%)

1.86 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

97.87
/src/flowMC/strategy/optimization.py
1
from typing import Callable
2✔
2

3
import jax
2✔
4
import jax.numpy as jnp
2✔
5
import optax
2✔
6
from jaxtyping import Array, Float, PRNGKeyArray
2✔
7

8
from flowMC.strategy.base import Strategy
2✔
9
from flowMC.resource.base import Resource
2✔
10

11

12
class AdamOptimization(Strategy):
2✔
13
    """Optimize a set of chains using Adam optimization. Note that if the posterior can
14
    go to infinity, this optimization scheme is likely to return NaNs.
15

16
    Args:
17
        n_steps: int = 100
18
            Number of optimization steps.
19
        learning_rate: float = 1e-2
20
            Learning rate for the optimization.
21
        noise_level: float = 10
22
            Variance of the noise added to the gradients.
23
    """
24

25
    logpdf: Callable[[Float[Array, " n_dim"], dict], Float]
2✔
26
    n_steps: int = 100
2✔
27
    learning_rate: float = 1e-2
2✔
28
    noise_level: float = 10
2✔
29
    bounds: Float[Array, "n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]])
2✔
30

31
    def __repr__(self):
32
        return "AdamOptimization"
33

34
    def __init__(
2✔
35
        self,
36
        logpdf: Callable[[Float[Array, " n_dim"], dict], Float],
37
        n_steps: int = 100,
38
        learning_rate: float = 1e-2,
39
        noise_level: float = 10,
40
        bounds: Float[Array, "n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]]),
41
    ):
42
        self.logpdf = logpdf
2✔
43
        self.n_steps = n_steps
2✔
44
        self.learning_rate = learning_rate
2✔
45
        self.noise_level = noise_level
2✔
46
        self.bounds = bounds
2✔
47

48
        self.solver = optax.chain(
2✔
49
            optax.adam(learning_rate=self.learning_rate),
50
        )
51

52
    def __call__(
2✔
53
        self,
54
        rng_key: PRNGKeyArray,
55
        resources: dict[str, Resource],
56
        initial_position: Float[Array, " n_chain n_dim"],
57
        data: dict,
58
    ) -> tuple[
59
        PRNGKeyArray,
60
        dict[str, Resource],
61
        Float[Array, "n_chains n_dim"],
62
    ]:
63
        def loss_fn(params: Float[Array, " n_dim"]) -> Float:
2✔
64
            return -self.logpdf(params, data)
2✔
65

66
        rng_key, optimized_positions = self.optimize(
2✔
67
            rng_key, loss_fn, initial_position, data
68
        )
69

70
        return rng_key, resources, optimized_positions
2✔
71

72
    def optimize(
2✔
73
        self,
74
        rng_key: PRNGKeyArray,
75
        objective: Callable,
76
        initial_position: Float[Array, " n_chain n_dim"],
77
        data: dict,
78
    ):
79
        """Optimization kernel. This can be used independently of the __call__ method.
80

81
        Args:
82
            rng_key: PRNGKeyArray
83
                Random key for the optimization.
84
            objective: Callable
85
                Objective function to optimize.
86
            initial_position: Float[Array, " n_chain n_dim"]
87
                Initial positions for the optimization.
88
        """
89
        grad_fn = jax.jit(jax.grad(objective))
2✔
90

91
        def _kernel(carry, data):
2✔
92
            key, params, opt_state = carry
2✔
93

94
            key, subkey = jax.random.split(key)
2✔
95
            grad = grad_fn(params) * (1 + jax.random.normal(subkey) * self.noise_level)
2✔
96
            updates, opt_state = self.solver.update(grad, opt_state, params)
2✔
97
            params = optax.apply_updates(params, updates)
2✔
98
            params = optax.projections.projection_box(
2✔
99
                params, self.bounds[:, 0], self.bounds[:, 1]
100
            )
101
            return (key, params, opt_state), None
2✔
102

103
        def _single_optimize(
2✔
104
            key: PRNGKeyArray,
105
            initial_position: Float[Array, " n_dim"],
106
        ) -> Float[Array, " n_dim"]:
107
            opt_state = self.solver.init(initial_position)
2✔
108

109
            (key, params, opt_state), _ = jax.lax.scan(
2✔
110
                _kernel,
111
                (key, initial_position, opt_state),
112
                jnp.arange(self.n_steps),
113
            )
114

115
            return params  # type: ignore
2✔
116

117
        print("Using Adam optimization")
2✔
118
        rng_key, subkey = jax.random.split(rng_key)
2✔
119
        keys = jax.random.split(subkey, initial_position.shape[0])
2✔
120
        optimized_positions = jax.vmap(_single_optimize, in_axes=(0, 0))(
2✔
121
            keys, initial_position
122
        )
123

124
        final_log_prob = jax.vmap(self.logpdf, in_axes=(0, None))(
2✔
125
            optimized_positions, data
126
        )
127

128
        if jnp.isinf(final_log_prob).any() or jnp.isnan(final_log_prob).any():
2✔
UNCOV
129
            print("Warning: Optimization accessed infinite or NaN log-probabilities.")
×
130

131
        return rng_key, optimized_positions
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc