• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

kazewong / flowMC / 13754639966

10 Mar 2025 01:10AM UTC coverage: 73.097% (+0.2%) from 72.881%
13754639966

push

github

kazewong
Update optimization strategy with the newest api

0 of 11 new or added lines in 1 file covered. (0.0%)

2 existing lines in 2 files now uncovered.

989 of 1353 relevant lines covered (73.1%)

1.46 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/src/flowMC/strategy/optimization.py
1
from typing import Callable
×
2

3
import jax
×
4
import jax.numpy as jnp
×
5
import optax
×
6
from jaxtyping import Array, Float, PRNGKeyArray, PyTree
×
7

8
from flowMC.resource.local_kernel.base import ProposalBase
×
9
from flowMC.resource.nf_model.NF_proposal import NFProposal
×
10
from flowMC.strategy.base import Strategy
×
NEW
11
from flowMC.resource.base import Resource
×
12

13

14
class optimization_Adam(Strategy):
×
15
    """Optimize a set of chains using Adam optimization. Note that if the posterior can
16
    go to infinity, this optimization scheme is likely to return NaNs.
17

18
    Args:
19
        n_steps: int = 100
20
            Number of optimization steps.
21
        learning_rate: float = 1e-2
22
            Learning rate for the optimization.
23
        noise_level: float = 10
24
            Variance of the noise added to the gradients.
25
    """
26

NEW
27
    logpdf: Callable[[Float[Array, " n_dim"], dict], Float]
×
28
    n_steps: int = 100
×
29
    learning_rate: float = 1e-2
×
30
    noise_level: float = 10
×
31
    bounds: Float[Array, "n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]])
×
32

NEW
33
    def __str__(self):
×
34
        return "AdamOptimization"
×
35

36
    def __init__(
×
37
        self,
38
        logpdf: Callable[[Float[Array, " n_dim"], dict], Float],
39
        n_steps: int = 100,
40
        learning_rate: float = 1e-2,
41
        noise_level: float = 10,
42
        bounds: Float[Array, "n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]]),
43
    ):
44

45

NEW
46
        self.logpdf = logpdf
×
NEW
47
        self.n_steps = n_steps
×
NEW
48
        self.learning_rate = learning_rate
×
NEW
49
        self.noise_level = noise_level
×
NEW
50
        self.bounds = bounds
×
51

52
        self.solver = optax.chain(
×
53
            optax.adam(learning_rate=self.learning_rate),
54
        )
55

56

UNCOV
57
    def __call__(
×
58
        self,
59
        rng_key: PRNGKeyArray,
60
        resources: dict[str, Resource],
61
        initial_position: Float[Array, " n_chain n_dim"],
62
        data: dict,
63
    ) -> tuple[
64
        PRNGKeyArray,
65
        dict[str, Resource],
66
        Float[Array, "n_chains n_dim"],
67
    ]:
68
        def loss_fn(params: Float[Array, " n_dim"]) -> Float:
×
NEW
69
            return -self.logpdf(params, data)
×
70

71
        grad_fn = jax.jit(jax.grad(loss_fn))
×
72

73
        def _kernel(carry, data):
×
74
            key, params, opt_state = carry
×
75

76
            key, subkey = jax.random.split(key)
×
77
            grad = grad_fn(params) * (1 + jax.random.normal(subkey) * self.noise_level)
×
78
            updates, opt_state = self.solver.update(grad, opt_state, params)
×
79
            params = optax.apply_updates(params, updates)
×
80
            params = optax.projections.projection_box(
×
81
                params, self.bounds[:, 0], self.bounds[:, 1]
82
            )
83
            return (key, params, opt_state), None
×
84

85
        def _single_optimize(
×
86
            key: PRNGKeyArray,
87
            initial_position: Float[Array, " n_dim"],
88
        ) -> Float[Array, " n_dim"]:
89
            opt_state = self.solver.init(initial_position)
×
90

91
            (key, params, opt_state), _ = jax.lax.scan(
×
92
                _kernel,
93
                (key, initial_position, opt_state),
94
                jnp.arange(self.n_steps),
95
            )
96

97
            return params  # type: ignore
×
98

99
        print("Using Adam optimization")
×
100
        rng_key, subkey = jax.random.split(rng_key)
×
101
        keys = jax.random.split(subkey, initial_position.shape[0])
×
102
        optimized_positions = jax.vmap(_single_optimize, in_axes=(0, 0))(
×
103
            keys, initial_position
104
        )
105

NEW
106
        final_log_prob = jax.vmap(self.logpdf, in_axes=(0, None))(optimized_positions, data)
×
107

108
        if (
×
109
            jnp.isinf(final_log_prob).any()
110
            or jnp.isnan(final_log_prob).any()
111
        ):
112
            print("Warning: Optimization accessed infinite or NaN log-probabilities.")
×
113

NEW
114
        return rng_key, resources, optimized_positions
×
115

116
    def optimize(
×
117
        self,
118
        rng_key: PRNGKeyArray,
119
        objective: Callable,
120
        initial_position: Float[Array, " n_chain n_dim"],
121
    ):
122
        """Standalone optimization function that takes an objective function and returns
123
        the optimized positions.
124

125
        WARNING: This is an old function that may not be compatible with flowMC 0.4.0
126

127
        Args:
128
            rng_key: PRNGKeyArray
129
                Random key for the optimization.
130
            objective: Callable
131
                Objective function to optimize.
132
            initial_position: Float[Array, " n_chain n_dim"]
133
                Initial positions for the optimization.
134
        """
135
        grad_fn = jax.jit(jax.grad(objective))
×
136

137
        def _kernel(carry, data):
×
138
            key, params, opt_state = carry
×
139

140
            key, subkey = jax.random.split(key)
×
141
            grad = grad_fn(params) * (1 + jax.random.normal(subkey) * self.noise_level)
×
142
            updates, opt_state = self.solver.update(grad, opt_state, params)
×
143
            params = optax.apply_updates(params, updates)
×
144
            return (key, params, opt_state), None
×
145

146
        def _single_optimize(
×
147
            key: PRNGKeyArray,
148
            initial_position: Float[Array, " n_dim"],
149
        ) -> Float[Array, " n_dim"]:
150
            opt_state = self.solver.init(initial_position)
×
151

152
            (key, params, opt_state), _ = jax.lax.scan(
×
153
                _kernel,
154
                (key, initial_position, opt_state),
155
                jnp.arange(self.n_steps),
156
            )
157

158
            return params  # type: ignore
×
159

160
        print("Using Adam optimization")
×
161
        rng_key, subkey = jax.random.split(rng_key)
×
162
        keys = jax.random.split(subkey, initial_position.shape[0])
×
163
        optimized_positions = jax.vmap(_single_optimize, in_axes=(0, 0))(
×
164
            keys, initial_position
165
        )
166

167
        summary = {}
×
168
        summary["initial_positions"] = initial_position
×
169
        summary["initial_log_prob"] = jax.jit(jax.vmap(objective))(initial_position)
×
170
        summary["final_positions"] = optimized_positions
×
171
        summary["final_log_prob"] = jax.jit(jax.vmap(objective))(optimized_positions)
×
172

173
        if (
×
174
            jnp.isinf(summary["final_log_prob"]).any()
175
            or jnp.isnan(summary["final_log_prob"]).any()
176
        ):
177
            print("Warning: Optimization accessed infinite or NaN log-probabilities.")
×
178

179
        return rng_key, optimized_positions, summary
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc