• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

kazewong / flowMC / 14196357157

01 Apr 2025 01:02PM UTC coverage: 93.21% (+0.4%) from 92.825%
14196357157

Pull #208

github

kazewong
update parallel tempering hyperparameters
Pull Request #208: 202 add parallel tempering strategy

97 of 102 new or added lines in 4 files covered. (95.1%)

10 existing lines in 3 files now uncovered.

1112 of 1193 relevant lines covered (93.21%)

1.86 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

88.64
/src/flowMC/resource/logPDF.py
1
from dataclasses import dataclass
2✔
2
from typing import Callable, Optional
2✔
3
from flowMC.resource.base import Resource
2✔
4
from jaxtyping import Array, Float, PyTree
2✔
5
import jax
2✔
6
import jax.numpy as jnp
2✔
7

8

9
@dataclass
2✔
10
class Variable:
2✔
11
    """A dataclass that holds the information of a variable in the log-pdf function.
12

13
    This main purpose of this class is to let the users name their variables,
14
    and specify whether they are continuous or not.
15
    """
16

17
    name: str
2✔
18
    continuous: bool
2✔
19

20

21
@jax.tree_util.register_pytree_node_class
2✔
22
class LogPDF(Resource):
2✔
23
    """A resource class that holds the log-pdf function.
24
    The main purpose of this class is to wrap the log-pdf function into the unified Resource interface.
25

26
    Args:
27
        log_pdf (Callable[[Float[Array, "n_dim"], PyTree], Float[Array, "1"]): The log-pdf function
28
        variables (list[Variable]): The list of variables in the log-pdf function
29
    """
30

31
    log_pdf: Callable[[Float[Array, " n_dim"], PyTree], Float[Array, "1"]]
2✔
32
    variables: list[Variable]
2✔
33

34
    @property
2✔
35
    def n_dims(self):
2✔
UNCOV
36
        return len(self.variables)
×
37

38
    def __repr__(self):
39
        return "LogPDF with " + str(self.n_dims) + " dimensions"
40

41
    def __init__(
2✔
42
        self,
43
        log_pdf: Callable[[Float[Array, " n_dim"], PyTree], Float[Array, "1"]],
44
        variables: Optional[list[Variable]] = None,
45
        n_dims: Optional[int] = None,
46
    ):
47
        """
48
        Args:
49
            log_pdf (Callable[[Float[Array, "n_dim"], PyTree], Float[Array, "1"]): The log-pdf function
50
            variables (list[Variable], optional): The list of variables in the log-pdf function. Defaults to None. n_dims must be provided if this argument is None.
51
            n_dims (int, optional): The number of dimensions of the log-pdf function. Defaults to None. If variables is provided, this argument is ignored.
52
        """
53
        self.log_pdf = log_pdf
2✔
54
        if variables is None and n_dims is not None:
2✔
55
            self.variables = [Variable("x_" + str(i), True) for i in range(n_dims)]
2✔
56
        elif variables is not None:
2✔
57
            self.variables = variables
2✔
58
        else:
UNCOV
59
            raise ValueError("Either variables or n_dims must be provided")
×
60

61
    def __call__(self, x: Float[Array, " n_dim"], data: PyTree) -> Float[Array, "1"]:
2✔
62
        return self.log_pdf(x, data)
2✔
63

64
    def print_parameters(self):
2✔
65
        print("LogPDF with variables:")
×
66
        for var in self.variables:
×
UNCOV
67
            print(var.name, var.continuous)
×
68

69
    def save_resource(self, path):
2✔
70
        raise NotImplementedError
71

72
    def load_resource(self, path):
2✔
73
        raise NotImplementedError
74

75
    def tree_flatten(self):
76
        children = ()
77
        aux_data = (self.log_pdf, self.variables)
78
        return (children, aux_data)
79

80
    @classmethod
81
    def tree_unflatten(cls, aux_data, children):
82
        return cls(aux_data[0], aux_data[1])
83

84

85
@jax.tree_util.register_pytree_node_class
2✔
86
class TemperedPDF(LogPDF):
2✔
87

88
    log_prior: Callable[[Float[Array, " n_dim"], PyTree], Float[Array, "1"]]
2✔
89

90
    def __init__(
2✔
91
        self,
92
        log_likelihood: Callable[[Float[Array, " n_dim"], PyTree], Float[Array, "1"]],
93
        log_prior: Callable[[Float[Array, " n_dim"], PyTree], Float[Array, "1"]],
94
        variables=None,
95
        n_dims=None,
96
        n_temps=5,
97
        max_temp=100,
98
    ):
99
        super().__init__(log_likelihood, variables, n_dims)
2✔
100
        self.log_prior = log_prior
2✔
101

102
    def __call__(self, x, data):
2✔
103
        temperature = data['temperature']
2✔
104
        base_pdf = super().__call__(x, data)
2✔
105
        return (1.0 / temperature) * base_pdf + self.log_prior(x, data)
2✔
106
    
107
    def original_log_pdf(self, x, data):
2✔
108
        """Returns the original log pdf (without temperature)"""
109
        return super().__call__(x, data)
2✔
110

111
    def tree_flatten(self):  # type: ignore
112
        children = ()
113
        aux_data = (self.log_pdf, self.log_prior, self.variables)
114
        return (children, aux_data)
115

116
    @classmethod
117
    def tree_unflatten(cls, aux_data, children):
118
        return cls(*aux_data, *children)
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc