• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

kazewong / flowMC / 14210129523

02 Apr 2025 02:22AM UTC coverage: 93.161% (+0.3%) from 92.825%
14210129523

Pull #208

github

kazewong
Add temperature adaption
Pull Request #208: 202 add parallel tempering strategy

100 of 101 new or added lines in 8 files covered. (99.01%)

5 existing lines in 1 file now uncovered.

1117 of 1199 relevant lines covered (93.16%)

1.86 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

89.36
/src/flowMC/Sampler.py
1
import jax.numpy as jnp
2✔
2
from jaxtyping import Array, Float, PRNGKeyArray
2✔
3
from typing import Optional
2✔
4

5
from flowMC.strategy.base import Strategy
2✔
6
from flowMC.resource.base import Resource
2✔
7
from flowMC.resource_strategy_bundles import ResourceStrategyBundle
2✔
8

9

10
class Sampler:
2✔
11
    """Top level API that the users primarily interact with.
12

13
    Args:
14
        n_dim (int): Dimension of the parameter space.
15
        n_chains (int): Number of chains to sample.
16
        rng_key (PRNGKeyArray): Jax PRNGKey.
17
        logpdf (Callable[[Float[Array, "n_dim"], dict], Float):
18
            Log probability function.
19
        resources (dict[str, Resource]): Resources to be used by the sampler.
20
        strategies (dict[str, Strategy]): Strategies to be used by the sampler.
21
        verbose (bool): Whether to print out progress. Defaults to False.
22
        logging (bool): Whether to log the progress. Defaults to True.
23
        outdir (str): Directory to save the logs. Defaults to "./outdir/".
24
    """
25

26
    # Essential parameters
27
    n_dim: int
2✔
28
    n_chains: int
2✔
29
    rng_key: PRNGKeyArray
2✔
30
    resources: dict[str, Resource]
2✔
31
    strategies: dict[str, Strategy]
2✔
32
    strategy_order: Optional[list[str]]
2✔
33

34
    # Logging hyperparameters
35
    verbose: bool = False
2✔
36
    logging: bool = True
2✔
37
    outdir: str = "./outdir/"
2✔
38

39
    def __init__(
2✔
40
        self,
41
        n_dim: int,
42
        n_chains: int,
43
        rng_key: PRNGKeyArray,
44
        resources: None | dict[str, Resource] = None,
45
        strategies: None | dict[str, Strategy] = None,
46
        strategy_order: None | list[str] = None,
47
        resource_strategy_bundles: None | ResourceStrategyBundle = None,
48
        **kwargs,
49
    ):
50
        # Copying input into the model
51

52
        self.n_dim = n_dim
2✔
53
        self.n_chains = n_chains
2✔
54
        self.rng_key = rng_key
2✔
55

56
        if resources is not None and strategies is not None:
2✔
57
            print(
2✔
58
                "Resources and strategies provided. Ignoring resource strategy bundles."
59
            )
60
            self.resources = resources
2✔
61
            self.strategies = strategies
2✔
62
            self.strategy_order = strategy_order
2✔
63

64
        else:
65
            print(
2✔
66
                "Resources or strategies not provided. Using resource strategy bundles."
67
            )
68
            if resource_strategy_bundles is None:
2✔
69
                raise ValueError(
×
70
                    "Resource strategy bundles not provided."
71
                    "Please provide either resources and strategies or resource strategy bundles."
72
                )
73
            self.resources = resource_strategy_bundles.resources
2✔
74
            self.strategies = resource_strategy_bundles.strategies
2✔
75
            self.strategy_order = resource_strategy_bundles.strategy_order
2✔
76

77
        # Set and override any given hyperparameters
78
        class_keys = list(self.__class__.__dict__.keys())
2✔
79
        for key, value in kwargs.items():
2✔
80
            if key in class_keys:
×
81
                if not key.startswith("__"):
×
82
                    setattr(self, key, value)
×
83

84
    def sample(self, initial_position: Float[Array, "n_chains n_dim"], data: dict):
2✔
85
        """Sample from the posterior using the local sampler.
86

87
        Args:
88
            initial_position (Device Array): Initial position.
89
            data (dict): Data to be used by the likelihood functions
90
        """
91

92
        initial_position = jnp.atleast_2d(initial_position)  # type: ignore
2✔
93
        rng_key = self.rng_key
2✔
94
        last_step = initial_position
2✔
95
        assert isinstance(self.strategy_order, list)
2✔
96
        for strategy in self.strategy_order:
2✔
97
            if strategy not in self.strategies:
2✔
NEW
98
                raise ValueError(
×
99
                    f"Invalid strategy name '{strategy}' provided. "
100
                    f"Available strategies are: {list(self.strategies.keys())}."
101
                )
102
            (
2✔
103
                rng_key,
104
                self.resources,
105
                last_step,
106
            ) = self.strategies[
107
                strategy
108
            ](rng_key, self.resources, last_step, data)
109

110
    # TODO: Implement quick access and summary functions that operates on buffer
111

112
    def serialize(self):
2✔
113
        """Serialize the sampler object."""
114
        raise NotImplementedError
115

116
    def deserialize(self):
2✔
117
        """Deserialize the sampler object."""
118
        raise NotImplementedError
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc