• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

kazewong / flowMC / 14160944599

30 Mar 2025 11:18PM UTC coverage: 93.205% (+0.04%) from 93.165%
14160944599

push

github

kazewong
Refactor RQSpline_MALA_Bundle to use a dictionary format for strategies and add strategy order

1 of 1 new or added line in 1 file covered. (100.0%)

6 existing lines in 2 files now uncovered.

1111 of 1192 relevant lines covered (93.2%)

1.86 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.11
/src/flowMC/Sampler.py
1
import jax.numpy as jnp
2✔
2
from jaxtyping import Array, Float, PRNGKeyArray
2✔
3
from typing import Optional
2✔
4

5
from flowMC.strategy.base import Strategy
2✔
6
from flowMC.resource.base import Resource
2✔
7
from flowMC.resource_strategy_bundles import ResourceStrategyBundle
2✔
8

9

10
class Sampler:
2✔
11
    """Top level API that the users primarily interact with.
12

13
    Args:
14
        n_dim (int): Dimension of the parameter space.
15
        n_chains (int): Number of chains to sample.
16
        rng_key (PRNGKeyArray): Jax PRNGKey.
17
        logpdf (Callable[[Float[Array, "n_dim"], dict], Float):
18
            Log probability function.
19
        resources (dict[str, Resource]): Resources to be used by the sampler.
20
        strategies (dict[str, Strategy]): Strategies to be used by the sampler.
21
        verbose (bool): Whether to print out progress. Defaults to False.
22
        logging (bool): Whether to log the progress. Defaults to True.
23
        outdir (str): Directory to save the logs. Defaults to "./outdir/".
24
    """
25

26
    # Essential parameters
27
    n_dim: int
2✔
28
    n_chains: int
2✔
29
    rng_key: PRNGKeyArray
2✔
30
    resources: dict[str, Resource]
2✔
31
    strategies: dict[str, Strategy]
2✔
32
    strategy_order: Optional[list[str]]
2✔
33

34
    # Logging hyperparameters
35
    verbose: bool = False
2✔
36
    logging: bool = True
2✔
37
    outdir: str = "./outdir/"
2✔
38

39
    def __init__(
2✔
40
        self,
41
        n_dim: int,
42
        n_chains: int,
43
        rng_key: PRNGKeyArray,
44
        resources: None | dict[str, Resource] = None,
45
        strategies: None | dict[str, Strategy] = None,
46
        strategy_order: None | list[str] = None,
47
        resource_strategy_bundles: None | ResourceStrategyBundle = None,
48
        **kwargs,
49
    ):
50
        # Copying input into the model
51

52
        self.n_dim = n_dim
2✔
53
        self.n_chains = n_chains
2✔
54
        self.rng_key = rng_key
2✔
55

56
        if resources is not None and strategies is not None:
2✔
57
            print(
2✔
58
                "Resources and strategies provided. Ignoring resource strategy bundles."
59
            )
60
            self.resources = resources
2✔
61
            self.strategies = strategies
2✔
62
            self.strategy_order = strategy_order
2✔
63

64
        else:
65
            print(
2✔
66
                "Resources or strategies not provided. Using resource strategy bundles."
67
            )
68
            if resource_strategy_bundles is None:
2✔
UNCOV
69
                raise ValueError(
×
70
                    "Resource strategy bundles not provided."
71
                    "Please provide either resources and strategies or resource strategy bundles."
72
                )
73
            self.resources = resource_strategy_bundles.resources
2✔
74
            self.strategies = resource_strategy_bundles.strategies
2✔
75
            self.strategy_order = resource_strategy_bundles.strategy_order
2✔
76

77
        # Set and override any given hyperparameters
78
        class_keys = list(self.__class__.__dict__.keys())
2✔
79
        for key, value in kwargs.items():
2✔
UNCOV
80
            if key in class_keys:
×
UNCOV
81
                if not key.startswith("__"):
×
UNCOV
82
                    setattr(self, key, value)
×
83

84
    def sample(self, initial_position: Float[Array, "n_chains n_dim"], data: dict):
2✔
85
        """Sample from the posterior using the local sampler.
86

87
        Args:
88
            initial_position (Device Array): Initial position.
89
            data (dict): Data to be used by the likelihood functions
90
        """
91

92
        initial_position = jnp.atleast_2d(initial_position)  # type: ignore
2✔
93
        rng_key = self.rng_key
2✔
94
        last_step = initial_position
2✔
95
        assert isinstance(self.strategy_order, list)
2✔
96
        for strategy in self.strategy_order:
2✔
97
            (
2✔
98
                rng_key,
99
                self.resources,
100
                last_step,
101
            ) = self.strategies[
102
                strategy
103
            ](rng_key, self.resources, last_step, data)
104

105
    # TODO: Implement quick access and summary functions that operates on buffer
106

107
    def serialize(self):
2✔
108
        """Serialize the sampler object."""
109
        raise NotImplementedError
110

111
    def deserialize(self):
2✔
112
        """Deserialize the sampler object."""
113
        raise NotImplementedError
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc