• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

globus-labs / cascade / 18732618520

22 Oct 2025 11:21PM UTC coverage: 25.34% (-70.0%) from 95.34%
18732618520

Pull #70

github

miketynes
fix init, logging init, waiting
Pull Request #70: Academy proto

261 of 1030 relevant lines covered (25.34%)

0.25 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

16.33
/cascade/learning/mace.py
1
"""Interface to the higher-order equivariant neural networks
2
of `Batatia et al. <https://arxiv.org/abs/2206.07697>`_"""
3

4
import logging
1✔
5
import random
1✔
6
from pathlib import Path
1✔
7
from tempfile import TemporaryDirectory
1✔
8

9
import ase
1✔
10
import torch
1✔
11
import numpy as np
1✔
12
import pandas as pd
1✔
13
from ase import Atoms, data
1✔
14
from ase.calculators.calculator import Calculator
1✔
15
from ignite.engine import Engine, Events
1✔
16
from mace.data import AtomicData
1✔
17
from mace.data.utils import config_from_atoms
1✔
18
from mace.modules import WeightedHuberEnergyForcesStressLoss, ScaleShiftMACE
1✔
19
from mace.tools import AtomicNumberTable
1✔
20
from mace.tools.torch_geometric.dataloader import DataLoader
1✔
21
from mace.tools.scripts_utils import extract_config_mace_model
1✔
22
from mace.calculators import MACECalculator
1✔
23

24
from cascade.learning.base import BaseLearnableForcefield, State
1✔
25
from cascade.learning.finetuning import MultiHeadConfig
1✔
26
from cascade.learning.utils import estimate_atomic_energies
1✔
27

28
logger = logging.getLogger(__name__)
1✔
29

30
MACEState = ScaleShiftMACE
1✔
31
"""Just the model, which we require being the MACE which includes scale shifting logic"""
1✔
32

33

34
def _update_offset_factors(model: ScaleShiftMACE, train_data: list[Atoms], train_loader: DataLoader, device: str):
1✔
35
    """Update the atomic energies and scale offset layers of a model
36

37
    Args:
38
        model: Model to be adjusted
39
        train_data: Training dataset
40
        train_loader: Loader built using the training set
41
        device: Device on which to perform inference
42
    """
43
    # Update the atomic energies using the data from all trajectories
44
    z_table = AtomicNumberTable(model.atomic_numbers.cpu().numpy().tolist())
×
45
    new_ae = model.atomic_energies_fn.atomic_energies.cpu().numpy()
×
46
    atomic_energies_dict = estimate_atomic_energies(train_data)
×
47
    for s, e in atomic_energies_dict.items():
×
48
        new_ae[z_table.zs.index(data.atomic_numbers[s])] = e
×
49
    with torch.no_grad():
×
50
        old_ae = model.atomic_energies_fn.atomic_energies
×
51
        model.atomic_energies_fn.atomic_energies = torch.from_numpy(new_ae).to(old_ae.dtype).to(old_ae.device)
×
52

53
    # Update the shift of the energy scale
54
    errors = []
×
55
    for batch in train_loader:
×
56
        batch = batch.to(device)
×
57
        num_atoms = batch.ptr[1:] - batch.ptr[:-1]  # Use the offsets to compute the number of atoms per inference
×
58
        ml = model(
×
59
            batch,
60
            training=False,
61
            compute_force=False,
62
            compute_virials=False,
63
            compute_stress=False,
64
        )
65
        error = (ml["energy"] - batch["energy"]) / num_atoms
×
66
        errors.extend(error.cpu().detach().numpy().tolist())
×
67
    model.scale_shift.shift -= np.mean(errors)
×
68

69

70
# TODO (wardlt): Use https://github.com/ACEsuit/mace/pull/830 when merged
71
def freeze_layers(model: torch.nn.Module, n: int = 4) -> None:
1✔
72
    """
73
    Freezes the first `n` layers of a model. If `n` is negative, freezes the last `|n|` layers.
74
    Args:
75
        model (torch.nn.Module): The model.
76
        n (int): The number of layers to freeze.
77
    """
78
    layers = list(model.children())
×
79
    num_layers = len(layers)
×
80

81
    logging.info(f"Total layers in model: {num_layers}")
×
82

83
    if abs(n) > num_layers:
×
84
        logging.warning(
×
85
            f"Requested {n} layers, but model only has {num_layers}. Adjusting `n` to fit the model."
86
        )
87
        n = num_layers if n > 0 else -num_layers
×
88

89
    frozen_layers = layers[:n] if n > 0 else layers[n:]
×
90

91
    logging.info(f"Freezing {len(frozen_layers)} layers.")
×
92

93
    for layer in frozen_layers:
×
94
        for param in layer.parameters():
×
95
            param.requires_grad = False
×
96

97

98
def atoms_to_loader(atoms: list[Atoms], batch_size: int, z_table: AtomicNumberTable, r_max: float, **kwargs):
1✔
99
    """
100
    Make a data loader from a list of ASE atoms objects
101

102
    Args:
103
        atoms: Atoms from which to create the loader
104
        batch_size: Batch size for the loader
105
        z_table: Map between atom ID in mace and periodic table
106
        r_max: Cutoff distance
107
    """
108

109
    def _prepare_atoms(my_atoms: Atoms):
×
110
        """MACE expects the training outputs to be stored in `info` and `arrays`"""
111
        # Start with a copy of positions, which should be available always
112
        my_atoms.arrays.update({
×
113
            'positions': my_atoms.positions,
114
        })
115

116
        if my_atoms.calc is None:
×
117
            return my_atoms  # No calc, no results
×
118

119
        # Now make an info dictionary if one doesn't exist yet
120
        if my_atoms.info is None:
×
121
            my_atoms.info = {}
×
122

123
        # Copy over all property data which exists
124
        if 'energy' in my_atoms.calc.results:
×
125
            my_atoms.info['energy'] = my_atoms.get_potential_energy()
×
126

127
        if 'stress' in my_atoms.calc.results:
×
128
            my_atoms.info['stress'] = my_atoms.get_stress()
×
129

130
        if 'forces' in my_atoms.calc.results:
×
131
            my_atoms.arrays['forces'] = my_atoms.get_forces()
×
132

133
        return my_atoms
×
134

135
    atoms = [config_from_atoms(_prepare_atoms(a)) for a in atoms]
×
136
    return DataLoader(
×
137
        [AtomicData.from_config(c, z_table=z_table, cutoff=r_max) for c in atoms],
138
        batch_size=batch_size,
139
        **kwargs
140
    )
141

142

143
class MACEInterface(BaseLearnableForcefield[MACEState]):
1✔
144
    """Interface to the `MACE library <https://github.com/ACEsuit/mace>`_"""
145

146
    def create_extra_heads(self, model: ScaleShiftMACE, num_heads: int) -> list[ScaleShiftMACE]:
1✔
147
        """Create multiple instances of a ScaleShiftMACE model that share some of the same layers
148

149
        The new models will share the node embedding, interaction, and product layers;
150
        but will have separate atomic energy, readout, and scale_shift layers.
151

152
        Args:
153
            model: Model to be replicated
154
            num_heads: Number of replicas to create
155
        Returns:
156
            Additional copies of the model with the same internal layers
157
        """
158

159
        _shared_layers = ('node_embedding', 'interactions', 'products')
×
160

161
        output = []
×
162
        for _ in range(num_heads):
×
163
            # Make a deep copy of the model
164
            new_model = self.get_model(self.serialize_model(model))
×
165

166
            # Copy over the shared layers
167
            for layer in _shared_layers:
×
168
                setattr(new_model, layer, getattr(model, layer))
×
169
            output.append(new_model)
×
170
        return output
×
171

172
    def evaluate(self,
1✔
173
                 model_msg: bytes | State,
174
                 atoms: list[ase.Atoms],
175
                 batch_size: int = 64,
176
                 device: str = 'cpu') -> (np.ndarray, list[np.ndarray], np.ndarray):
177
        # Ready the models and the data
178
        model = self.get_model(model_msg)
×
179
        r_max = model.r_max.item()
×
180
        z_table = AtomicNumberTable(model.atomic_numbers.cpu().numpy().tolist())
×
181

182
        model.to(device)
×
183
        loader = atoms_to_loader(atoms, batch_size=batch_size, z_table=z_table, r_max=r_max, shuffle=False, drop_last=False)
×
184

185
        # Compile results
186
        energies = []
×
187
        forces = []
×
188
        stresses = []
×
189
        model.to(device)
×
190
        model.eval()
×
191
        for batch in loader:
×
192
            batch = batch.to(device)
×
193
            y = model(
×
194
                batch,
195
                training=False,
196
                compute_force=True,
197
                compute_virials=False,
198
                compute_stress=True,
199
            )
200
            energies.extend(y['energy'].cpu().detach().numpy())
×
201
            stresses.extend(y['stress'].cpu().detach().numpy())
×
202
            forces_numpy = y['forces'].cpu().detach().numpy()
×
203
            for i, j in zip(batch.ptr, batch.ptr[1:]):
×
204
                forces.append(forces_numpy[i:j, :])
×
205
        return np.array(energies), forces, np.array(stresses)
×
206

207
    def train(self,
1✔
208
              model_msg: bytes | State,
209
              train_data: list[Atoms],
210
              valid_data: list[Atoms],
211
              num_epochs: int,
212
              device: str = 'cpu',
213
              batch_size: int = 32,
214
              learning_rate: float = 1e-3,
215
              huber_deltas: tuple[float, float, float] = (0.5, 1, 1),
216
              force_weight: float = 10,
217
              stress_weight: float = 100,
218
              reset_weights: bool = False,
219
              patience: int | None = None,
220
              num_freeze: int | None = None,
221
              replay: MultiHeadConfig | None = None
222
              ) -> tuple[bytes, pd.DataFrame]:
223
        """Train a model
224

225
        Args:
226
            model_msg: Model to be retrained
227
            train_data: Structures used for training
228
            valid_data: Structures used for validation
229
            num_epochs: Number of training epochs
230
            device: Device (e.g., 'cuda', 'cpu') used for training
231
            batch_size: Batch size during training
232
            learning_rate: Initial learning rate for optimizer
233
            huber_deltas: Delta parameters for the loss functions for energy and force
234
            force_weight: Amount of weight to use for the force part of the loss function
235
            stress_weight: Amount of weight to use for the stress part of the loss function
236
            reset_weights: Whether to reset the weights before training
237
            patience: Halt training after validation error increases for these many epochs
238
            num_freeze: Number of layers to freeze. Starts from the top of the model (node embedding)
239
                See: `Radova et al. <https://arxiv.org/html/2502.15582v1>`_
240
            replay: Settings for replaying an initial training set
241
        Returns:
242
            - model: Retrained model
243
            - history: Training history
244
        """
245

246
        # Load the model
247
        model = self.get_model(model_msg)
×
248
        r_max = model.r_max.item()
×
249
        z_table = AtomicNumberTable(model.atomic_numbers.cpu().numpy().tolist())
×
250

251
        # Reset weights if desired
252
        if reset_weights:
×
253
            config = extract_config_mace_model(model)
×
254
            model = ScaleShiftMACE(**config)
×
255
        model.to(device)
×
256

257
        # Unpin weights
258
        for p in model.parameters():
×
259
            p.requires_grad = True
×
260

261
        # Freeze desired layers
262
        if num_freeze is not None:
×
263
            freeze_layers(model, num_freeze)
×
264

265
        # Convert the training data from ASE -> MACE Configs
266
        train_loader = atoms_to_loader(train_data, batch_size, z_table, r_max, shuffle=True, drop_last=True)
×
267
        valid_loader = atoms_to_loader(valid_data, batch_size, z_table, r_max, shuffle=False, drop_last=True)
×
268

269
        # Update the atomic energies for the current dataset
270
        _update_offset_factors(model, train_data, train_loader, device)
×
271

272
        opt = torch.optim.Adam(model.parameters(), lr=learning_rate)
×
273
        criterion = WeightedHuberEnergyForcesStressLoss(
×
274
            energy_weight=1,
275
            forces_weight=force_weight,
276
            stress_weight=stress_weight,
277
            huber_delta=huber_deltas[0],
278
        )
279

280
        # Prepare the training engine
281
        train_losses = []
×
282

283
        def get_loss_stats(b, y):
×
284
            """Compute the losses"""
285
            na = b.ptr[1:] - b.ptr[:-1]
×
286
            return {
×
287
                'energy_mae': torch.mean(torch.abs(b['energy'] - y['energy']) / na).item(),
288
                'force_rmse': torch.sqrt(torch.square(b['forces'] - y['forces']).mean()).item(),
289
                'stress_rmse': torch.sqrt(torch.square(b['stress'] - y['stress']).mean()).item()
290
            }
291

292
        def train_step(engine, batch):
×
293
            """Borrowed from the training step used inside MACE"""
294
            model.train()
×
295
            opt.zero_grad()
×
296
            batch = batch.to(device)
×
297
            y = model(
×
298
                batch,
299
                training=True,
300
                compute_force=True,
301
                compute_virials=False,
302
                compute_stress=True,
303
            )
304
            loss = criterion(pred=y, ref=batch)
×
305
            loss.backward()
×
306
            opt.step()
×
307

308
            # Get the training stats
309
            detailed_loss = get_loss_stats(batch, y)
×
310
            detailed_loss['epoch'] = engine.state.epoch - 1
×
311
            detailed_loss['total_loss'] = loss.item()
×
312
            train_losses.append(detailed_loss)
×
313
            return loss.item()
×
314

315
        trainer = Engine(train_step)
×
316

317
        # Make the validation step
318
        valid_losses = []
×
319
        patience_status = {'best_loss': np.inf, 'patience': patience}
×
320

321
        @trainer.on(Events.EPOCH_COMPLETED)
×
322
        def validation_process(engine: Engine):
×
323
            model.eval()
×
324
            logger.info(f'Started validation for epoch {engine.state.epoch - 1}')
×
325

326
            for batch in valid_loader:
×
327
                batch.to(device)
×
328
                y = model(
×
329
                    batch,
330
                    training=False,
331
                    compute_force=True,
332
                    compute_virials=False,
333
                    compute_stress=True,
334
                )
335
                loss = criterion(pred=y, ref=batch)
×
336
                detailed_loss = get_loss_stats(batch, y)
×
337
                detailed_loss['epoch'] = engine.state.epoch - 1
×
338
                detailed_loss['total_loss'] = loss.item()
×
339
                valid_losses.append(detailed_loss)
×
340
            logger.info(f'Completed validation for epoch {engine.state.epoch - 1}')
×
341

342
            # Add early stopping if desired
343
            if patience_status['patience'] is not None:
×
344
                cur_loss = np.mean([x['total_loss'] for x in valid_losses if x['epoch'] == engine.state.epoch - 1])
×
345
                if cur_loss < patience_status['best_loss']:
×
346
                    patience_status['best_loss'] = cur_loss
×
347
                    patience_status['patience'] = patience
×
348
                else:
349
                    patience_status['patience'] -= 1
×
350

351
                if patience_status['patience'] < 0:
×
352
                    engine.terminate()
×
353
                    logger.info('Early stopping criterion met')
×
354

355
        # Add multi-head replay, if desired
356
        if replay is not None:
×
357
            # Downselect data, if desired
358
            #  TODO (wardlt): Use FPS to pick samples, as in https://arxiv.org/abs/2412.02877
359
            if replay.num_downselect == 0:
×
360
                replay_data = replay.original_dataset.copy()
×
361
                random.shuffle(replay_data)
×
362
                replay_data = replay_data[:replay.num_downselect]
×
363
            else:
364
                replay_data = replay.original_dataset
×
365

366
            # Create and re-scale replay
367
            replay_model = self.create_extra_heads(model, 1)[0]
×
368
            replay_model.to(device)
×
369
            replay_loader = atoms_to_loader(replay_data, replay.batch_size or batch_size,
×
370
                                            z_table, r_max, shuffle=True, drop_last=True)
371
            _update_offset_factors(replay_model, replay_data, replay_loader, device)
×
372

373
            # Make the replay loss
374
            replay_opt = torch.optim.Adam(replay_model.parameters(), lr=learning_rate // replay.lr_reduction)
×
375

376
            @trainer.on(Events.EPOCH_COMPLETED(every=replay.epoch_frequency))
×
377
            def replay_process(engine: Engine):
×
378
                replay_model.train()
×
379
                logger.info(f'Started replay for epoch {engine.state.epoch - 1}')
×
380

381
                for batch in replay_loader:
×
382
                    batch.to(device)
×
383
                    y = replay_model(
×
384
                        batch,
385
                        training=True,
386
                        compute_force=True,
387
                        compute_virials=False,
388
                        compute_stress=True,
389
                    )
390
                    loss = criterion(pred=y, ref=batch)
×
391
                    loss.backward()
×
392
                    replay_opt.step()
×
393

394
                    detailed_loss = dict((f'{k}_replay', v) for k, v in get_loss_stats(batch, y).items())
×
395
                    detailed_loss['epoch'] = engine.state.epoch - 1
×
396
                    detailed_loss['total_loss_replay'] = loss.item()
×
397
                    valid_losses.append(detailed_loss)
×
398

399
        logger.info('Started training')
×
400
        trainer.run(train_loader, max_epochs=num_epochs)
×
401
        logger.info('Finished training')
×
402
        model.cpu()  # Force it off the GPU
×
403

404
        # Compile the loss
405
        train_losses = pd.DataFrame(train_losses).groupby('epoch').mean().reset_index()
×
406
        valid_losses = pd.DataFrame(valid_losses).groupby('epoch').mean().reset_index()
×
407
        log = train_losses.merge(valid_losses, on='epoch', suffixes=('_train', '_valid'))
×
408
        return self.serialize_model(model), log
×
409

410
    def make_calculator(self, model_msg: bytes | State, device: str) -> Calculator:
1✔
411
        # MACE calculator loads the model from disk, so let's write to disk
412
        with TemporaryDirectory(self.scratch_dir, prefix='mace_') as tmp:
×
413
            model_path = Path(tmp) / 'model.pt'
×
414
            model_path.write_bytes(self.serialize_model(model_msg))
×
415

416
            return MACECalculator(model_paths=[model_path], device=device, compile_mode=None)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc