• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

globus-labs / cascade / 18732618520

22 Oct 2025 11:21PM UTC coverage: 25.34% (-70.0%) from 95.34%
18732618520

Pull #70

github

miketynes
fix init, logging init, waiting
Pull Request #70: Academy proto

261 of 1030 relevant lines covered (25.34%)

0.25 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

12.5
/cascade/learning/torchani/__init__.py
1
"""Interface and glue code for to models built using `TorchANI <https://github.com/aiqm/torchani>_"""
2
from functools import partial
1✔
3
import copy
1✔
4

5
from ase import units, Atoms
1✔
6
import numpy as np
1✔
7
import pandas as pd
1✔
8
from ase.calculators.calculator import Calculator
1✔
9
import torch
1✔
10
from torch.utils.data import DataLoader
1✔
11
from torchani.nn import SpeciesEnergies, Sequential
1✔
12
from torchani import AEVComputer, ANIModel, EnergyShifter
1✔
13
from torchani.aev import SpeciesAEV
1✔
14
from torchani.data import collate_fn
1✔
15
from ignite.engine import Engine, Events
1✔
16

17
from cascade.learning.base import BaseLearnableForcefield, State
1✔
18
from cascade.learning.utils import estimate_atomic_energies
1✔
19
from cascade.learning.torchani.ase import Calculator as ANICalculator
1✔
20

21
__all__ = ['TorchANI', 'ANIModelContents']
1✔
22

23
ANIModelContents = tuple[AEVComputer, ANIModel, dict[str, float]]
1✔
24
"""Contents of the serialized form of a model:
1✔
25

26
1. Compute for atomic environments
27
2. The model which maps environments to energies
28
3. Ordered dict of chemical symbol to atomic energies (all Py3 dicts are ordered)
29
"""
30

31
my_collate_dict = {
1✔
32
    'species': -1,
33
    'coordinates': 0.0,
34
    'forces': 0.0,
35
    'energies': 0.0,
36
    'cells': 0.0,
37
    'volumes': 0.0,
38
    'stresses': 0.0
39
}
40

41

42
def ase_to_ani(atoms: Atoms, species: list[str]) -> dict[str, torch.Tensor]:
1✔
43
    """Make an ANI-format dictionary from an ASE Atoms object
44

45
    Args:
46
        atoms: Atoms object to be converted
47
        species: List of species used to determine index given chemical symbol
48
    Returns:
49
        Atoms object in a tensor format
50
    """
51

52
    # An entry _must_ have the species and coordinates
53
    output = {
×
54
        'species': torch.from_numpy(np.array([species.index(s) for s in atoms.symbols])),
55
        'coordinates': torch.from_numpy(atoms.positions).float(),
56
        'cells': torch.from_numpy(atoms.cell.array).float(),
57
        'volumes': torch.from_numpy(np.array(atoms.get_volume())).float()
58
    }
59

60
    if atoms.calc is not None:
×
61
        if 'energy' in atoms.calc.results:
×
62
            output['energies'] = torch.from_numpy(np.atleast_1d(atoms.get_potential_energy())).float()
×
63
        if 'forces' in atoms.calc.results:
×
64
            output['forces'] = torch.from_numpy(atoms.get_forces()).float()
×
65
        if 'stress' in atoms.calc.results:
×
66
            output['stresses'] = torch.from_numpy(atoms.get_stress(voigt=False)).float()
×
67
    return output
×
68

69

70
def make_data_loader(data: list[Atoms],
1✔
71
                     species: list[str],
72
                     batch_size: int,
73
                     train: bool,
74
                     **kwargs) -> DataLoader:
75
    """Make a data loader based on a list of Atoms
76

77
    Args:
78
        data: Data to use for the loader
79
        species: Map of chemical species to index in network
80
        batch_size: Batch size to use for the loader
81
        train: Whether this is a loader for training data. If so, sets ``shuffle`` and ``drop_last`` to True.
82
        kwargs: Passed to the ``DataLoader`` constructor
83
    """
84
    # Append training settings if set to train
85
    if train:
×
86
        kwargs['shuffle'] = True
×
87
        kwargs['drop_last'] = True
×
88

89
    return DataLoader([ase_to_ani(a, species) for a in data],
×
90
                      collate_fn=lambda x: collate_fn(x, my_collate_dict),
91
                      batch_size=max(min(batch_size, len(data)), 1),
92
                      **kwargs)
93

94

95
def forward_batch(batch: dict[str, torch.Tensor],
1✔
96
                  aev_computer: AEVComputer,
97
                  nn: ANIModel,
98
                  atom_energies: np.ndarray,
99
                  pbc: torch.Tensor,
100
                  forces: bool = True,
101
                  stresses: bool = True,
102
                  train: bool = True,
103
                  device: str = 'cpu') -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
104
    """Run the forward step on a batch of entries
105

106
    Args:
107
        batch: Batch from the data loader
108
        aev_computer: Atomic environment computer
109
        nn: Model which maps atomic environments to energies
110
        atom_energies: Array holding the reference energy for each species
111
        pbc: Periodic boundary conditions used by all members of the batch
112
        forces: Whether to compute forces
113
        stresses: Whether to compute stresses
114
        train: Whether we are in training mode
115
        device: Device on which to run computations
116
    Returns:
117
        - Energies for each member
118
        - Forces for each member, if ``forces``
119
        - Stresses for each member, if ``stresses``
120
    """
121

122
    # Move the data to the device
123
    batch_z = batch['species'].to(device)
×
124
    batch_x = batch['coordinates'].float().to(device).requires_grad_(forces)
×
125
    batch_c = batch['cells'].float().to(device)
×
126

127
    # Prepare for stress calculation
128
    scaling = None
×
129
    if stresses:
×
130
        scaling = torch.eye(3, requires_grad=True, dtype=batch_x.dtype, device=device)
×
131
        scaling = torch.tile(scaling[None, :, :], (batch_z.shape[0], 1, 1))
×
132
        batch_c = torch.matmul(batch_c, scaling)
×
133
        batch_x = torch.matmul(batch_x, scaling)
×
134

135
    # Compute the energy offset per member (run on the CPU because it's fast)
136
    species = batch['species'].numpy()
×
137
    batch_o = atom_energies[batch['species'].numpy()].sum(axis=1, where=(species >= 0))
×
138
    batch_o = torch.from_numpy(batch_o).to(device)
×
139

140
    # Compute the AEVs individually because TorchANI assumes all entries have the same cell size
141
    batch_a = []
×
142
    for row_z, row_x, row_c in zip(batch_z, batch_x, batch_c):
×
143
        row_z = torch.unsqueeze(row_z, 0)
×
144
        row_x = torch.unsqueeze(row_x, 0)
×
145
        batch_a.append(aev_computer((row_z, row_x), row_c, pbc).aevs)
×
146
    batch_a = torch.concat(batch_a)
×
147
    batch_a = SpeciesAEV(batch_z, batch_a)
×
148

149
    # Get the energies for each member
150
    _, batch_e_pred = nn(batch_a)
×
151
    batch_e_pred = batch_e_pred + batch_o
×
152

153
    # Compute forces
154
    batch_f_pred = None
×
155
    if forces:
×
156
        batch_f_pred = -torch.autograd.grad(batch_e_pred.sum(), batch_x, create_graph=train or stresses)[0]
×
157

158
    # Compute stresses
159
    batch_s_pred = None
×
160
    if stresses:
×
161
        batch_v = batch['volumes'].float().to(device)
×
162
        batch_s_pred = torch.autograd.grad(batch_e_pred.sum(), scaling, create_graph=train)[0] / batch_v[:, None, None]
×
163

164
    return batch_e_pred, batch_f_pred, batch_s_pred
×
165

166

167
def adjust_energy_scale(aev_computer: AEVComputer,
1✔
168
                        model: ANIModel,
169
                        loader: DataLoader,
170
                        atom_energies: np.ndarray,
171
                        device: str | torch.device = 'cpu') -> tuple[float, float]:
172
    """Adjust the last layer of an ANIModel such that its standard deviation matches that of the training data
173

174
    Args:
175
        aev_computer: Tool which computes atomic environments
176
        model: Model to be adjusted
177
        loader: Data loader
178
        atom_energies: Reference energy for each specie
179
        device: Device on which to perform inference
180
    Returns:
181
        Scale and shift factors
182
    """
183

184
    # Iterate over the dataset to get the actual and observed standard deviation of atomic energies
185
    pbc = torch.from_numpy(np.ones((3,), bool)).to(device)  # TODO (don't hard code to 3D)
×
186
    true_energies = []
×
187
    pred_energies = []
×
188
    for batch in loader:
×
189
        # Get the actual energies and predicted energies for the system
190
        batch_e_pred, _, _ = forward_batch(batch, aev_computer, model, atom_energies, pbc, forces=False, stresses=False, train=False, device=device)
×
191
        batch_e = batch['energies'][:, 0].cpu().numpy()
×
192

193
        # Get the energy per atom w/o the reference energy
194
        species = batch['species'].numpy()
×
195
        batch_o = atom_energies[species].sum(axis=1, where=species >= 0)
×
196
        batch_n = (species >= 0).sum(axis=1, dtype=batch_e.dtype)
×
197
        batch_e_pred = batch_e_pred.detach().cpu().numpy()
×
198

199
        pred_energies.extend((batch_e_pred - batch_o) / batch_n)
×
200
        true_energies.extend((batch_e - batch_o) / batch_n)
×
201

202
    # Get the ratio in standard deviations and the sign
203
    true_std = np.std(true_energies)
×
204
    pred_std = np.std(pred_energies)
×
205
    factor = true_std / pred_std
×
206
    r = np.corrcoef(true_energies, pred_energies)[0, 1]
×
207
    if r < 0:
×
208
        factor *= -1
×
209

210
    # Update the last layer of each network to match the new scaling
211
    with torch.no_grad():
×
212
        for m in model.values():
×
213
            last_linear = m[-1]
×
214
            assert isinstance(last_linear, torch.nn.Linear), f'Last layer is not linear. It is {type(last_linear)}'
×
215
            assert last_linear.out_features == 1, f'Expected last layer to have one output. Found {last_linear.out_features}'
×
216
            last_linear.weight *= factor
×
217

218
    # Recompute the energies given the new shift factor
219
    pred_energies = []
×
220
    for batch in loader:
×
221
        batch_e_pred, _, _ = forward_batch(batch, aev_computer, model, atom_energies, pbc, forces=False, stresses=False, train=False, device=device)
×
222
        species = batch['species'].numpy()
×
223
        batch_o = atom_energies[batch['species'].numpy()].sum(axis=1, where=species >= 0)
×
224
        batch_n = (species >= 0).sum(axis=1, dtype=batch_o.dtype)
×
225
        pred_energies.extend((batch_e_pred.detach().cpu().numpy() - batch_o) / batch_n)
×
226

227
    # Get the shift for the mean
228
    true_mean = np.mean(true_energies)
×
229
    pred_mean = np.mean(pred_energies)
×
230
    shift = pred_mean - true_mean
×
231

232
    # Update the last layer of each network to match the new scaling
233
    with torch.no_grad():
×
234
        for m in model.values():
×
235
            last_linear = m[-1]
×
236
            last_linear.bias -= shift
×
237

238
    return factor, shift
×
239

240

241
class TorchANI(BaseLearnableForcefield[ANIModelContents]):
1✔
242
    """Interface to the high-dimensional neural networks implemented by `TorchANI <https://github.com/aiqm/torchani>`_"""
243

244
    def evaluate(self,
1✔
245
                 model_msg: bytes | ANIModelContents,
246
                 atoms: list[Atoms],
247
                 batch_size: int = 64,
248
                 device: str = 'cpu') -> tuple[np.ndarray, list[np.ndarray], np.ndarray]:
249

250
        # TODO (wardlt): Put model in "eval" mode, skip making the graph when computing gradients in `forward_batch` <- performance optimizations
251

252
        # Unpack the model
253
        if isinstance(model_msg, bytes):
×
254
            model_msg = self.get_model(model_msg)
×
255
        aev_computer, model, atomic_energies = model_msg
×
256
        model.to(device)
×
257
        aev_computer.to(device)
×
258

259
        # Unpack the reference energies as a float32 array
260
        species = list(atomic_energies.keys())
×
261
        ref_energies = np.array([atomic_energies[s] for s in species]).astype(np.float32)
×
262

263
        # Build the data loader
264
        loader = make_data_loader(atoms, species, batch_size, train=False)
×
265

266
        # Run inference on all data
267
        energies = []
×
268
        forces = []
×
269
        stresses = []
×
270
        pbc = torch.from_numpy(np.ones((3,), bool)).to(device)  # TODO (don't hard code to 3D)
×
271
        for batch in loader:
×
272
            batch_e_pred, batch_f_pred, batch_s_pred = forward_batch(batch, aev_computer, model, ref_energies, pbc, stresses=True, train=False, device=device)
×
273
            energies.extend(batch_e_pred.detach().cpu().numpy())  # Energies and stress are the same regardless of size of input
×
274
            stresses.extend(batch_s_pred.detach().cpu().numpy())
×
275

276
            # The shape of the force array differs depending on size
277
            batch_n = (batch['species'] >= 0).sum(dim=1).cpu().numpy()  # Number of real atoms per batch
×
278
            for entry_f, entry_n in zip(batch_f_pred.detach().cpu().numpy(), batch_n):
×
279
                forces.append(entry_f[:entry_n, :])
×
280

281
        # Move model back from device
282
        model.to('cpu')
×
283
        aev_computer.to('cpu')
×
284

285
        return np.array(energies), list(forces), np.array(stresses)
×
286

287
    def train(self,
1✔
288
              model_msg: bytes | State,
289
              train_data: list[Atoms],
290
              valid_data: list[Atoms],
291
              num_epochs: int,
292
              device: str = 'cpu',
293
              batch_size: int = 32,
294
              learning_rate: float = 1e-3,
295
              huber_deltas: tuple[float, float, float] = (0.5, 1, 1),
296
              force_weight: float = 10,
297
              stress_weight: float = 100,
298
              reset_weights: bool = False,
299
              scale_energies: bool = True,
300
              **kwargs) -> tuple[bytes, pd.DataFrame]:
301
        # Unpack the model and move to device
302
        if isinstance(model_msg, bytes):
×
303
            model_msg = self.get_model(model_msg)
×
304
        aev_computer, model, atomic_energies = model_msg
×
305
        species = list(atomic_energies.keys())
×
306
        model.to(device)
×
307
        aev_computer.to(device)
×
308

309
        # Re-fit the atomic energies
310
        atomic_energies.update(estimate_atomic_energies(train_data))
×
311
        ref_energies = np.array([atomic_energies[s] for s in species], dtype=np.float32)  # Don't forget to cast to f32!
×
312

313
        # Reset the weights, if desired
314
        def init_params(m):
×
315
            if isinstance(m, torch.nn.Linear):
×
316
                torch.nn.init.kaiming_normal_(m.weight, a=1.0)
×
317
                torch.nn.init.zeros_(m.bias)
×
318

319
        if reset_weights:
×
320
            model.apply(init_params)
×
321

322
        # Build the data loader
323
        pbc = torch.from_numpy(np.ones((3,), bool)).to(device)  # TODO (don't hard code to 3D)
×
324
        train_loader = make_data_loader(train_data, species, batch_size, train=True)
×
325
        valid_loader = make_data_loader(valid_data, species, batch_size, train=False)
×
326

327
        # Adjust output layers to match data distribution
328
        if scale_energies:
×
329
            adjust_energy_scale(aev_computer, model, train_loader, ref_energies, device)
×
330

331
        # Prepare optimizer and loss functions
332
        opt = torch.optim.Adam(model.parameters(), lr=learning_rate)
×
333

334
        huber_e, huber_f, huber_s = huber_deltas
×
335
        loss_e = torch.nn.HuberLoss(reduction='none', delta=huber_e)
×
336
        loss_f = torch.nn.HuberLoss(reduction='none', delta=huber_f)
×
337
        loss_s = torch.nn.HuberLoss(reduction='none', delta=huber_s)
×
338

339
        def train_step(engine, batch):
×
340
            """Borrowed from the training step used inside MACE"""
341
            model.train()
×
342
            opt.zero_grad()
×
343

344
            # Run the forward step
345
            batch_e_pred, batch_f_pred, batch_s_pred = forward_batch(batch, aev_computer, model, ref_energies, pbc, device=device)
×
346

347
            # Compute the losses
348
            batch_e = batch['energies'].to(device)[:, 0]
×
349
            batch_f = batch['forces'].to(device)
×
350
            batch_s = batch['stresses'].to(device)
×
351
            batch_n = (batch['species'] >= 0).sum(dim=1, dtype=batch_e.dtype).to(device)
×
352

353
            energy_loss = (loss_e(batch_e_pred, batch_e) / batch_n.sqrt()).mean()
×
354
            force_loss = (loss_f(batch_f_pred, batch_f).sum(dim=(1, 2)) / batch_n).mean()
×
355
            stress_loss = (loss_s(batch_s_pred, batch_s).sum(dim=(1, 2))).mean()
×
356
            loss = energy_loss + force_weight * force_loss + stress_weight * stress_loss
×
357
            loss.backward()
×
358
            opt.step()
×
359
            return loss.item()
×
360

361
        def evaluate_model(loader: DataLoader,
×
362
                           accumulator: list[dict[str, float]]) -> None:
363
            """Evaluate the model against all data in loader, store in global list
364

365
            Args:
366
                loader: a pytorch data loader
367
                accumulator: a list in which to append results
368
            """
369
            e_qm, e_ml, f_qm, f_ml, s_qm, s_ml = [], [], [], [], [], []
×
370
            for batch in loader:
×
371
                # Get the true results
372
                batch_e = batch['energies'].cpu().numpy()[:, 0]  # Make it a 1D array
×
373
                batch_f = batch['forces'].cpu().numpy()
×
374
                batch_s = batch['stresses'].cpu().numpy()
×
375

376
                batch_e_pred, batch_f_pred, batch_s_pred = forward_batch(batch, aev_computer, model, ref_energies, pbc, device=device)
×
377
                batch_e_pred = batch_e_pred
×
378

379
                e_qm.extend(batch_e)
×
380
                e_ml.extend(batch_e_pred.detach().cpu().numpy())
×
381
                f_qm.extend(batch_f.ravel())
×
382
                f_ml.extend(batch_f_pred.detach().cpu().numpy().ravel())
×
383
                s_qm.extend(batch_s.ravel())
×
384
                s_ml.extend(batch_s_pred.detach().cpu().numpy().ravel())
×
385

386
            # convert batch results into flat np arrays
387
            e_qm, e_ml, f_qm, f_ml = map(lambda a: np.asarray(a), (e_qm, e_ml, f_qm, f_ml))
×
388

389
            # compute and store metrics
390
            result = {}
×
391
            for tag, qm, ml in [('e', e_qm, e_ml), ('f', f_qm, f_ml), ('s', s_qm, s_ml)]:
×
392
                diff = np.subtract(qm, ml)
×
393
                result[f'{tag}_rmse'] = np.sqrt(np.power(diff, 2)).mean()
×
394
                result[f'{tag}_mae'] = np.abs(diff).mean()
×
395
            accumulator.append(result)
×
396
            return
×
397

398
        # instantiate the trainer
399
        trainer = Engine(train_step)
×
400

401
        # Set up model evaluators
402
        # TODO (miketynes): make this more idiomatic pytorch ignite
403
        # TODO (miketynes): add early stopping (probably depends on the above)
404
        perf_train: list[dict[str, float]] = []
×
405
        perf_val: list[dict[str, float]] = []
×
406
        evaluator_train = partial(evaluate_model,
×
407
                                  loader=train_loader,
408
                                  accumulator=perf_train)
409
        evaluator_val = partial(evaluate_model,
×
410
                                loader=valid_loader,
411
                                accumulator=perf_val)
412
        trainer.add_event_handler(Events.EPOCH_COMPLETED, evaluator_train)
×
413
        trainer.add_event_handler(Events.EPOCH_COMPLETED, evaluator_val)
×
414

415
        # Run the training
416
        trainer.run(train_loader, max_epochs=num_epochs)
×
417

418
        # coalesce the training performance
419
        perf_train, perf_val = map(pd.DataFrame.from_records, [perf_train, perf_val])
×
420
        perf_train.rename(columns=lambda x: f'{x}_train', inplace=True)
×
421
        perf_val.rename(columns=lambda x: f'{x}_valid', inplace=True)
×
422
        perf = pd.concat([perf_train, perf_val], axis=1).reset_index(names='iteration')
×
423

424
        # Ensure GPU memory is cleared
425
        model.to('cpu')
×
426
        aev_computer.to('cpu')
×
427

428
        # serialize model
429
        model_msg = self.serialize_model((aev_computer, model, atomic_energies))
×
430
        return model_msg, perf
×
431

432
    def make_calculator(self, model_msg: bytes | ANIModelContents, device: str) -> Calculator:
1✔
433
        # Unpack the model
434
        if isinstance(model_msg, bytes):
×
435
            model_msg = self.get_model(model_msg)
×
436
        aev_computer, model, atomic_energies = model_msg
×
437

438
        # Make a copy of the model because torchANI makes it untrainable
439
        model.to('cpu')
×
440
        model = copy.deepcopy(model)
×
441

442
        # Make an output layer which re-adds the atomic energies and other which converts to Ha (TorchNANI
443
        ref_energies = torch.tensor(list(atomic_energies.values()), dtype=torch.float32)
×
444
        shifter = EnergyShifter(ref_energies)
×
445

446
        class ToHartree(torch.nn.Module):
×
447
            def forward(self, species_energies: SpeciesEnergies,
×
448
                        cell: torch.Tensor | None = None,
449
                        pbc: torch.Tensor | None = None) -> SpeciesEnergies:
450
                species, energies = species_energies
×
451
                return SpeciesEnergies(species, energies / units.Hartree)
×
452

453
        to_hartree = ToHartree()
×
454
        post_model = Sequential(
×
455
            aev_computer,
456
            model,
457
            shifter,
458
            to_hartree
459
        )  # Use ANI's Sequential, which is fine with `cell` and `pbc` as inputs
460

461
        # Assemble the calculator
462
        species = list(atomic_energies.keys())
×
463
        post_model.to(device)
×
464
        return ANICalculator(species, post_model, overwrite=False)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc