• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

kazewong / flowMC / 13818329754

12 Mar 2025 06:01PM UTC coverage: 81.682% (+13.8%) from 67.835%
13818329754

push

github

web-flow
Merge pull request #196 from kazewong/190-updating-documentation-to-align-with-the-latest-version-of-flowmc

190 updating documentation to align with the latest version of flowmc

38 of 65 new or added lines in 12 files covered. (58.46%)

3 existing lines in 3 files now uncovered.

1039 of 1272 relevant lines covered (81.68%)

1.63 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/src/flowMC/utils/PythonFunctionWrap.py
1
import warnings
×
2
from functools import wraps
×
NEW
3
from typing import Any, Callable, List, Tuple
×
4

5
import jax
×
6
import jax.numpy as jnp
×
7
import numpy as np
×
8
from jax import lax
×
9
from jax._src import dtypes
×
10
from jax._src.util import safe_zip
×
11
from jax.custom_batching import custom_vmap
×
12
from jax.experimental import host_callback
×
13
from jax.tree_util import tree_flatten, tree_unflatten
×
14
from jaxtyping import PyTree
×
15

16

NEW
17
Array = Any
×
18

19

20
def wrap_python_log_prob_fn(python_log_prob_fn: Callable[..., Array]):
×
21
    @custom_vmap
×
22
    @wraps(python_log_prob_fn)
×
23
    def log_prob_fn(params: Array, data: PyTree) -> Array:
×
24
        dtype = _tree_dtype(params)
×
25
        inputs = {"params": params, "data": data}
×
26
        return host_callback.call(
×
27
            python_log_prob_fn,
28
            inputs,
29
            result_shape=jax.ShapeDtypeStruct((), dtype),
30
        )
31

32
    @log_prob_fn.def_vmap
×
33
    def _(
×
34
        axis_size: int, in_batched: List[bool], params: Array, data: PyTree
35
    ) -> Tuple[Array, bool]:
36
        del axis_size, in_batched
×
37

38
        if _arraylike(params):
×
39
            flat_params = params
×
40
            eval_one = python_log_prob_fn
×
41
        else:
42
            flat_params, unravel = ravel_ensemble(params)
×
43

44
            def eval_one(x):
×
45
                return python_log_prob_fn(unravel(x))
×
46

47
        result_shape = jax.ShapeDtypeStruct((flat_params.shape[0],), flat_params.dtype)
×
48

49
        result = host_callback.call(
×
50
            lambda y: np.stack([eval_one({"params": x, "data": data}) for x in y]),
51
            flat_params,
52
            result_shape=result_shape,
53
        )
54
        return (
×
55
            result,
56
            True,
57
        )
58

59
    return log_prob_fn
×
60

61

62
def _tree_dtype(tree: PyTree) -> Any:
×
63
    leaves, _ = tree_flatten(tree)
×
NEW
64
    from_dtypes = [dtypes.dtype(leaf) for leaf in leaves]
×
65
    return dtypes.result_type(*from_dtypes)
×
66

67

68
def _arraylike(x: Array) -> bool:
×
69
    return (
×
70
        isinstance(x, np.ndarray)
71
        or isinstance(x, jnp.ndarray)
72
        or hasattr(x, "__jax_array__")
73
        or np.isscalar(x)
74
    )
75

76

77
UnravelFn = Callable[[Array], PyTree]
×
78

79
zip = safe_zip
×
80

81

82
def ravel_ensemble(coords: PyTree) -> Tuple[Array, UnravelFn]:
×
83
    leaves, treedef = tree_flatten(coords)
×
84
    flat, unravel_inner = _ravel_inner(leaves)
×
85

86
    def unravel_one(flat):
×
87
        return tree_unflatten(treedef, unravel_inner(flat))
×
88

89
    return flat, unravel_one
×
90

91

92
def _ravel_inner(lst: List[Array]) -> Tuple[Array, UnravelFn]:
×
93
    if not lst:
×
94
        return jnp.array([], jnp.float32), lambda _: []
×
NEW
95
    from_dtypes = [dtypes.dtype(leaf) for leaf in lst]
×
96
    to_dtype = dtypes.result_type(*from_dtypes)
×
97
    shapes = [jnp.shape(x)[1:] for x in lst]
×
98
    indices = np.cumsum([int(np.prod(s)) for s in shapes])
×
99

100
    if all(dt == to_dtype for dt in from_dtypes):
×
101
        del from_dtypes, to_dtype
×
102

103
        def unravel(arr: Array) -> PyTree:
×
104
            chunks = jnp.split(arr, indices[:-1])
×
105
            return [chunk.reshape(shape) for chunk, shape in zip(chunks, shapes)]
×
106

107
        def ravel(arg):
×
108
            return jnp.concatenate([jnp.ravel(e) for e in arg])
×
109

110
        raveled = jax.vmap(ravel)(lst)
×
111
        return raveled, unravel
×
112

113
    else:
114

115
        def unravel(arr: Array) -> PyTree:
×
116
            arr_dtype = dtypes.dtype(arr)
×
117
            if arr_dtype != to_dtype:
×
118
                raise TypeError(
×
119
                    f"unravel function given array of dtype {arr_dtype}, "
120
                    f"but expected dtype {to_dtype}"
121
                )
122
            chunks = jnp.split(arr, indices[:-1])
×
123
            with warnings.catch_warnings():
×
124
                warnings.simplefilter("ignore")
×
125
                return [
×
126
                    lax.convert_element_type(chunk.reshape(shape), dtype)
127
                    for chunk, shape, dtype in zip(chunks, shapes, from_dtypes)
128
                ]
129

130
        def ravel(arg):
×
131
            return jnp.concatenate(
×
132
                [jnp.ravel(lax.convert_element_type(e, to_dtype)) for e in arg]
133
            )
134

135
        raveled = jax.vmap(ravel)(lst)
×
136
        return raveled, unravel
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc