• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

kazewong / flowMC / 13818329754

12 Mar 2025 06:01PM UTC coverage: 81.682% (+13.8%) from 67.835%
13818329754

push

github

web-flow
Merge pull request #196 from kazewong/190-updating-documentation-to-align-with-the-latest-version-of-flowmc

190 updating documentation to align with the latest version of flowmc

38 of 65 new or added lines in 12 files covered. (58.46%)

3 existing lines in 3 files now uncovered.

1039 of 1272 relevant lines covered (81.68%)

1.63 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

85.71
/src/flowMC/Sampler.py
1
import jax.numpy as jnp
2✔
2
from jaxtyping import Array, Float, PRNGKeyArray
2✔
3

4
from flowMC.strategy.base import Strategy
2✔
5
from flowMC.resource.base import Resource
2✔
6
from flowMC.resource_strategy_bundles import ResourceStrategyBundle
2✔
7

8

9
class Sampler:
2✔
10
    """Top level API that the users primarily interact with.
11

12
    Args:
13
        n_dim (int): Dimension of the parameter space.
14
        n_chains (int): Number of chains to sample.
15
        rng_key (PRNGKeyArray): Jax PRNGKey.
16
        logpdf (Callable[[Float[Array, "n_dim"], dict], Float):
17
            Log probability function.
18
        resources (dict[str, Resource]): Resources to be used by the sampler.
19
        strategies (list[Strategy]): List of strategies to be used by the sampler.
20
        verbose (bool): Whether to print out progress. Defaults to False.
21
        logging (bool): Whether to log the progress. Defaults to True.
22
        outdir (str): Directory to save the logs. Defaults to "./outdir/".
23
    """
24

25
    # Essential parameters
26
    n_dim: int
2✔
27
    n_chains: int
2✔
28
    rng_key: PRNGKeyArray
2✔
29
    resources: dict[str, Resource]
2✔
30
    strategies: list[Strategy]
2✔
31

32
    # Logging hyperparameters
33
    verbose: bool = False
2✔
34
    logging: bool = True
2✔
35
    outdir: str = "./outdir/"
2✔
36

37
    def __init__(
2✔
38
        self,
39
        n_dim: int,
40
        n_chains: int,
41
        rng_key: PRNGKeyArray,
42
        resources: None | dict[str, Resource] = None,
43
        strategies: None | list[Strategy] = None,
44
        resource_strategy_bundles: None | ResourceStrategyBundle = None,
45
        **kwargs,
46
    ):
47
        # Copying input into the model
48

49
        self.n_dim = n_dim
2✔
50
        self.n_chains = n_chains
2✔
51
        self.rng_key = rng_key
2✔
52

53
        if resources is not None and strategies is not None:
2✔
54
            print(
2✔
55
                "Resources and strategies provided. Ignoring resource strategy bundles."
56
            )
57
            self.resources = resources
2✔
58
            self.strategies = strategies
2✔
59
        else:
60
            print(
2✔
61
                "Resources or strategies not provided. Using resource strategy bundles."
62
            )
63
            if resource_strategy_bundles is None:
2✔
NEW
64
                raise ValueError(
×
65
                    "Resource strategy bundles not provided."
66
                    "Please provide either resources and strategies or resource strategy bundles."
67
                )
68
            self.resources = resource_strategy_bundles.resources
2✔
69
            self.strategies = resource_strategy_bundles.strategies
2✔
70

71
        # Set and override any given hyperparameters
72
        class_keys = list(self.__class__.__dict__.keys())
2✔
73
        for key, value in kwargs.items():
2✔
74
            if key in class_keys:
×
75
                if not key.startswith("__"):
×
76
                    setattr(self, key, value)
×
77

78
    def sample(self, initial_position: Float[Array, "n_chains n_dim"], data: dict):
2✔
79
        """Sample from the posterior using the local sampler.
80

81
        Args:
82
            initial_position (Device Array): Initial position.
83
            data (dict): Data to be used by the likelihood functions
84
        """
85

86
        initial_position = jnp.atleast_2d(initial_position)  # type: ignore
2✔
87
        rng_key = self.rng_key
2✔
88
        last_step = initial_position
2✔
89
        for strategy in self.strategies:
2✔
90
            (
2✔
91
                rng_key,
92
                self.resources,
93
                last_step,
94
            ) = strategy(rng_key, self.resources, last_step, data)
95

96
    # TODO: Implement quick access and summary functions that operates on buffer
97

98
    def serialize(self):
2✔
99
        """Serialize the sampler object."""
NEW
100
        pass
×
101

102
    def deserialize(self):
2✔
103
        """Deserialize the sampler object."""
NEW
104
        pass
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc