• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

GW-JAX-Team / flowMC / 24136453959

08 Apr 2026 12:55PM UTC coverage: 91.714% (-0.07%) from 91.784%
24136453959

push

github

web-flow
Merge pull request #73 from GW-JAX-Team/flowMC-dev

Merge final changes before release

14 of 16 new or added lines in 5 files covered. (87.5%)

3 existing lines in 2 files now uncovered.

2081 of 2269 relevant lines covered (91.71%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

89.55
/src/flowMC/resource/model/common.py
1
from typing import Callable, List, Tuple, Optional
1✔
2

3
import equinox as eqx
1✔
4
import jax
1✔
5
import jax.numpy as jnp
1✔
6
from jaxtyping import Array, Float, Key
1✔
7
from abc import abstractmethod
1✔
8

9

10
class Bijection(eqx.Module):
1✔
11
    """Base class for bijective transformations.
12

13
    Subclasses must implement :meth:`forward` and :meth:`inverse`.
14
    The default :meth:`__call__` delegates to :meth:`forward`.
15

16
    This is an abstract template that should not be directly used.
17
    """
18

19
    @abstractmethod
20
    def __init__(self):
21
        raise NotImplementedError
22

23
    def __call__(
1✔
24
        self,
25
        x: Float[Array, " n_dim"],
26
        condition: Float[Array, " n_condition"],
27
    ) -> tuple[Float[Array, " n_dim"], Float]:
28
        """Apply the forward transformation.
29

30
        Args:
31
            x (Float[Array, "n_dim"]): Input array.
32
            condition (Float[Array, "n_condition"]): Conditioning variables.
33

34
        Returns:
35
            tuple[Float[Array, "n_dim"], Float]: Transformed output and log-det Jacobian.
36
        """
37
        return self.forward(x, condition)
1✔
38

39
    @abstractmethod
40
    def forward(
41
        self,
42
        x: Float[Array, " n_dim"],
43
        condition: Float[Array, " n_condition"],
44
    ) -> tuple[Float[Array, " n_dim"], Float]:
45
        """Transform from input space to output space.
46

47
        Args:
48
            x (Float[Array, "n_dim"]): Input array.
49
            condition (Float[Array, "n_condition"]): Conditioning variables.
50

51
        Returns:
52
            tuple[Float[Array, "n_dim"], Float]: Transformed output and log-det Jacobian.
53
        """
54
        raise NotImplementedError
55

56
    @abstractmethod
57
    def inverse(
58
        self,
59
        x: Float[Array, " n_dim"],
60
        condition: Float[Array, " n_condition"],
61
    ) -> tuple[Float[Array, " n_dim"], Float]:
62
        """Transform from output space back to input space.
63

64
        Args:
65
            x (Float[Array, "n_dim"]): Array in the output (transformed) space.
66
            condition (Float[Array, "n_condition"]): Conditioning variables.
67

68
        Returns:
69
            tuple[Float[Array, "n_dim"], Float]: Inverse output and log-det Jacobian.
70
        """
71
        raise NotImplementedError
72

73

74
class Distribution(eqx.Module):
1✔
75
    """Base class for probability distributions.
76

77
    Subclasses must implement :meth:`log_prob` and :meth:`sample`.
78
    The default :meth:`__call__` delegates to :meth:`log_prob`.
79

80
    This is an abstract template that should not be directly used.
81
    """
82

83
    @abstractmethod
84
    def __init__(self):
85
        raise NotImplementedError
86

87
    def __call__(self, x: Array, key: Optional[Key] = None) -> Array:
1✔
88
        """Evaluate the log-probability of ``x``.
89

90
        Args:
91
            x (Array): Input sample.
92
            key (Key, optional): Unused; reserved for subclass compatibility.
93

94
        Returns:
95
            Array: Log-probability of ``x``.
96
        """
UNCOV
97
        return self.log_prob(x)
×
98

99
    @abstractmethod
100
    def log_prob(self, x: Array) -> Array:
101
        raise NotImplementedError
102

103
    @abstractmethod
104
    def sample(
105
        self, rng_key: Key, n_samples: int
106
    ) -> Float[Array, "n_samples n_features"]:
107
        raise NotImplementedError
108

109

110
class MLP(eqx.Module):
1✔
111
    r"""Multilayer perceptron.
112

113
    Args:
114
        shape (List[int]): Shape of the MLP. The first element is the input dimension,
115
            the last element is the output dimension.
116
        key (Key): Random key.
117

118
    Attributes:
119
        layers (List): List of layers.
120
        activation (Callable): Activation function.
121
        use_bias (bool): Whether to use bias.
122
    """
123

124
    layers: List
1✔
125

126
    def __init__(
1✔
127
        self,
128
        shape: List[int],
129
        key: Key,
130
        scale: Float = 1e-4,
131
        activation: Callable = jax.nn.relu,
132
        use_bias: bool = True,
133
    ):
134
        self.layers = []
1✔
135
        for i in range(len(shape) - 2):
1✔
136
            key, subkey1, subkey2 = jax.random.split(key, 3)
1✔
137
            layer = eqx.nn.Linear(
1✔
138
                shape[i], shape[i + 1], key=subkey1, use_bias=use_bias
139
            )
140
            weight = jax.random.normal(subkey2, (shape[i + 1], shape[i])) * jnp.sqrt(
1✔
141
                scale / shape[i]
142
            )
143
            layer = eqx.tree_at(lambda layer: layer.weight, layer, weight)
1✔
144
            self.layers.append(layer)
1✔
145
            self.layers.append(activation)
1✔
146
        key, subkey = jax.random.split(key)
1✔
147
        self.layers.append(
1✔
148
            eqx.nn.Linear(shape[-2], shape[-1], key=subkey, use_bias=use_bias)
149
        )
150

151
    def __call__(self, x: Float[Array, " n_in"]) -> Float[Array, " n_out"]:
1✔
152
        for layer in self.layers:
1✔
153
            x = layer(x)
1✔
154
        return x
1✔
155

156
    @property
1✔
157
    def n_input(self) -> int:
1✔
158
        return self.layers[0].in_features
1✔
159

160
    @property
1✔
161
    def n_output(self) -> int:
1✔
162
        return self.layers[-1].out_features
1✔
163

164
    @property
1✔
165
    def dtype(self) -> jnp.dtype:
1✔
166
        return self.layers[0].weight.dtype
1✔
167

168

169
class MaskedCouplingLayer(Bijection):
1✔
170
    r"""Masked coupling layer.
171

172
    f(x) = (1-m)*b(x;c(m*x;z)) + m*x
173
    where b is the inner bijector, m is the mask, and c is the conditioner.
174

175
    Args:
176
        bijector (Bijection): inner bijector in the masked coupling layer.
177
        mask (Array): Mask. 0 for the input variables that are transformed,
178
            1 for the input variables that are not transformed.
179
    """
180

181
    _mask: Float[Array, " n_dim"]
1✔
182
    bijector: Bijection
1✔
183

184
    @property
1✔
185
    def mask(self) -> Float[Array, " n_dim"]:
1✔
186
        return jax.lax.stop_gradient(self._mask)
1✔
187

188
    def __init__(self, bijector: Bijection, mask: Float[Array, " n_dim"]):
1✔
189
        self.bijector = bijector
1✔
190
        self._mask = mask
1✔
191

192
    def forward(
1✔
193
        self,
194
        x: Float[Array, " n_dim"],
195
        condition: Float[Array, " n_condition"],
196
    ) -> tuple[Float[Array, " n_dim"], Float]:
197
        y, log_det = self.bijector(x, x * self.mask)  # type: ignore
1✔
198
        y = (1 - self.mask) * y + self.mask * x
1✔
199
        log_det = ((1 - self.mask) * log_det).sum()
1✔
200
        return y, log_det
1✔
201

202
    def inverse(
1✔
203
        self,
204
        x: Float[Array, " n_dim"],
205
        condition: Float[Array, " n_condition"],
206
    ) -> tuple[Float[Array, " n_dim"], Float]:
207
        y, log_det = self.bijector.inverse(x, x * self.mask)  # type: ignore
1✔
208
        y = (1 - self.mask) * y + self.mask * x
1✔
209
        log_det = ((1 - self.mask) * log_det).sum()
1✔
210
        return y, log_det
1✔
211

212

213
class MLPAffine(Bijection):
1✔
214
    scale_MLP: MLP
1✔
215
    shift_MLP: MLP
1✔
216
    dt: Float = 1
1✔
217

218
    def __init__(self, scale_MLP: MLP, shift_MLP: MLP, dt: Float = 1):
1✔
219
        self.scale_MLP = scale_MLP
1✔
220
        self.shift_MLP = shift_MLP
1✔
221
        self.dt = dt
1✔
222

223
    def __call__(
1✔
224
        self, x: Float[Array, " n_dim"], condition_x: Float[Array, " n_cond"]
225
    ) -> Tuple[Float[Array, " n_dim"], Float]:
226
        return self.forward(x, condition_x)
1✔
227

228
    def forward(
1✔
229
        self,
230
        x: Float[Array, " n_dim"],
231
        condition: Float[Array, " n_condition"],
232
    ) -> tuple[Float[Array, " n_dim"], Float]:
233
        # Note that this note output log_det as an array instead of a number.
234
        # This is because we need to sum over the log_det in the masked coupling layer.
235
        scale = jnp.tanh(self.scale_MLP(condition)) * self.dt
1✔
236
        shift = self.shift_MLP(condition) * self.dt
1✔
237
        log_det = scale
1✔
238
        y = (x + shift) * jnp.exp(scale)
1✔
239
        return y, log_det
1✔
240

241
    def inverse(
1✔
242
        self,
243
        x: Float[Array, " n_dim"],
244
        condition: Float[Array, " n_condition"],
245
    ) -> tuple[Float[Array, " n_dim"], Float]:
246
        scale = jnp.tanh(self.scale_MLP(condition)) * self.dt
1✔
247
        shift = self.shift_MLP(condition) * self.dt
1✔
248
        log_det = -scale
1✔
249
        y = x * jnp.exp(-scale) - shift
1✔
250
        return y, log_det
1✔
251

252

253
class ScalarAffine(Bijection):
1✔
254
    scale: Array
1✔
255
    shift: Array
1✔
256

257
    def __init__(self, scale: Float, shift: Float):
1✔
258
        self.scale = jnp.array(scale)
1✔
259
        self.shift = jnp.array(shift)
1✔
260

261
    def __call__(
1✔
262
        self, x: Float[Array, " n_dim"], condition_x: Float[Array, " n_cond"]
263
    ) -> Tuple[Float[Array, " n_dim"], Float]:
264
        return self.forward(x, condition_x)
1✔
265

266
    def forward(
1✔
267
        self,
268
        x: Float[Array, " n_dim"],
269
        condition: Float[Array, " n_condition"],
270
    ) -> tuple[Float[Array, " n_dim"], Float]:
271
        y = (x + self.shift) * jnp.exp(self.scale)
1✔
272
        log_det = self.scale
1✔
273
        return y, log_det
1✔
274

275
    def inverse(
1✔
276
        self,
277
        x: Float[Array, " n_dim"],
278
        condition: Float[Array, " n_condition"],
279
    ) -> tuple[Float[Array, " n_dim"], Float]:
280
        y = x * jnp.exp(-self.scale) - self.shift
1✔
281
        log_det = -self.scale
1✔
282
        return y, log_det
1✔
283

284

285
class Gaussian(Distribution):
1✔
286
    r"""Multivariate Gaussian distribution.
287

288
    Args:
289
        mean (Array): Mean.
290
        cov (Array): Covariance matrix.
291
        learnable (bool):
292
            Whether the mean and covariance matrix are learnable parameters.
293

294
    Attributes:
295
        mean (Array): Mean.
296
        cov (Array): Covariance matrix.
297
    """
298

299
    _mean: Float[Array, " n_dim"]
1✔
300
    _cov: Float[Array, "n_dim n_dim"]
1✔
301
    learnable: bool = False
1✔
302

303
    @property
1✔
304
    def mean(self) -> Float[Array, " n_dim"]:
1✔
305
        if self.learnable:
1✔
306
            return self._mean
×
307
        else:
308
            return jax.lax.stop_gradient(self._mean)
1✔
309

310
    @property
1✔
311
    def cov(self) -> Float[Array, "n_dim n_dim"]:
1✔
312
        if self.learnable:
1✔
313
            return self._cov
×
314
        else:
315
            return jax.lax.stop_gradient(self._cov)
1✔
316

317
    def __init__(
1✔
318
        self,
319
        mean: Float[Array, " n_dim"],
320
        cov: Float[Array, "n_dim n_dim"],
321
        learnable: bool = False,
322
    ):
323
        self._mean = mean
1✔
324
        self._cov = cov
1✔
325
        self.learnable = learnable
1✔
326

327
    def log_prob(self, x: Float[Array, " n_dim"]) -> Float:
1✔
328
        return jax.scipy.stats.multivariate_normal.logpdf(x, self.mean, self.cov)
1✔
329

330
    def sample(
1✔
331
        self, rng_key: Key, n_samples: int
332
    ) -> Float[Array, "n_samples n_features"]:
333
        return jax.random.multivariate_normal(
1✔
334
            rng_key, self.mean, self.cov, (n_samples,)
335
        )
336

337

338
class Composable(Distribution):
1✔
339
    distributions: list[Distribution]
1✔
340
    partitions: dict[str, tuple[int, int]]
1✔
341

342
    def __init__(self, distributions: list[Distribution], partitions: dict):
1✔
343
        self.distributions = distributions
×
344
        self.partitions = partitions
×
345

346
    def log_prob(self, x: Float[Array, " n_dim"]) -> Float:
1✔
347
        log_prob = 0
×
348
        for dist, (_, ranges) in zip(self.distributions, self.partitions.items()):
×
349
            log_prob += dist.log_prob(x[ranges[0] : ranges[1]])
×
350
        return log_prob
×
351

352
    def sample(
1✔
353
        self, rng_key: Key, n_samples: int
354
    ) -> Float[Array, "n_samples n_features"]:
355
        samples = {}
×
356
        for dist, (key, _) in zip(self.distributions, self.partitions.items()):
×
357
            rng_key, sub_key = jax.random.split(rng_key)
×
358
            samples[key] = dist.sample(sub_key, n_samples=n_samples)
×
359
        return samples  # type: ignore
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc