• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

kazewong / flowMC / 13818329754

12 Mar 2025 06:01PM UTC coverage: 81.682% (+13.8%) from 67.835%
13818329754

push

github

web-flow
Merge pull request #196 from kazewong/190-updating-documentation-to-align-with-the-latest-version-of-flowmc

190 updating documentation to align with the latest version of flowmc

38 of 65 new or added lines in 12 files covered. (58.46%)

3 existing lines in 3 files now uncovered.

1039 of 1272 relevant lines covered (81.68%)

1.63 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

63.38
/src/flowMC/strategy/optimization.py
1
from typing import Callable
2✔
2

3
import jax
2✔
4
import jax.numpy as jnp
2✔
5
import optax
2✔
6
from jaxtyping import Array, Float, PRNGKeyArray
2✔
7

8
from flowMC.strategy.base import Strategy
2✔
9
from flowMC.resource.base import Resource
2✔
10

11

12
class AdamOptimization(Strategy):
2✔
13
    """Optimize a set of chains using Adam optimization. Note that if the posterior can
14
    go to infinity, this optimization scheme is likely to return NaNs.
15

16
    Args:
17
        n_steps: int = 100
18
            Number of optimization steps.
19
        learning_rate: float = 1e-2
20
            Learning rate for the optimization.
21
        noise_level: float = 10
22
            Variance of the noise added to the gradients.
23
    """
24

25
    logpdf: Callable[[Float[Array, " n_dim"], dict], Float]
2✔
26
    n_steps: int = 100
2✔
27
    learning_rate: float = 1e-2
2✔
28
    noise_level: float = 10
2✔
29
    bounds: Float[Array, "n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]])
2✔
30

31
    def __str__(self):
2✔
32
        return "AdamOptimization"
×
33

34
    def __init__(
2✔
35
        self,
36
        logpdf: Callable[[Float[Array, " n_dim"], dict], Float],
37
        n_steps: int = 100,
38
        learning_rate: float = 1e-2,
39
        noise_level: float = 10,
40
        bounds: Float[Array, "n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]]),
41
    ):
42
        self.logpdf = logpdf
2✔
43
        self.n_steps = n_steps
2✔
44
        self.learning_rate = learning_rate
2✔
45
        self.noise_level = noise_level
2✔
46
        self.bounds = bounds
2✔
47

48
        self.solver = optax.chain(
2✔
49
            optax.adam(learning_rate=self.learning_rate),
50
        )
51

52
    def __call__(
2✔
53
        self,
54
        rng_key: PRNGKeyArray,
55
        resources: dict[str, Resource],
56
        initial_position: Float[Array, " n_chain n_dim"],
57
        data: dict,
58
    ) -> tuple[
59
        PRNGKeyArray,
60
        dict[str, Resource],
61
        Float[Array, "n_chains n_dim"],
62
    ]:
63
        def loss_fn(params: Float[Array, " n_dim"]) -> Float:
2✔
64
            return -self.logpdf(params, data)
2✔
65

66
        grad_fn = jax.jit(jax.grad(loss_fn))
2✔
67

68
        def _kernel(carry, data):
2✔
69
            key, params, opt_state = carry
2✔
70

71
            key, subkey = jax.random.split(key)
2✔
72
            grad = grad_fn(params) * (1 + jax.random.normal(subkey) * self.noise_level)
2✔
73
            updates, opt_state = self.solver.update(grad, opt_state, params)
2✔
74
            params = optax.apply_updates(params, updates)
2✔
75
            params = optax.projections.projection_box(
2✔
76
                params, self.bounds[:, 0], self.bounds[:, 1]
77
            )
78
            return (key, params, opt_state), None
2✔
79

80
        def _single_optimize(
2✔
81
            key: PRNGKeyArray,
82
            initial_position: Float[Array, " n_dim"],
83
        ) -> Float[Array, " n_dim"]:
84
            opt_state = self.solver.init(initial_position)
2✔
85

86
            (key, params, opt_state), _ = jax.lax.scan(
2✔
87
                _kernel,
88
                (key, initial_position, opt_state),
89
                jnp.arange(self.n_steps),
90
            )
91

92
            return params  # type: ignore
2✔
93

94
        print("Using Adam optimization")
2✔
95
        rng_key, subkey = jax.random.split(rng_key)
2✔
96
        keys = jax.random.split(subkey, initial_position.shape[0])
2✔
97
        optimized_positions = jax.vmap(_single_optimize, in_axes=(0, 0))(
2✔
98
            keys, initial_position
99
        )
100

101
        final_log_prob = jax.vmap(self.logpdf, in_axes=(0, None))(
2✔
102
            optimized_positions, data
103
        )
104

105
        if jnp.isinf(final_log_prob).any() or jnp.isnan(final_log_prob).any():
2✔
UNCOV
106
            print("Warning: Optimization accessed infinite or NaN log-probabilities.")
×
107

108
        return rng_key, resources, optimized_positions
2✔
109

110
    def optimize(
2✔
111
        self,
112
        rng_key: PRNGKeyArray,
113
        objective: Callable,
114
        initial_position: Float[Array, " n_chain n_dim"],
115
    ):
116
        """Standalone optimization function that takes an objective function and returns
117
        the optimized positions.
118

119
        WARNING: This is an old function that may not be compatible with flowMC 0.4.0
120

121
        Args:
122
            rng_key: PRNGKeyArray
123
                Random key for the optimization.
124
            objective: Callable
125
                Objective function to optimize.
126
            initial_position: Float[Array, " n_chain n_dim"]
127
                Initial positions for the optimization.
128
        """
129
        grad_fn = jax.jit(jax.grad(objective))
×
130

131
        def _kernel(carry, data):
×
132
            key, params, opt_state = carry
×
133

134
            key, subkey = jax.random.split(key)
×
135
            grad = grad_fn(params) * (1 + jax.random.normal(subkey) * self.noise_level)
×
136
            updates, opt_state = self.solver.update(grad, opt_state, params)
×
137
            params = optax.apply_updates(params, updates)
×
138
            return (key, params, opt_state), None
×
139

140
        def _single_optimize(
×
141
            key: PRNGKeyArray,
142
            initial_position: Float[Array, " n_dim"],
143
        ) -> Float[Array, " n_dim"]:
144
            opt_state = self.solver.init(initial_position)
×
145

146
            (key, params, opt_state), _ = jax.lax.scan(
×
147
                _kernel,
148
                (key, initial_position, opt_state),
149
                jnp.arange(self.n_steps),
150
            )
151

152
            return params  # type: ignore
×
153

154
        print("Using Adam optimization")
×
155
        rng_key, subkey = jax.random.split(rng_key)
×
156
        keys = jax.random.split(subkey, initial_position.shape[0])
×
157
        optimized_positions = jax.vmap(_single_optimize, in_axes=(0, 0))(
×
158
            keys, initial_position
159
        )
160

161
        summary = {}
×
162
        summary["initial_positions"] = initial_position
×
163
        summary["initial_log_prob"] = jax.jit(jax.vmap(objective))(initial_position)
×
164
        summary["final_positions"] = optimized_positions
×
165
        summary["final_log_prob"] = jax.jit(jax.vmap(objective))(optimized_positions)
×
166

167
        if (
×
168
            jnp.isinf(summary["final_log_prob"]).any()
169
            or jnp.isnan(summary["final_log_prob"]).any()
170
        ):
171
            print("Warning: Optimization accessed infinite or NaN log-probabilities.")
×
172

173
        return rng_key, optimized_positions, summary
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc