• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

kazewong / flowMC / 14429832292

13 Apr 2025 01:14PM UTC coverage: 93.314% (-0.03%) from 93.339%
14429832292

push

github

web-flow
Merge pull request #225 from kazewong/development

flowMC 0.4.3 update

259 of 278 new or added lines in 14 files covered. (93.17%)

1 existing line in 1 file now uncovered.

1270 of 1361 relevant lines covered (93.31%)

1.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

89.77
/src/flowMC/strategy/parallel_tempering.py
1
from flowMC.resource.base import Resource
2✔
2
from flowMC.resource.local_kernel.base import ProposalBase
2✔
3
from flowMC.resource.buffers import Buffer
2✔
4
from flowMC.resource.states import State
2✔
5
from flowMC.resource.logPDF import TemperedPDF
2✔
6
from flowMC.strategy.base import Strategy
2✔
7
from jaxtyping import Array, Float, PRNGKeyArray, Int, Bool
2✔
8
import jax
2✔
9
import jax.numpy as jnp
2✔
10
import equinox as eqx
2✔
11

12

13
class ParallelTempering(Strategy):
2✔
14
    """Sample a tempered PDF with one exchange step.
15
    This is in essence closer to TakeSteps than global tuning.
16
    Considering the tempered version of the PDF is only there to
17
    help with convergence, by default the extra information in
18
    temperature not equal to 1 is not saved.
19

20
    There should be a version of this class that saves the extra
21
    information in the temperature not equal to 1, which could be
22
    used for other purposes such as diagnostics or training.
23
    """
24

25
    n_steps: int
2✔
26
    tempered_logpdf_name: str
2✔
27
    kernel_name: str
2✔
28
    tempered_buffer_names: list[str]
2✔
29
    verbose: bool = False
2✔
30

31
    def __init__(
2✔
32
        self,
33
        n_steps: int,
34
        tempered_logpdf_name: str,
35
        kernel_name: str,
36
        tempered_buffer_names: list[str],
37
        state_name: str,
38
        verbose: bool = False,
39
    ):
40
        self.n_steps = n_steps
2✔
41
        self.tempered_logpdf_name = tempered_logpdf_name
2✔
42
        self.kernel_name = kernel_name
2✔
43
        self.tempered_buffer_names = tempered_buffer_names
2✔
44
        self.verbose = verbose
2✔
45
        self.state_name = state_name
2✔
46

47
    def __call__(
2✔
48
        self,
49
        rng_key: PRNGKeyArray,
50
        resources: dict[str, Resource],
51
        initial_position: Float[Array, "n_chains n_dims"],
52
        data: dict,
53
    ) -> tuple[
54
        PRNGKeyArray,
55
        dict[str, Resource],
56
        Float[Array, "n_chains n_dim"],
57
    ]:
58
        """
59
        Resources must contain:
60
            - TemperedPDF
61
            - Local kernel
62
            - A buffer holding the tempered positions
63
            - A buffer holding the temperatures
64

65
        This strategy has 3 main steps:
66
        1. Sample from the tempered PDF using the local kernel for n_steps
67
        2. Exchange the samples between the temperatures
68
        3. Adapt the temperatures based on the acceptance rate
69

70
        TODO: Add way to turn of temperature adaptation to maintain detail balance.
71
        """
72

73
        rng_key, subkey = jax.random.split(rng_key)
2✔
74
        assert isinstance(kernel := resources[self.kernel_name], ProposalBase)
2✔
75
        assert isinstance(
2✔
76
            tempered_logpdf := resources[self.tempered_logpdf_name], TemperedPDF
77
        )
78
        assert isinstance(
2✔
79
            tempered_positions := resources[self.tempered_buffer_names[0]], Buffer
80
        )  # Shape (n_chains, n_temps, n_dims)
81

82
        assert isinstance(
2✔
83
            temperatures := resources[self.tempered_buffer_names[1]], Buffer
84
        )
85
        assert isinstance(state := resources[self.state_name], State)
2✔
86

87
        initial_position = jnp.concatenate(
2✔
88
            [initial_position[:, None, :], tempered_positions.data],
89
            axis=1,
90
        )
91

92
        # Take individual steps
93

94
        rng_key, subkey = jax.random.split(rng_key)
2✔
95
        subkey = jax.random.split(subkey, initial_position.shape[0])
2✔
96
        positions, log_probs, do_accepts = eqx.filter_jit(
2✔
97
            eqx.filter_vmap(
98
                jax.tree_util.Partial(self._ensemble_step, kernel),
99
                in_axes=(0, 0, None, None, None),
100
            )
101
        )(
102
            subkey, initial_position, tempered_logpdf, temperatures.data, data
103
        )  # vmapping over chains
104

105
        if self.verbose:
2✔
NEW
106
            mean_accs = jnp.mean(do_accepts)
×
NEW
107
            print("Mean acceptance of individual steps in PT: " + str(mean_accs))
×
108
            # print(log_probs)
109

110
        # Exchange between temperatures
111

112
        rng_key, subkey = jax.random.split(rng_key)
2✔
113
        subkey = jax.random.split(subkey, initial_position.shape[0])
2✔
114
        positions, log_probs, do_accepts = eqx.filter_jit(
2✔
115
            eqx.filter_vmap(self._exchange, in_axes=(0, 0, None, None, None))
116
        )(subkey, positions, tempered_logpdf, temperatures.data, data)
117

118
        if self.verbose:
2✔
NEW
119
            mean_accs = jnp.mean(do_accepts)
×
NEW
120
            print("Mean acceptance of exchange steps in PT: " + str(mean_accs))
×
121

122
        # Update the buffers
123
        if state.data["training"]:
2✔
124

NEW
125
            tempered_positions.update_buffer(positions[:, 1:], 0)
×
126

127
            # Adapt the temperatures
NEW
128
            temperatures.update_buffer(
×
129
                eqx.filter_jit(self._adapt_temperature)(temperatures.data, do_accepts),
130
                0,
131
            )
132

133
        return rng_key, resources, positions[:, 0]
2✔
134

135
    def _individual_step_body(
2✔
136
        self,
137
        kernel: ProposalBase,
138
        carry: tuple[
139
            PRNGKeyArray,
140
            Float[Array, " n_dims"],
141
            Float[Array, "1"],
142
            TemperedPDF,
143
            Float[Array, " n_temps"],
144
            dict,
145
        ],
146
        aux,
147
    ) -> tuple[
148
        tuple[
149
            PRNGKeyArray,
150
            Float[Array, " n_dims"],
151
            Float[Array, "1"],
152
            TemperedPDF,
153
            Float[Array, " n_temps"],
154
            dict,
155
        ],
156
        tuple[
157
            Float[Array, " n_dims"],
158
            Float[Array, "1"],
159
            Int[Array, "1"],
160
        ],
161
    ]:
162
        """Take a step using the kernel and the tempered logpdf.
163
        This should not be called directly but instead used in a
164
        jax.lax.scan to take multiple steps.
165

166
        Args:
167
            kernel (ProposalBase): The kernel to use.
168
            carry (tuple): The current state of the chain.
169
                - key (PRNGKeyArray): jax random key.
170
                - position (Float[Array, " n_dims"]): Current position of the chain.
171
                - log_prob (Float[Array, "1"]): Current log probability of the chain.
172
                - logpdf (TemperedPDF): The tempered LogPDF class.
173
                - temperatures (Float[Array, " n_temps"]): Array of temperatures.
174
                - data (dict): Additional data to pass to the logpdf.
175
            aux (None): Not used.
176
        Returns:
177
            tuple: Updated carry and the result of the kernel step.
178
                - carry (tuple): Updated state of the chain.
179
                    - key (PRNGKeyArray): jax random key.
180
                    - position (Float[Array, " n_dims"]): New position of the chain.
181
                    - log_prob (Float[Array, "1"]): New log probability of the chain.
182
                    - logpdf (TemperedPDF): The tempered LogPDF class.
183
                    - temperatures (Float[Array, " n_temps"]): Array of temperatures.
184
                    - data (dict): Additional data to pass to the logpdf.
185
                - result (tuple): Result of the kernel step.
186
                    - position (Float[Array, " n_dims"]): New position of the chain.
187
                    - log_prob (Float[Array, "1"]): New log probability of the chain.
188
                    - do_accept (Int[Array, "1"]): Whether the new position is accepted.
189
        """
190
        key, position, log_prob, logpdf, temperatures, data = carry
2✔
191
        key, subkey = jax.random.split(key)
2✔
192
        position, log_prob, do_accept = kernel.kernel(
2✔
193
            subkey,
194
            position,
195
            log_prob,
196
            jax.tree_util.Partial(logpdf.tempered_log_pdf, temperatures),
197
            data,
198
        )
199
        return (key, position, log_prob, logpdf, temperatures, data), (
2✔
200
            position,
201
            log_prob,
202
            do_accept,
203
        )
204

205
    def _individal_step(
2✔
206
        self,
207
        kernel: ProposalBase,
208
        rng_key: PRNGKeyArray,
209
        positions: Float[Array, " n_dims"],
210
        logpdf: TemperedPDF,
211
        temperatures: Float[Array, " n_temps"],
212
        data: dict,
213
    ) -> tuple[
214
        Float[Array, " n_dims"],
215
        Float[Array, "1"],
216
        Int[Array, "1"],
217
    ]:
218
        """
219
        Perform a series of individual steps for a single chain using the kernel.
220

221
        Args:
222
            kernel (ProposalBase): The kernel to use for proposing new positions.
223
            rng_key (PRNGKeyArray): jax random key for reproducibility.
224
            positions (Float[Array, " n_dims"]): Current positions of the chain.
225
            logpdf (TemperedPDF): The tempered log probability density function.
226
            temperatures (Float[Array, " n_temps"]): Array of temperatures.
227
            data (dict): Additional data to pass to the logpdf.
228

229
        Returns:
230
            tuple:
231
                - positions (Float[Array, " n_dims"]): Updated positions of the chain.
232
                - log_probs (Float[Array, "1"]): Log probabilities of the chain.
233
                - do_accept (Int[Array, "1"]): Acceptance flag for the new position.
234
        """
235
        log_probs = jax.tree_util.Partial(logpdf.tempered_log_pdf, temperatures)(
2✔
236
            positions, data
237
        )
238

239
        (key, position, log_prob, logpdf, temperatures, data), (
2✔
240
            positions,
241
            log_probs,
242
            do_accept,
243
        ) = jax.lax.scan(
244
            jax.tree_util.Partial(self._individual_step_body, kernel),
245
            ((rng_key, positions, log_probs, logpdf, temperatures, data)),
246
            length=self.n_steps,
247
        )
248
        return position, log_prob, do_accept
2✔
249

250
    def _ensemble_step(
2✔
251
        self,
252
        kernel: ProposalBase,
253
        rng_key: PRNGKeyArray,
254
        positions: Float[Array, "n_temps n_dims"],
255
        logpdf: TemperedPDF,
256
        temperatures: Float[Array, " n_temps"],
257
        data: dict,
258
    ) -> tuple[
259
        Float[Array, "n_temps n_dims"],
260
        Float[Array, " n_temps"],
261
        Int[Array, " n_temps"],
262
    ]:
263
        """
264
        Perform ensemble steps for all chains and temperatures.
265

266
        Args:
267
            kernel (ProposalBase): The kernel to use for proposing new positions.
268
            rng_key (PRNGKeyArray): Random key for reproducibility.
269
            positions (Float[Array, "n_temps n_dims"]): Current positions for all temperatures.
270
            logpdf (TemperedPDF): The tempered log probability density function.
271
            temperatures (Float[Array, " n_temps"]): Array of temperatures.
272
            data (dict): Additional data to pass to the logpdf.
273

274
        Returns:
275
            tuple:
276
                - positions (Float[Array, "n_temps n_dims"]): Updated positions for all temperatures.
277
                - log_probs (Float[Array, " n_temps"]): Log probabilities for all temperatures.
278
                - do_accept (Int[Array, " n_temps"]): Acceptance flags for each temperature.
279
        """
280

281
        if self.verbose:
2✔
282
            print("Taking individual steps")
×
283
        rng_key = jax.random.split(rng_key, positions.shape[0])
2✔
284

285
        positions, log_probs, do_accept = jax.vmap(
2✔
286
            self._individal_step, in_axes=(None, 0, 0, None, 0, None)
287
        )(kernel, rng_key, positions, logpdf, temperatures, data)
288

289
        return positions, log_probs, do_accept
2✔
290

291
    def _exchange_step_body(
2✔
292
        self,
293
        carry: tuple[
294
            PRNGKeyArray,
295
            Float[Array, "n_temps n_dims"],
296
            Float[Array, " n_temps"],
297
            int,
298
            TemperedPDF,
299
            Float[Array, " n_temps"],
300
            dict,
301
        ],
302
        aux: None,
303
    ) -> tuple[
304
        tuple[
305
            PRNGKeyArray,
306
            Float[Array, "n_temps n_dims"],
307
            Float[Array, " n_temps"],
308
            int,
309
            TemperedPDF,
310
            Float[Array, " n_temps"],
311
            dict,
312
        ],
313
        Int[Array, "1"],
314
    ]:
315

316
        key, positions, log_probs, idx, logpdf, temperatures, data = carry
2✔
317

318
        key, subkey = jax.random.split(key)
2✔
319
        ratio = (1.0 / temperatures[idx - 1] - 1.0 / temperatures[idx]) * (
2✔
320
            log_probs[idx] - log_probs[idx - 1]
321
        )
322
        log_uniform = jnp.log(jax.random.uniform(subkey))
2✔
323
        do_accept: Bool[Array, " 1"] = log_uniform < ratio
2✔
324
        swapped = jnp.flip(
2✔
325
            jax.lax.dynamic_slice_in_dim(positions, idx - 1, 2, axis=0), axis=0
326
        )
327
        #        jax.debug.print("Before idx: {}, ratio: {}, idx: {}, temperature: {}, do_accept: {}", idx, ratio, log_probs,  temperatures, do_accept)
328
        #        jax.debug.print("Before {}, {}, {}", idx, positions, do_accept)
329
        positions = jax.lax.cond(
2✔
330
            do_accept,
331
            true_fun=lambda: jax.lax.dynamic_update_slice_in_dim(
332
                positions, swapped, idx - 1, axis=0
333
            ),
334
            false_fun=lambda: positions,
335
        )
336
        swapped = jnp.flip(
2✔
337
            jax.lax.dynamic_slice_in_dim(log_probs, idx - 1, 2, axis=0), axis=0
338
        )
339
        log_probs = jax.lax.cond(
2✔
340
            do_accept,
341
            true_fun=lambda: jax.lax.dynamic_update_slice_in_dim(
342
                log_probs, swapped, idx - 1, axis=0
343
            ),
344
            false_fun=lambda: log_probs,
345
        )
346
        #        jax.debug.print("compute log_prob {}", jax.vmap(logpdf, in_axes=(0, None))(positions, {}))
347
        #        jax.debug.print("new log_prob {}", log_probs)
348
        return (
2✔
349
            key,
350
            positions,
351
            log_probs,
352
            idx - 1,
353
            logpdf,
354
            temperatures,
355
            data,
356
        ), do_accept
357

358
    def _exchange(
2✔
359
        self,
360
        key: PRNGKeyArray,
361
        positions: Float[Array, "n_temps n_dims"],
362
        logpdf: TemperedPDF,
363
        temperatures: Float[Array, " n_temps"],
364
        data: dict,
365
    ) -> tuple[
366
        Float[Array, "n_temps n_dims"],
367
        Float[Array, " n_temps"],
368
        Int[Array, " n_temps - 1"],
369
    ]:
370
        """
371
        Perform exchange steps between adjacent temperatures.
372

373
        Args:
374
            key (PRNGKeyArray): jax random key for reproducibility.
375
            positions (Float[Array, "n_temps n_dims"]): Current positions for all temperatures.
376
            logpdf (TemperedPDF): The tempered log probability density function.
377
            temperatures (Float[Array, " n_temps"]): Array of temperatures.
378
            data (dict): Additional data to pass to the logpdf.
379

380
        Returns:
381
            tuple:
382
                - positions (Float[Array, "n_temps n_dims"]): Updated positions for all temperatures.
383
                - log_probs (Float[Array, " n_temps"]): Log probabilities for all temperatures.
384
                - do_accept (Int[Array, " n_temps - 1"]): Acceptance flags for each temperature.
385
        """
386

387
        if self.verbose:
2✔
388
            print("Exchanging walkers")
×
389

390
        log_probs = jax.vmap(logpdf, in_axes=(0, None))(positions, data)
2✔
391
        (key, positions, log_probs, idx, logpdf, temperatures, data), do_accept = (
2✔
392
            jax.lax.scan(
393
                self._exchange_step_body,
394
                (
395
                    key,
396
                    positions,
397
                    log_probs,
398
                    positions.shape[0] - 1,
399
                    logpdf,
400
                    temperatures,
401
                    data,
402
                ),
403
                length=positions.shape[0] - 1,
404
            )
405
        )
406
        return positions, log_probs, do_accept
2✔
407

408
    def _adapt_temperature(
2✔
409
        self,
410
        temperatures: Float[Array, " n_temps"],
411
        do_accept: Int[Array, " n_chains n_temps 1"],
412
    ) -> Float[Array, " n_temps"]:
413
        """
414
        Adapt the temperatures based on the acceptance rates.
415

416
        Args:
417
            temperatures (Float[Array, " n_temps"]): Current temperatures.
418
            do_accept (Int[Array, " n_chains n_temps 1"]): Acceptance flags for each chain and temperature.
419

420
        Returns:
421
            Float[Array, " n_temps"]: Updated temperatures.
422

423
        TODO: The adaptation now let's the temperature to go above the maximum temperature.
424
        Need to add a check to prevent this.
425
        """
426
        # Adapt the temperature based on the acceptance rate
427

428
        if self.verbose:
2✔
429
            print("Adapting temperatures")
×
430

431
        acceptance_rate = jnp.mean(do_accept, axis=0)
2✔
432
        damping_factor = (100.0 / do_accept.shape[0]) * (
2✔
433
            acceptance_rate[:-1] - acceptance_rate[1:]
434
        )
435
        new_temperatures = temperatures
2✔
436
        for i in range(1, temperatures.shape[0] - 1):
2✔
437
            new_temperatures = new_temperatures.at[i].set(
2✔
438
                new_temperatures[i - 1]
439
                + (temperatures[i] - temperatures[i - 1])
440
                * jnp.exp(damping_factor[i - 1])
441
            )
442

443
        # jax.debug.print("{} {} {} {}", temperatures, acceptance_rate, damping_factor, new_temperatures )
444
        return new_temperatures
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc