• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

globus-labs / cascade / 18732618520

22 Oct 2025 11:21PM UTC coverage: 25.34% (-70.0%) from 95.34%
18732618520

Pull #70

github

miketynes
fix init, logging init, waiting
Pull Request #70: Academy proto

261 of 1030 relevant lines covered (25.34%)

0.25 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

27.23
/cascade/proxima/__init__.py
1
"""ASE-compatible implementations of online learning strategy of
2
`Zamora et al. <https://dl.acm.org/doi/abs/10.1145/3447818.3460370>`_, Proxima."""
3
from collections import deque
1✔
4
from typing import List, Optional, Any
1✔
5
from pathlib import Path
1✔
6
from random import random
1✔
7
import logging
1✔
8
import json
1✔
9

10
import numpy as np
1✔
11
import pandas as pd
1✔
12
from ase.calculators.calculator import Calculator, all_changes, all_properties
1✔
13
from ase.db import connect
1✔
14

15
from cascade.learning.base import BaseLearnableForcefield
1✔
16
from cascade.calculator import EnsembleCalculator
1✔
17
from cascade.utils import to_voigt
1✔
18

19
logger = logging.getLogger(__name__)
1✔
20

21

22
class SerialLearningCalculator(Calculator):
1✔
23
    """A Calculator which switches between a physics-based calculator and one
24
    being trained to emulate it.
25

26
    Determines when to switch between the physics and learnable calculator based
27
    on an uncertainty metric from the learnable calculator.
28

29
    Will run `history_length` steps under physics before considering surrogate
30

31
    Switching can be applied smoothly by mixing of physics-
32
    and surrogate-derived energies, forces, and stresses.
33
    The calculator will move between full surrogate
34
    and full physics at a rate controlled by the ``n_blending_steps`` parameter.
35

36
    Parameters for the calculator are:
37

38
    target_calc: BaseCalculator
39
        A physics-based calculator used to provide training data for learnable
40
    learner: BaseLearnableForcefield
41
        A class used to train the forcefield and generate an updated calculator
42
    models: list of State associated with learner
43
        One or more set of objects that define the architecture and weights of the
44
        surrogate model. These weights are used by the learner.
45
    device: str
46
        Device used for running the learned surrogates
47
    train_kwargs: dict
48
        Dictionary of parameters passed to the train function of the learner
49
    train_freq: int
50
        After how many new data points to retrain the surrogate model
51
    train_max_size: int, optional
52
        Maximum size of training set to use when updating model. Set to ``None`` to use all data
53
    train_recency_bias: float
54
        Bias towards selecting newer points if using only a subset of the available training data.
55
        Weights will be assigned to each point in a geometric series such that the most recent
56
        point is ``train_recency_bias`` times more likely to be selected than the least recent.
57
    target_ferr: float
58
        Target maximum difference between the forces predicted by the target
59
        calculator and the learnable surrogate
60
    train_from_original: bool
61
        Whether to use the original models provided when creating the class as the starting
62
        point for training rather than the models produced from the most-recent training.
63
        The calculator will preserve the original models only if ``True``.
64
    history_length: int
65
        The number of previous observations of the error between target and surrogate
66
        function to use when establishing a link between uncertainty metric
67
        and the maximum observed error. Will run exactly this number of target
68
        calculations before considering using the surrogate
69
    min_target_fraction: float
70
        Minimum fraction of timesteps to run the target function.
71
        This value is used as the probability of running the target function
72
        even if it need not be used based on the UQ metric.
73
    n_blending_steps: int
74
        How many timesteps to smoothly combine target and surrogate forces.
75
        When the threshold is satisfied we apply an increasing mixture of ML and
76
        target forces.
77
    db_path: Path or str
78
        Database in which to store the results of running the target calculator,
79
        which are used to train the surrogate model
80
    log_path: Path or str or None
81
        Optional. Path to a directory in which to write model files and training logs.
82
        Writes logs after each re-training event.
83
    """
84

85
    default_parameters = {
1✔
86
        'target_calc': None,
87
        'learner': None,
88
        'models': None,
89
        'device': 'cpu',
90
        'train_kwargs': {'num_epochs': 8},
91
        'train_freq': 1,
92
        'train_max_size': None,
93
        'train_recency_bias': 1.,
94
        'train_from_original': False,
95
        'target_ferr': 0.1,  # TODO (wardlt): Make the error metric configurable
96
        'min_target_fraction': 0.,
97
        'n_blending_steps': 0,
98
        'history_length': 8,
99
        'db_path': 'proxima.db',
100
        'log_path': None
101
    }
102

103
    train_logs: Optional[List[pd.DataFrame]] = None
1✔
104
    """Logs from the most recent model training"""
1✔
105
    surrogate_calc: Optional[EnsembleCalculator] = None
1✔
106
    """Cache for the surrogate calculator"""
1✔
107
    error_history: Optional[deque[tuple[float, float]]] = None
1✔
108
    """History of pairs of the uncertainty metric and observed error"""
1✔
109
    alpha: Optional[float] = None
1✔
110
    """Coefficient which relates distrust metric and observed error"""
1✔
111
    threshold: Optional[float] = None
1✔
112
    """Current threshold for the uncertainty metric beyond which the target calculator will be used"""
1✔
113
    used_surrogate: Optional[bool] = None
1✔
114
    """Whether the last invocation used the surrogate model"""
1✔
115
    new_points: int = 0
1✔
116
    """How many new points have been acquired since the last model update"""
1✔
117
    total_invocations: int = 0
1✔
118
    """Total number of calls to the calculator"""
1✔
119
    target_invocations: int = 0
1✔
120
    """Total number of calls to the target calculator"""
1✔
121
    blending_step: np.int64 = np.int_(0)
1✔
122
    """Ranges from 0 to n_blending_steps, corresponding to
1✔
123
    full surrogate and full physics, respectively"""
124
    lambda_target: float = 1.
1✔
125
    """Ranges from 0-1, describing mixture between surrogate and physics"""
1✔
126
    model_version: int = 0
1✔
127
    """How many times the model has been retrained"""
1✔
128
    models: Optional[list] = None
1✔
129
    """Ensemble of models from the latest training invocation. The same as ``parameters['models']`` if `train_from_original`"""
1✔
130

131
    def set(self, **kwargs):
1✔
132
        # TODO (wardlt): Fix ASE such that it does not try to do a numpy comparison on everything
133
        self.parameters.update(kwargs)
×
134
        self.reset()
×
135

136
    @property
1✔
137
    def implemented_properties(self) -> List[str]:
1✔
138
        return self.parameters['target_calc'].implemented_properties
×
139

140
    @property
1✔
141
    def learner(self) -> BaseLearnableForcefield:
1✔
142
        return self.parameters['learner']
×
143

144
    @staticmethod
1✔
145
    def smoothing_function(x):
1✔
146
        """Smoothing used for blending surrogate with physics"""
147
        return 0.5 * ((np.cos(np.pi * x)) + 1)
×
148

149
    def retrain_surrogate(self):
1✔
150
        """Retrain the surrogate models using the currently-available data"""
151
        # Start with the models set as the originals
152
        self.models = self.parameters['models'].copy()
×
153

154
        # Load in the data from the db
155
        db_path = self.parameters['db_path']
×
156
        if not Path(db_path).is_file():
×
157
            logger.debug(f'No data at {db_path} yet')
×
158
            return
×
159

160
        # Retrieve the data such that the oldest training entry is first
161
        with connect(db_path) as db:
×
162
            all_atoms = [a for a in db.select('', sort='-age')]
×
163
            assert len(all_atoms) < 2 or all_atoms[0].ctime < all_atoms[1].ctime, 'Logan got the sort order backwards'
×
164
            all_atoms = [a.toatoms() for a in all_atoms]
×
165
        logger.info(f'Loaded {len(all_atoms)} from {db_path} for retraining {len(self.parameters["models"])} models')
×
166
        if len(all_atoms) < 10:
×
167
            logger.info('Too few entries to retrain. Skipping')
×
168
            return
×
169

170
        # Determine where the updated models will be stored
171
        model_list = self.models = (
×
172
            [None] * len(self.parameters['models'])  # Create a new list
173
            if self.parameters['train_from_original'] else
174
            self.parameters['models']  # Edit it in place
175
        )
176

177
        # Train each model using a different, randomly-selected subset of the data
178
        self.train_logs = []
×
179
        for i, model_msg in enumerate(self.parameters['models']):
×
180
            # Assign splits such that the same entries do not switch between train/validation as test grows
181
            rng = np.random.RandomState(i)
×
182
            is_train = rng.uniform(0, 1, size=(len(all_atoms),)) > 0.1  # TODO (wardlt): Make this configurable
×
183
            train_atoms = [all_atoms[i] for i in np.where(is_train)[0]]  # Where preserves sort
×
184
            valid_atoms = [all_atoms[i] for i in np.where(np.logical_not(is_train))[0]]
×
185

186
            # Downselect training set if it is larger than the fixed maximum
187
            train_max_size = self.parameters['train_max_size']
×
188
            if train_max_size is not None and len(train_atoms) > train_max_size:
×
189
                # Decrease the validation size proportionally
190
                valid_size = train_max_size * len(valid_atoms) // len(train_atoms)
×
191

192
                train_weights = np.geomspace(1, self.parameters['train_recency_bias'], len(train_atoms))
×
193
                train_ids = rng.choice(len(train_atoms), size=(train_max_size,), p=train_weights / train_weights.sum(), replace=False)
×
194
                train_atoms = [train_atoms[i] for i in train_ids]
×
195

196
                if valid_size > 0:
×
197
                    valid_weights = np.geomspace(1, self.parameters['train_recency_bias'], len(valid_atoms))
×
198
                    valid_ids = rng.choice(len(valid_atoms), size=(valid_size,), p=valid_weights / valid_weights.sum(), replace=False)
×
199
                    valid_atoms = [valid_atoms[i] for i in valid_ids]
×
200

201
            logger.debug(f'Training model {i} on {len(train_atoms)} atoms and validating on {len(valid_atoms)}')
×
202
            new_model_msg, log = self.learner.train(model_msg, train_atoms, valid_atoms, **self.parameters['train_kwargs'])
×
203
            model_list[i] = new_model_msg
×
204
            self.train_logs.append(log)
×
205
            logger.debug(f'Finished training model {i}')
×
206
        self.model_version += 1
×
207

208
    def write_log_to_dir(self, log_dir: Path | None = None):
1✔
209
        """Write the current proxima state to a logging directory
210

211
        Args:
212
            log_dir: Path to the output directory. Use the one
213
        """
214

215
        # Output dir
216
        if log_dir is not None:
×
217
            out_dir = log_dir
×
218
        elif 'log_dir' in self.parameters:
×
219
            out_dir = self.parameters['log_dir']
×
220
        else:
221
            raise ValueError('Logging requires either setting the `log_dir` parameter of the calculator, '
×
222
                             'or supplying one to this function.')
223
        out_dir = Path(out_dir)
×
224

225
        state = self.get_state()
×
226

227
        # Make the output directory
228
        out_dir.mkdir(exist_ok=True, parents=True)
×
229
        for i, model in enumerate(state.pop('models', [])):
×
230
            out_dir.joinpath(f'model_{i}.bin').write_bytes(model)
×
231

232
        # Save the training log
233
        for i, log in enumerate(state.pop('train_logs') or []):
×
234
            log.to_csv(out_dir.joinpath(f'train-log_{i}.csv'), index=False)
×
235

236
        with out_dir.joinpath('proxima.json').open('w') as fp:
×
237
            json.dump(state, fp, indent=2)
×
238

239
    def calculate(
1✔
240
            self, atoms=None, properties=all_properties, system_changes=all_changes
241
    ):
242
        super().calculate(atoms, properties, system_changes)
×
243

244
        # Start by running an ensemble of surrogate models
245
        if self.surrogate_calc is None:
×
246
            self.retrain_surrogate()
×
247
            self.surrogate_calc = EnsembleCalculator(
×
248
                calculators=[self.learner.make_calculator(m, self.parameters['device']) for m in self.models]
249
            )
250
            if 'log_dir' in self.parameters:  # Log if desired
×
251
                self.write_log_to_dir()
×
252
        self.surrogate_calc.calculate(atoms, properties + ['forces'], system_changes)  # Make sure forces are computed too
×
253

254
        # Compute an uncertainty metric for the ensemble model
255
        #  We use, for now, the maximum mean difference in force prediction for over all atoms.
256
        forces_ens = self.surrogate_calc.results['forces_ens']
×
257
        forces_diff = np.linalg.norm(forces_ens - self.surrogate_calc.results['forces'][None, :, :], axis=-1).mean(axis=0)  # Mean diff per atom
×
258
        unc_metric = forces_diff.max()
×
259
        logger.debug(f'Computed the uncertainty metric for the model to be: {unc_metric:.2e}')
×
260

261
        # Check whether to use the result from the surrogate
262
        uq_small_enough = self.threshold is not None and unc_metric < self.threshold
×
263
        self.used_surrogate = uq_small_enough and (random() > self.parameters['min_target_fraction'])
×
264
        self.total_invocations += 1
×
265

266
        # Track blending parameters for surrogate/target
267
        increment = +1 if self.used_surrogate else -1
×
268
        self.blending_step = np.clip(self.blending_step + increment, 0, self.parameters['n_blending_steps'])
×
269
        self.lambda_target = self.smoothing_function(self.blending_step / self.parameters['n_blending_steps'])
×
270

271
        # Case: fully use the surrogate
272
        if self.used_surrogate and self.blending_step == self.parameters['n_blending_steps']:
×
273
            logger.debug(f'The uncertainty metric is low enough ({unc_metric:.2e} < {self.threshold:.2e}). Using the surrogate result.')
×
274
            self.results = self.surrogate_calc.results.copy()
×
275
            return
×
276

277
        # If not, run the target calculator and use that result
278
        target_calc: Calculator = self.parameters['target_calc']
×
279
        target_calc.calculate(atoms, properties, system_changes)
×
280
        self.target_invocations += 1
×
281

282
        if self.blending_step > 0:
×
283
            # return a blend if appropriate
284
            results_target = target_calc.results
×
285
            results_surrogate = self.surrogate_calc.results
×
286
            self.results = {}
×
287
            for k in results_surrogate.keys():
×
288
                if k in results_target.keys():  # blend on the intersection of keys
×
289
                    r_target, r_surrogate = results_target[k], results_surrogate[k]
×
290
                    #  handle differences in voigt vs (3,3) stress convention
291
                    if k == 'stress' and r_target.shape != r_surrogate.shape:
×
292
                        r_target, r_surrogate = map(to_voigt, [r_target, r_surrogate])
×
293
                    self.results[k] = self.lambda_target * r_target + (1 - self.lambda_target) * r_surrogate
×
294
                else:
295
                    # the surrogate may have some extra results which we store
296
                    self.results[k] = results_surrogate[k]
×
297
        else:
298
            # otherwise just return the target
299
            self.results = target_calc.results.copy()
×
300

301
        # Increment the training set with this new result
302
        db_atoms = atoms.copy()
×
303
        db_atoms.calc = target_calc
×
304
        with connect(self.parameters['db_path']) as db:
×
305
            db.write(db_atoms)
×
306

307
        # Reset the model if the training frequency has been reached
308
        surrogate_forces = self.surrogate_calc.results['forces']
×
309
        self.new_points = (self.new_points + 1) % self.parameters['train_freq']
×
310
        if self.new_points == 0:
×
311
            self.surrogate_calc = None
×
312

313
        # Update the alpha parameter, which relates uncertainty and observed error
314
        #  See Section 3.2 from https://dl.acm.org/doi/abs/10.1145/3447818.3460370
315
        #  Main difference: We do not fit an intercept when estimating \alpha
316
        actual_err = np.linalg.norm(target_calc.results['forces'] - surrogate_forces, axis=-1).max()
×
317
        if self.error_history is None:
×
318
            self.error_history = deque(maxlen=self.parameters['history_length'])
×
319
        self.error_history.append((float(unc_metric), float(actual_err)))
×
320

321
        if len(self.error_history) < self.parameters['history_length']:
×
322
            logger.debug(f'Too few entries in training history. {len(self.error_history)} < {self.parameters["history_length"]}')
×
323
            return
×
324
        uncert_metrics, obs_errors = zip(*self.error_history)
×
325

326
        # Special case: uncertainty metrics are all zero. Happens when using the same pre-trained weights for whole ensemble.
327
        all_zero = np.allclose(uncert_metrics, 0.)
×
328
        if all_zero:
×
329
            logger.debug('All uncertainty metrics are zero. Setting threshold to zero')
×
330
            self.threshold = 0.
×
331
            return
×
332

333
        many_alphas = np.true_divide(obs_errors, np.clip(uncert_metrics, 1e-6, a_max=np.inf))  # Alpha's units: error / UQ
×
334
        self.alpha = np.mean(many_alphas)
×
335
        assert self.alpha >= 0
×
336

337
        # Update the threshold used to determine if the surrogate is usable
338
        if self.threshold is None:
×
339
            # Use the initial estimate for alpha to set a conservative threshold
340
            #  Following Eq. 1 of https://dl.acm.org/doi/abs/10.1145/3447818.3460370,
341
            self.threshold = self.parameters['target_ferr'] / self.alpha  # Units: error / (error / UQ) -> UQ
×
342
            self.threshold /= 2  # Make the threshold even stricter than we estimate TODO (wardlt): Make this adjustable
×
343
        else:
344
            # Update according to Eq. 3 of https://dl.acm.org/doi/abs/10.1145/3447818.3460370
345
            current_err = np.mean([e for _, e in self.error_history])
×
346
            self.threshold -= (current_err - self.parameters['target_ferr']) / self.alpha
×
347
            self.threshold = max(self.threshold, 0)  # Keep it at least zero (assuming UQ signals are nonnegative)
×
348

349
    def get_state(self) -> dict[str, Any]:
1✔
350
        """Get the state of the learner in a state that can be saved to disk using pickle
351

352
        The state contains the current threshold control parameters, error history, retraining status, and the latest models.
353

354
        Returns:
355
            Dictionary containing the state of the model(s)
356
        """
357

358
        output = {
×
359
            'threshold': None if self.threshold is None else float(self.threshold),
360
            'alpha': None if self.alpha is None else float(self.alpha),
361
            'blending_step': int(self.blending_step),
362
            'error_history': list(self.error_history) if self.error_history is not None else [],
363
            'new_points': self.new_points,
364
            'train_logs': self.train_logs,
365
            'total_invocations': self.total_invocations,
366
            'target_invocations': self.target_invocations,
367
            'model_version': self.model_version
368
        }
369

370
        # Write models if they have been assembled into a calculator
371
        if self.surrogate_calc is not None:
×
372
            output['models'] = [self.learner.serialize_model(s) for s in self.models]
×
373
        return output
×
374

375
    def set_state(self, state: dict[str, Any]):
1✔
376
        """Set the state of learner using the state saved by :meth:`get_state`
377

378
        Args:
379
            state: State containing the threshold control system parameters and trained models, if available
380
        """
381

382
        # Set the state of the threshold
383
        self.alpha = state['alpha']
×
384
        self.blending_step = state['blending_step']
×
385
        self.threshold = state['threshold']
×
386
        self.new_points = state['new_points']
×
387
        self.error_history = deque(maxlen=self.parameters['history_length'])
×
388
        self.error_history.extend(state['error_history'])
×
389
        self.train_logs = state['train_logs']
×
390
        self.total_invocations = state['total_invocations']
×
391
        self.target_invocations = state['target_invocations']
×
392
        self.model_version = state['model_version']
×
393

394
        # Remake the surrogate calculator, if available
395
        if 'models' in state:
×
396
            # Store in a different place depending on whether we are training from original or latest
397
            if self.parameters['train_from_original']:
×
398
                self.models = state['models']
×
399
            else:
400
                self.models = self.parameters['models'] = state['models']  # Both are the same
×
401
            self.surrogate_calc = EnsembleCalculator(
×
402
                calculators=[self.learner.make_calculator(m, self.parameters['device']) for m in state['models']]
403
            )
404

405
    def todict(self, skip_default=True):
1✔
406
        # Never skip defaults because testing for equality between current and default breaks for our data types
407
        output = super().todict(False)
×
408

409
        # Drop a brief status report into here as well
410
        #  Note: you must convert from numpy to Python int/floats before serialization
411
        output.update({
×
412
            'blending_step': int(self.blending_step),
413
            'total_invocations': int(self.total_invocations),
414
            'target_invocations': int(self.target_invocations),
415
        })
416
        if self.error_history is not None:
×
417
            output['current_error'] = float(np.mean([e for _, e in self.error_history]))
×
418

419
        # The models don't json serialize, so let's skip them
420
        output.pop('models')
×
421
        output.pop('learner')
×
422
        return output
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc