• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

GW-JAX-Team / flowMC / 20065757058

09 Dec 2025 01:48PM UTC coverage: 91.566% (-0.09%) from 91.653%
20065757058

push

github

thomasckng
refactor: replace capsys with caplog for logging in parameter print tests

1748 of 1909 relevant lines covered (91.57%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

89.8
/src/flowMC/Sampler.py
1
import jax.numpy as jnp
1✔
2
from jaxtyping import Array, Float, PRNGKeyArray
1✔
3
from typing import Optional
1✔
4
import logging
1✔
5

6
from flowMC.strategy.base import Strategy
1✔
7
from flowMC.resource.base import Resource
1✔
8
from flowMC.resource_strategy_bundle.base import ResourceStrategyBundle
1✔
9

10
logger = logging.getLogger(__name__)
1✔
11

12

13
class Sampler:
1✔
14
    """Top level API that the users primarily interact with.
15

16
    Args:
17
        n_dim (int): Dimension of the parameter space.
18
        n_chains (int): Number of chains to sample.
19
        rng_key (PRNGKeyArray): Jax PRNGKey.
20
        logpdf (Callable[[Float[Array, "n_dim"], dict], Float):
21
            Log probability function.
22
        resources (dict[str, Resource]): Resources to be used by the sampler.
23
        strategies (dict[str, Strategy]): Strategies to be used by the sampler.
24
        verbose (bool): Whether to print out progress. Defaults to False.
25
        logging (bool): Whether to log the progress. Defaults to True.
26
        outdir (str): Directory to save the logs. Defaults to "./outdir/".
27
    """
28

29
    # Essential parameters
30
    n_dim: int
1✔
31
    n_chains: int
1✔
32
    rng_key: PRNGKeyArray
1✔
33
    resources: dict[str, Resource]
1✔
34
    strategies: dict[str, Strategy]
1✔
35
    strategy_order: Optional[list[str]]
1✔
36

37
    # Logging hyperparameters
38
    verbose: bool = False
1✔
39
    logging: bool = True
1✔
40
    outdir: str = "./outdir/"
1✔
41

42
    def __init__(
1✔
43
        self,
44
        n_dim: int,
45
        n_chains: int,
46
        rng_key: PRNGKeyArray,
47
        resources: Optional[dict[str, Resource]] = None,
48
        strategies: Optional[dict[str, Strategy]] = None,
49
        strategy_order: Optional[list[str]] = None,
50
        resource_strategy_bundles: Optional[ResourceStrategyBundle] = None,
51
        **kwargs,
52
    ):
53
        # Copying input into the model
54

55
        self.n_dim = n_dim
1✔
56
        self.n_chains = n_chains
1✔
57
        self.rng_key = rng_key
1✔
58

59
        if resources is not None and strategies is not None:
1✔
60
            logger.info(
1✔
61
                "Resources and strategies provided. Ignoring resource strategy bundles."
62
            )
63
            self.resources = resources
1✔
64
            self.strategies = strategies
1✔
65
            self.strategy_order = strategy_order
1✔
66

67
        else:
68
            logger.info(
1✔
69
                "Resources or strategies not provided. Using resource strategy bundles."
70
            )
71
            if resource_strategy_bundles is None:
1✔
72
                raise ValueError(
×
73
                    "Resource strategy bundles not provided."
74
                    "Please provide either resources and strategies or resource strategy bundles."
75
                )
76
            self.resources = resource_strategy_bundles.resources
1✔
77
            self.strategies = resource_strategy_bundles.strategies
1✔
78
            self.strategy_order = resource_strategy_bundles.strategy_order
1✔
79

80
        # Set and override any given hyperparameters
81
        class_keys = list(self.__class__.__dict__.keys())
1✔
82
        for key, value in kwargs.items():
1✔
83
            if key in class_keys:
×
84
                if not key.startswith("__"):
×
85
                    setattr(self, key, value)
×
86

87
    def sample(self, initial_position: Float[Array, "n_chains n_dim"], data: dict):
1✔
88
        """Sample from the posterior using the local sampler.
89

90
        Args:
91
            initial_position (Device Array): Initial position.
92
            data (dict): Data to be used by the likelihood functions
93
        """
94

95
        initial_position = jnp.atleast_2d(initial_position)  # type: ignore
1✔
96
        rng_key = self.rng_key
1✔
97
        last_step = initial_position
1✔
98
        assert isinstance(self.strategy_order, list)
1✔
99
        for strategy in self.strategy_order:
1✔
100
            if strategy not in self.strategies:
1✔
101
                raise ValueError(
×
102
                    f"Invalid strategy name '{strategy}' provided. "
103
                    f"Available strategies are: {list(self.strategies.keys())}."
104
                )
105
            (
1✔
106
                rng_key,
107
                self.resources,
108
                last_step,
109
            ) = self.strategies[strategy](rng_key, self.resources, last_step, data)
110

111
    # TODO: Implement quick access and summary functions that operates on buffer
112

113
    def serialize(self):
1✔
114
        """Serialize the sampler object."""
115
        raise NotImplementedError
116

117
    def deserialize(self):
1✔
118
        """Deserialize the sampler object."""
119
        raise NotImplementedError
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc