• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

globus-labs / cascade / 18732618520

22 Oct 2025 11:21PM UTC coverage: 25.34% (-70.0%) from 95.34%
18732618520

Pull #70

github

miketynes
fix init, logging init, waiting
Pull Request #70: Academy proto

261 of 1030 relevant lines covered (25.34%)

0.25 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

25.0
/cascade/learning/chgnet.py
1
"""Training based on the `"ChargeNet" neural network <https://doi.org/10.1038/s42256-023-00716-3>`_"""
2
from contextlib import redirect_stdout
1✔
3
from pathlib import Path
1✔
4
from tempfile import TemporaryDirectory
1✔
5

6
import numpy as np
1✔
7
import ase
1✔
8
import pandas as pd
1✔
9
from ase import units
1✔
10
from ase.calculators.calculator import Calculator
1✔
11
from chgnet.data.dataset import StructureData, get_loader
1✔
12
from chgnet.model import CHGNet, CHGNetCalculator
1✔
13
from chgnet.trainer import Trainer
1✔
14
from pymatgen.io.ase import AseAtomsAdaptor
1✔
15

16
from .base import BaseLearnableForcefield, State
1✔
17

18

19
def make_chgnet_dataset(atoms: list[ase.Atoms]) -> StructureData:
1✔
20
    """Make a dataset ready for use by CHGNet for training
21

22
    Args:
23
        atoms: List of atoms with computed properties
24
    Returns:
25
        Training dataset including forces, energies, and stresses
26
    """
27

28
    structures = []
×
29
    energies = []
×
30
    forces = []
×
31
    stresses = []
×
32
    for a in atoms:
×
33
        structures.append(AseAtomsAdaptor.get_structure(a))
×
34

35
        # Retrieve the properties from the ASE calculator
36
        for p in ['energy', 'forces', 'stress']:
×
37
            if p not in a.calc.results:
×
38
                raise ValueError(f'Atoms is missing property: {p}')
×
39

40
        energies.append(a.get_potential_energy() / len(a))  # CHGNet uses atomic energies
×
41
        forces.append(a.get_forces())
×
42
        stresses.append(a.get_stress(voigt=False) / units.GPa * -10)  # CHGNet expects training data in kBar and using VASP's sign convention
×
43

44
    return StructureData(
×
45
        structures=structures,
46
        energies=energies,
47
        forces=forces,
48
        stresses=stresses,
49
    )
50

51

52
class CHGNetInterface(BaseLearnableForcefield[CHGNet]):
1✔
53
    """Interface to training and running CHGnet forcefields"""
54

55
    def evaluate(self,
1✔
56
                 model_msg: bytes | CHGNet,
57
                 atoms: list[ase.Atoms],
58
                 batch_size: int = 64,
59
                 device: str = 'cpu') -> tuple[np.ndarray, list[np.ndarray], np.ndarray]:
60
        model = self.get_model(model_msg)
×
61
        model.to(device)
×
62

63
        # Convert structures to Pymatgen
64
        structures = [AseAtomsAdaptor.get_structure(a) for a in atoms]
×
65

66
        # Run everything
67
        model.to(device)
×
68
        preds = model.predict_structure(structures, task='efs', batch_size=batch_size)
×
69
        model.to('cpu')
×
70

71
        # Transpose into Numpy arrayes
72
        energies = np.array([r['e'] for r in preds])
×
73
        if model.is_intensive:
×
74
            atom_counts = np.array([len(a) for a in atoms])
×
75
            energies *= atom_counts
×
76
        forces = [r['f'][:len(a), :] for a, r in zip(atoms, preds)]
×
77
        stress = np.array([r['s'] for r in preds]) * units.GPa
×
78
        return energies, forces, stress
×
79

80
    def train(self,
1✔
81
              model_msg: bytes | CHGNet,
82
              train_data: list[ase.Atoms],
83
              valid_data: list[ase.Atoms],
84
              num_epochs: int,
85
              device: str = 'cpu',
86
              batch_size: int = 32,
87
              learning_rate: float = 1e-3,
88
              huber_deltas: tuple[float, float, float] = (0.1, 0.1, 0.1),
89
              force_weight: float = 1,
90
              stress_weight: float = 0.1,
91
              reset_weights: bool = False,
92
              **kwargs) -> tuple[bytes, pd.DataFrame]:
93
        model = self.get_model(model_msg)
×
94

95
        # Reset weights, if needed
96
        if reset_weights:
×
97
            def init_params(m):
×
98
                if hasattr(m, 'reset_parameters'):
×
99
                    m.reset_parameters()
×
100

101
            model.apply(init_params)
×
102

103
        with TemporaryDirectory(prefix='chgnet_') as tmpdir:
×
104
            tmpdir = Path(tmpdir)
×
105
            with open(tmpdir / 'chgnet.stdout', 'w') as fp, redirect_stdout(fp):
×
106
                # Make the data loaders
107
                train_dataset = make_chgnet_dataset(train_data)
×
108
                valid_dataset = make_chgnet_dataset(valid_data)
×
109
                train_loader = get_loader(train_dataset, batch_size=batch_size)
×
110
                valid_loader = get_loader(valid_dataset, batch_size=batch_size)
×
111

112
                # Fit the atomic reference energies
113
                model.composition_model.fit(
×
114
                    train_dataset.structures,
115
                    train_dataset.energies
116
                )
117

118
                # Run the training
119
                trainer = Trainer(
×
120
                    model=model,
121
                    targets='efs',
122
                    criterion="Huber",
123
                    force_loss_ratio=force_weight,
124
                    stress_loss_ratio=stress_weight,
125
                    epochs=num_epochs,
126
                    learning_rate=learning_rate,
127
                    use_device=device,
128
                    print_freq=num_epochs + 1
129
                )
130
                trainer.train(train_loader, valid_loader, train_composition_model=True, save_dir=str(tmpdir))
×
131
                model.to('cpu')
×
132

133
                # Store the results
134
                best_model = trainer.get_best_model()
×
135

136
        log = {}
×
137
        for key_1, history_1 in trainer.training_history.items():
×
138
            for key_2, history in history_1.items():
×
139
                if len(history) != num_epochs:
×
140
                    continue
×
141
                log[f'{key_1}_{key_2}'] = history
×
142
        log = pd.DataFrame(log)
×
143

144
        return self.serialize_model(best_model), log
×
145

146
    def make_calculator(self, model_msg: bytes | State, device: str) -> Calculator:
1✔
147
        model = self.get_model(model_msg)
×
148
        return CHGNetCalculator(model=model, use_device=device)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc