• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

kazewong / flowMC / 13654930442

04 Mar 2025 01:36PM UTC coverage: 67.835%. Remained the same
13654930442

push

github

kazewong
black formatting

0 of 4 new or added lines in 2 files covered. (0.0%)

2 existing lines in 2 files now uncovered.

987 of 1455 relevant lines covered (67.84%)

1.36 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/src/flowMC/strategy/optimization.py
1
from typing import Callable
×
2

3
import jax
×
4
import jax.numpy as jnp
×
5
import optax
×
6
from jaxtyping import Array, Float, PRNGKeyArray, PyTree
×
7

NEW
8
from flowMC.resource.local_kernel.base import ProposalBase
×
NEW
9
from flowMC.resource.nf_model.NF_proposal import NFProposal
×
UNCOV
10
from flowMC.strategy.base import Strategy
×
11

12

13
class optimization_Adam(Strategy):
×
14
    """
15
    Optimize a set of chains using Adam optimization.
16
    Note that if the posterior can go to infinity, this optimization scheme is likely to return NaNs.
17

18
    Args:
19
        n_steps: int = 100
20
            Number of optimization steps.
21
        learning_rate: float = 1e-2
22
            Learning rate for the optimization.
23
        noise_level: float = 10
24
            Variance of the noise added to the gradients.
25
    """
26

27
    n_steps: int = 100
×
28
    learning_rate: float = 1e-2
×
29
    noise_level: float = 10
×
30
    bounds: Float[Array, "n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]])
×
31

32
    @property
×
33
    def __name__(self):
×
34
        return "AdamOptimization"
×
35

36
    def __init__(
×
37
        self,
38
        bounds: Float[Array, "n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]]),
39
        **kwargs,
40
    ):
41
        class_keys = list(self.__class__.__annotations__.keys())
×
42
        for key, value in kwargs.items():
×
43
            if key in class_keys:
×
44
                if not key.startswith("__"):
×
45
                    setattr(self, key, value)
×
46

47
        self.solver = optax.chain(
×
48
            optax.adam(learning_rate=self.learning_rate),
49
        )
50

51
        self.bounds = bounds
×
52

53
    def __call__(
×
54
        self,
55
        rng_key: PRNGKeyArray,
56
        local_sampler: ProposalBase,
57
        global_sampler: NFProposal,
58
        initial_position: Float[Array, " n_chain n_dim"],
59
        data: dict,
60
    ) -> tuple[
61
        PRNGKeyArray, Float[Array, " n_chain n_dim"], ProposalBase, NFProposal, PyTree
62
    ]:
63
        def loss_fn(params: Float[Array, " n_dim"]) -> Float:
×
NEW
64
            return -local_sampler.log_pdf(params, data)
×
65

66
        grad_fn = jax.jit(jax.grad(loss_fn))
×
67

68
        def _kernel(carry, data):
×
69
            key, params, opt_state = carry
×
70

71
            key, subkey = jax.random.split(key)
×
72
            grad = grad_fn(params) * (1 + jax.random.normal(subkey) * self.noise_level)
×
73
            updates, opt_state = self.solver.update(grad, opt_state, params)
×
74
            params = optax.apply_updates(params, updates)
×
75
            params = optax.projections.projection_box(
×
76
                params, self.bounds[:, 0], self.bounds[:, 1]
77
            )
78
            return (key, params, opt_state), None
×
79

80
        def _single_optimize(
×
81
            key: PRNGKeyArray,
82
            initial_position: Float[Array, " n_dim"],
83
        ) -> Float[Array, " n_dim"]:
84
            opt_state = self.solver.init(initial_position)
×
85

86
            (key, params, opt_state), _ = jax.lax.scan(
×
87
                _kernel,
88
                (key, initial_position, opt_state),
89
                jnp.arange(self.n_steps),
90
            )
91

92
            return params  # type: ignore
×
93

94
        print("Using Adam optimization")
×
95
        rng_key, subkey = jax.random.split(rng_key)
×
96
        keys = jax.random.split(subkey, initial_position.shape[0])
×
97
        optimized_positions = jax.vmap(_single_optimize, in_axes=(0, 0))(
×
98
            keys, initial_position
99
        )
100

101
        summary = {}
×
102
        summary["initial_positions"] = initial_position
×
103
        summary["initial_log_prob"] = local_sampler.logpdf_vmap(initial_position, data)
×
104
        summary["final_positions"] = optimized_positions
×
105
        summary["final_log_prob"] = local_sampler.logpdf_vmap(optimized_positions, data)
×
106

107
        if (
×
108
            jnp.isinf(summary["final_log_prob"]).any()
109
            or jnp.isnan(summary["final_log_prob"]).any()
110
        ):
111
            print("Warning: Optimization accessed infinite or NaN log-probabilities.")
×
112

113
        return rng_key, optimized_positions, local_sampler, global_sampler, summary
×
114

115
    def optimize(
×
116
        self,
117
        rng_key: PRNGKeyArray,
118
        objective: Callable,
119
        initial_position: Float[Array, " n_chain n_dim"],
120
    ):
121
        """
122
        Standalone optimization function that takes an objective function and returns the optimized positions.
123

124
        Args:
125
            rng_key: PRNGKeyArray
126
                Random key for the optimization.
127
            objective: Callable
128
                Objective function to optimize.
129
            initial_position: Float[Array, " n_chain n_dim"]
130
                Initial positions for the optimization.
131
        """
132
        grad_fn = jax.jit(jax.grad(objective))
×
133

134
        def _kernel(carry, data):
×
135
            key, params, opt_state = carry
×
136

137
            key, subkey = jax.random.split(key)
×
138
            grad = grad_fn(params) * (1 + jax.random.normal(subkey) * self.noise_level)
×
139
            updates, opt_state = self.solver.update(grad, opt_state, params)
×
140
            params = optax.apply_updates(params, updates)
×
141
            return (key, params, opt_state), None
×
142

143
        def _single_optimize(
×
144
            key: PRNGKeyArray,
145
            initial_position: Float[Array, " n_dim"],
146
        ) -> Float[Array, " n_dim"]:
147
            opt_state = self.solver.init(initial_position)
×
148

149
            (key, params, opt_state), _ = jax.lax.scan(
×
150
                _kernel,
151
                (key, initial_position, opt_state),
152
                jnp.arange(self.n_steps),
153
            )
154

155
            return params  # type: ignore
×
156

157
        print("Using Adam optimization")
×
158
        rng_key, subkey = jax.random.split(rng_key)
×
159
        keys = jax.random.split(subkey, initial_position.shape[0])
×
160
        optimized_positions = jax.vmap(_single_optimize, in_axes=(0, 0))(
×
161
            keys, initial_position
162
        )
163

164
        summary = {}
×
165
        summary["initial_positions"] = initial_position
×
166
        summary["initial_log_prob"] = jax.jit(jax.vmap(objective))(initial_position)
×
167
        summary["final_positions"] = optimized_positions
×
168
        summary["final_log_prob"] = jax.jit(jax.vmap(objective))(optimized_positions)
×
169

170
        if (
×
171
            jnp.isinf(summary["final_log_prob"]).any()
172
            or jnp.isnan(summary["final_log_prob"]).any()
173
        ):
174
            print("Warning: Optimization accessed infinite or NaN log-probabilities.")
×
175

176
        return rng_key, optimized_positions, summary
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc