• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

kazewong / flowMC / 13655281414

04 Mar 2025 01:55PM UTC coverage: 67.835%. Remained the same
13655281414

push

github

kazewong
format doc strings

1 of 1 new or added line in 1 file covered. (100.0%)

162 existing lines in 15 files now uncovered.

987 of 1455 relevant lines covered (67.84%)

1.36 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/src/flowMC/strategy/optimization.py
1
from typing import Callable
×
2

3
import jax
×
4
import jax.numpy as jnp
×
5
import optax
×
6
from jaxtyping import Array, Float, PRNGKeyArray, PyTree
×
7

8
from flowMC.resource.local_kernel.base import ProposalBase
×
9
from flowMC.resource.nf_model.NF_proposal import NFProposal
×
10
from flowMC.strategy.base import Strategy
×
11

12

13
class optimization_Adam(Strategy):
×
14
    """Optimize a set of chains using Adam optimization. Note that if the posterior can
15
    go to infinity, this optimization scheme is likely to return NaNs.
16

17
    Args:
18
        n_steps: int = 100
19
            Number of optimization steps.
20
        learning_rate: float = 1e-2
21
            Learning rate for the optimization.
22
        noise_level: float = 10
23
            Variance of the noise added to the gradients.
24
    """
25

UNCOV
26
    n_steps: int = 100
×
27
    learning_rate: float = 1e-2
×
28
    noise_level: float = 10
×
29
    bounds: Float[Array, "n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]])
×
30

UNCOV
31
    @property
×
32
    def __name__(self):
×
33
        return "AdamOptimization"
×
34

UNCOV
35
    def __init__(
×
36
        self,
37
        bounds: Float[Array, "n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]]),
38
        **kwargs,
39
    ):
UNCOV
40
        class_keys = list(self.__class__.__annotations__.keys())
×
41
        for key, value in kwargs.items():
×
42
            if key in class_keys:
×
43
                if not key.startswith("__"):
×
44
                    setattr(self, key, value)
×
45

UNCOV
46
        self.solver = optax.chain(
×
47
            optax.adam(learning_rate=self.learning_rate),
48
        )
49

UNCOV
50
        self.bounds = bounds
×
51

UNCOV
52
    def __call__(
×
53
        self,
54
        rng_key: PRNGKeyArray,
55
        local_sampler: ProposalBase,
56
        global_sampler: NFProposal,
57
        initial_position: Float[Array, " n_chain n_dim"],
58
        data: dict,
59
    ) -> tuple[
60
        PRNGKeyArray, Float[Array, " n_chain n_dim"], ProposalBase, NFProposal, PyTree
61
    ]:
UNCOV
62
        def loss_fn(params: Float[Array, " n_dim"]) -> Float:
×
63
            return -local_sampler.log_pdf(params, data)
×
64

UNCOV
65
        grad_fn = jax.jit(jax.grad(loss_fn))
×
66

UNCOV
67
        def _kernel(carry, data):
×
68
            key, params, opt_state = carry
×
69

UNCOV
70
            key, subkey = jax.random.split(key)
×
71
            grad = grad_fn(params) * (1 + jax.random.normal(subkey) * self.noise_level)
×
72
            updates, opt_state = self.solver.update(grad, opt_state, params)
×
73
            params = optax.apply_updates(params, updates)
×
74
            params = optax.projections.projection_box(
×
75
                params, self.bounds[:, 0], self.bounds[:, 1]
76
            )
UNCOV
77
            return (key, params, opt_state), None
×
78

UNCOV
79
        def _single_optimize(
×
80
            key: PRNGKeyArray,
81
            initial_position: Float[Array, " n_dim"],
82
        ) -> Float[Array, " n_dim"]:
UNCOV
83
            opt_state = self.solver.init(initial_position)
×
84

UNCOV
85
            (key, params, opt_state), _ = jax.lax.scan(
×
86
                _kernel,
87
                (key, initial_position, opt_state),
88
                jnp.arange(self.n_steps),
89
            )
90

UNCOV
91
            return params  # type: ignore
×
92

UNCOV
93
        print("Using Adam optimization")
×
94
        rng_key, subkey = jax.random.split(rng_key)
×
95
        keys = jax.random.split(subkey, initial_position.shape[0])
×
96
        optimized_positions = jax.vmap(_single_optimize, in_axes=(0, 0))(
×
97
            keys, initial_position
98
        )
99

UNCOV
100
        summary = {}
×
101
        summary["initial_positions"] = initial_position
×
102
        summary["initial_log_prob"] = local_sampler.logpdf_vmap(initial_position, data)
×
103
        summary["final_positions"] = optimized_positions
×
104
        summary["final_log_prob"] = local_sampler.logpdf_vmap(optimized_positions, data)
×
105

UNCOV
106
        if (
×
107
            jnp.isinf(summary["final_log_prob"]).any()
108
            or jnp.isnan(summary["final_log_prob"]).any()
109
        ):
UNCOV
110
            print("Warning: Optimization accessed infinite or NaN log-probabilities.")
×
111

UNCOV
112
        return rng_key, optimized_positions, local_sampler, global_sampler, summary
×
113

UNCOV
114
    def optimize(
×
115
        self,
116
        rng_key: PRNGKeyArray,
117
        objective: Callable,
118
        initial_position: Float[Array, " n_chain n_dim"],
119
    ):
120
        """Standalone optimization function that takes an objective function and returns
121
        the optimized positions.
122

123
        Args:
124
            rng_key: PRNGKeyArray
125
                Random key for the optimization.
126
            objective: Callable
127
                Objective function to optimize.
128
            initial_position: Float[Array, " n_chain n_dim"]
129
                Initial positions for the optimization.
130
        """
UNCOV
131
        grad_fn = jax.jit(jax.grad(objective))
×
132

UNCOV
133
        def _kernel(carry, data):
×
134
            key, params, opt_state = carry
×
135

UNCOV
136
            key, subkey = jax.random.split(key)
×
137
            grad = grad_fn(params) * (1 + jax.random.normal(subkey) * self.noise_level)
×
138
            updates, opt_state = self.solver.update(grad, opt_state, params)
×
139
            params = optax.apply_updates(params, updates)
×
140
            return (key, params, opt_state), None
×
141

UNCOV
142
        def _single_optimize(
×
143
            key: PRNGKeyArray,
144
            initial_position: Float[Array, " n_dim"],
145
        ) -> Float[Array, " n_dim"]:
UNCOV
146
            opt_state = self.solver.init(initial_position)
×
147

UNCOV
148
            (key, params, opt_state), _ = jax.lax.scan(
×
149
                _kernel,
150
                (key, initial_position, opt_state),
151
                jnp.arange(self.n_steps),
152
            )
153

UNCOV
154
            return params  # type: ignore
×
155

UNCOV
156
        print("Using Adam optimization")
×
157
        rng_key, subkey = jax.random.split(rng_key)
×
158
        keys = jax.random.split(subkey, initial_position.shape[0])
×
159
        optimized_positions = jax.vmap(_single_optimize, in_axes=(0, 0))(
×
160
            keys, initial_position
161
        )
162

UNCOV
163
        summary = {}
×
164
        summary["initial_positions"] = initial_position
×
165
        summary["initial_log_prob"] = jax.jit(jax.vmap(objective))(initial_position)
×
166
        summary["final_positions"] = optimized_positions
×
167
        summary["final_log_prob"] = jax.jit(jax.vmap(objective))(optimized_positions)
×
168

UNCOV
169
        if (
×
170
            jnp.isinf(summary["final_log_prob"]).any()
171
            or jnp.isnan(summary["final_log_prob"]).any()
172
        ):
UNCOV
173
            print("Warning: Optimization accessed infinite or NaN log-probabilities.")
×
174

UNCOV
175
        return rng_key, optimized_positions, summary
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc