• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

HDembinski / zenflow / 9484327026

12 Jun 2024 02:20PM UTC coverage: 91.585% (-0.4%) from 92.033%
9484327026

Pull #4

github

web-flow
Merge 2648be1ef into 97e6e7292
Pull Request #4: bounds argument for ShiftBounds

154 of 168 new or added lines in 5 files covered. (91.67%)

8 existing lines in 2 files now uncovered.

468 of 511 relevant lines covered (91.59%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

75.0
/src/zenflow/flow.py
1
"""The Flow class which implements a trainable conditional normalizing flow."""
2

3
from typing import Union, Optional
1✔
4
from flax.typing import Array
1✔
5

6
import jax.numpy as jnp
1✔
7
import jax
1✔
8

9
from .distributions import Distribution, Beta
1✔
10
from .bijectors import Bijector, Chain
1✔
11
from flax import linen as nn
1✔
12

13
__all__ = ["Flow"]
1✔
14

15

16
class Flow(nn.Module):
1✔
17
    """A conditional normalizing flow."""
18

19
    bijector: Bijector
1✔
20
    latent: Distribution = Beta()
1✔
21

22
    @nn.nowrap
1✔
23
    def _normalize_c(self, x: Array, c: Optional[Array]):
1✔
24
        if c is None:
1✔
25
            c = jnp.zeros((x.shape[0], 0))
1✔
26
        elif c.ndim == 1:
1✔
27
            c = c.reshape(-1, 1)
1✔
28
        return c
1✔
29

30
    def __call__(
1✔
31
        self,
32
        x: Array,
33
        c: Optional[Array] = None,
34
        *,
35
        train: bool = False,
36
    ) -> Array:
37
        """
38
        Return log-likelihood of the samples.
39

40
        Parameters
41
        ----------
42
        x : Array of shape (N, D)
43
            N samples from a D-dimensional distribution. It is not necessary to
44
            normalize this distribution or to transform it to look gaussian, but doing
45
            so might accelerate convergence.
46
        c : Array of shape (N, K) or None
47
            N values from a K-dimensional vector of variables which determines the shape
48
            of the D-dimensional distribution.
49
        train : bool, optional (default = False)
50
            Whether to run in training mode (update BatchNorm statistics, etc.).
51

52
        """
53
        c = self._normalize_c(x, c)
1✔
54
        x, log_det = self.bijector(x, c, train)
1✔
55
        log_prob = self.latent.log_prob(x) + log_det
1✔
56
        log_prob = jnp.nan_to_num(log_prob, nan=-jnp.inf)
1✔
57
        return log_prob
1✔
58

59
    def sample(
1✔
60
        self,
61
        conditions_or_size: Union[Array, int],
62
        *,
63
        seed: int = 0,
64
    ) -> Array:
65
        """
66
        Return samples from the learned distribution.
67

68
        Parameters
69
        ----------
70
        conditions_or_size: Array of shape (N, K) or int
71
            If the distribution depends on a vector of conditional variables, you need
72
            to pass one vector here for each random sample that should be generated. If
73
            the distribution does not depend on conditional variables, you can directly
74
            pass the number of random samples here that should be generated.
75
        seed: int (default = 0)
76
            Seed to use for generating samples.
77

78
        """
79
        if isinstance(conditions_or_size, int):
1✔
80
            size = conditions_or_size
1✔
81
            c = jnp.zeros((size, 0))
1✔
82
        else:
83
            size = conditions_or_size.shape[0]
1✔
84
            c = conditions_or_size
1✔
85
            if c.ndim == 1:
1✔
86
                c = c.reshape(-1, 1)
1✔
87
        x = self.latent.sample(size, jax.random.PRNGKey(seed))
1✔
88
        x = self.bijector.inverse(x, c)
1✔
89
        return x
1✔
90

91
    def _steps(self, x, c: Optional[Array] = None, *, inverse: bool = False):
1✔
92
        if not isinstance(self.bijector, Chain):
×
93
            raise ValueError("only for Chain bijector")
×
94

95
        c = self._normalize_c(x, c)
×
96

97
        results = []
×
98
        if inverse:
×
99
            for bijector in self.bijector[::-1]:
×
100
                x = bijector.inverse(x, c)
×
101
                results.append(x)
×
102
        else:
UNCOV
103
            for bijector in self.bijector:
×
UNCOV
104
                x, _ = bijector(x, c, False)
×
UNCOV
105
                results.append(x)
×
UNCOV
106
        return results
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc