• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

kazewong / flowMC / 13655281414

04 Mar 2025 01:55PM UTC coverage: 67.835%. Remained the same
13655281414

push

github

kazewong
format doc strings

1 of 1 new or added line in 1 file covered. (100.0%)

162 existing lines in 15 files now uncovered.

987 of 1455 relevant lines covered (67.84%)

1.36 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/src/flowMC/utils/EvolutionaryOptimizer.py
1
import jax
×
2
import jax.numpy as jnp
×
3
import tqdm
×
4
from evosax import CMA_ES
×
5
from jaxtyping import PRNGKeyArray
×
6

7
"""
×
8
WARNING: This file is on the process of being deprecated.
9
Do not use this optimizer
10

11
"""
12

13

14
class EvolutionaryOptimizer:
×
15
    """A wrapper class for the evosax package. Note that we do not aim to solve any
16
    generic optimization problem, especially in a high dimension space.
17

18
    Parameters
19
    ----------
20
    ndims : int
21
        The dimension of the parameter space.
22
    popsize : int
23
        The population size of the evolutionary algorithm.
24
    verbose : bool
25
        Whether to print the progress bar.
26

27
    Attributes
28
    ----------
29
    strategy : evosax.CMA_ES
30
        The evolutionary strategy.
31
    es_params : evosax.CMA_ESParams
32
        The parameters of the evolutionary strategy.
33
    verbose : bool
34
        Whether to print the progress bar.
35

36
    Methods
37
    -------
38
    optimize(objective, bound, n_loops = 100, seed = 9527)
39
        Optimize the objective function.
40
    get_result()
41
        Get the best member and the best fitness.
42
    """
43

UNCOV
44
    def __init__(self, ndims, popsize=100, verbose=False):
×
UNCOV
45
        self.strategy = CMA_ES(num_dims=ndims, popsize=popsize, elite_ratio=0.5)
×
UNCOV
46
        self.es_params = self.strategy.default_params.replace(clip_min=0, clip_max=1)
×
47
        self.verbose = verbose
×
48
        self.history = []
×
49
        self.state = None
×
50

51
    def optimize(self, objective, bound, n_loops=100, seed=9527, keep_history_step=0):
×
52
        """Optimize the objective function.
53

54
        Parameters
55
        ----------
56
        objective : Callable
57
            The objective function, which should be implemented in JAX.
58
        bound : (2, ndims) ndarray
59
            The bound of the parameter space.
60
        n_loops : int
61
            The number of iterations.
62
        seed : int
63
            The random seed.
64

65
        Returns
66
        -------
67
        None
68
        """
UNCOV
69
        rng = jax.random.PRNGKey(seed)
×
UNCOV
70
        key, subkey = jax.random.split(rng)
×
UNCOV
71
        progress_bar = (
×
72
            tqdm.tqdm(range(n_loops), "Generation: ")
73
            if self.verbose
74
            else range(n_loops)
75
        )
UNCOV
76
        self.bound = bound
×
UNCOV
77
        self.state = self.strategy.initialize(key, self.es_params)
×
UNCOV
78
        if keep_history_step > 0:
×
UNCOV
79
            self.history = []
×
80
            for i in progress_bar:
×
81
                subkey, self.state, theta = self.optimize_step(
×
82
                    subkey, self.state, objective, bound
83
                )
84
                if i % keep_history_step == 0:
×
85
                    self.history.append(theta)
×
UNCOV
86
                if self.verbose:
×
UNCOV
87
                    progress_bar.set_description(
×
88
                        f"Generation: {i}, Fitness: {self.state.best_fitness:.4f}"
89
                    )
90
            self.history = jnp.array(self.history)
×
91
        else:
UNCOV
92
            for i in progress_bar:
×
UNCOV
93
                subkey, self.state, _ = self.optimize_step(
×
94
                    subkey, self.state, objective, bound
95
                )
96
                if self.verbose:
×
97
                    progress_bar.set_description(
×
98
                        f"Generation: {i}, Fitness: {self.state.best_fitness:.4f}"
99
                    )
100

101
    def optimize_step(self, key: PRNGKeyArray, state, objective: callable, bound):
×
UNCOV
102
        key, subkey = jax.random.split(key)
×
UNCOV
103
        x, state = self.strategy.ask(subkey, state, self.es_params)
×
UNCOV
104
        theta = x * (bound[:, 1] - bound[:, 0]) + bound[:, 0]
×
105
        fitness = objective(theta)
×
106
        state = self.strategy.tell(
×
107
            x, fitness.astype(jnp.float32), state, self.es_params
108
        )
109
        return key, state, theta
×
110

UNCOV
111
    def get_result(self):
×
112
        """Get the best member and the best fitness.
113

114
        Returns
115
        -------
116
        best_member : (ndims,) ndarray
117
            The best member.
118
        best_fitness : float
119
            The best fitness.
120
        """
121

UNCOV
122
        best_member = (
×
123
            self.state.best_member * (self.bound[:, 1] - self.bound[:, 0])
124
            + self.bound[:, 0]
125
        )
UNCOV
126
        best_fitness = self.state.best_fitness
×
127
        return best_member, best_fitness
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc