• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

globus-labs / cascade / 17808755921

17 Sep 2025 07:39PM UTC coverage: 25.34% (-70.0%) from 95.34%
17808755921

push

github

miketynes
initial agentic implementation

261 of 1030 relevant lines covered (25.34%)

0.25 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

59.38
/cascade/learning/base.py
1
"""Interface definitions"""
2
from io import BytesIO
1✔
3
from pathlib import Path
1✔
4
from typing import Generic, TypeVar
1✔
5

6
import ase
1✔
7
from ase.calculators.calculator import Calculator
1✔
8
import numpy as np
1✔
9
import pandas as pd
1✔
10
import torch.nn
1✔
11

12
from cascade.calculator import EnsembleCalculator
1✔
13

14
# TODO (wardlt): Break the hard-wire to PyTorch, maybe. I don't have a model yet which uses something else
15
State = TypeVar('State')
1✔
16
"""Generic type for the state of a certain model"""
1✔
17

18

19
class BaseLearnableForcefield(Generic[State]):
1✔
20
    """Interface for learning and evaluating a forcefield
21

22
    Using a Learnable Forcefield
23
    ----------------------------
24

25

26
    The learnable forcefield class defines a reduced interface to a surrogate model
27
    that computes the energies and forces of a system of atoms.
28
    The interfaces are designed to be simple to facilitate integration within a workflow
29
    and operate on serializable Python types to allow the workflow to run across distributed nodes.
30

31
    The key functions for use in workflows are
32

33
    - :meth:`evaluate` to predict the energies and forces of a series of atomic structures
34
    - :meth:`train` to update the machine learning model given a new set of structures
35
    - :meth:`make_make_calculator` to produce an ASE Calculator suitable for use in running dynamics
36

37
    The first argument to each of these functions is a "State" that is the components of a machine learning model
38
    in their original or serialized form.
39
    Each implementation varies in what defines the "State" of a model, but all share
40
    the :meth:`serialize_model` function to produce a byte-string version of the State
41
    before sending it to a remote compute node.
42

43
    Implementing a Learnable Forcefield
44
    -----------------------------------
45

46
    Implementations must define the :meth:`evaluate` and :meth:`train` functions,
47
    which provide an imperfect but sufficient interface for training the model.
48

49
    The functions must take either the serialized version of the model as a byte string
50
    or the unserialized version for workflows run on a single node.
51
    Use the :meth:`get_model` function to deserialize a byte string.
52

53
    Express the type used to express the state of your model as a Python type specification
54
    that as passed as a generic argument to the class.
55
    For example,
56

57
    .. code: python
58

59
        State = list[float]
60

61
        class MyClass(BaseLearnableForcefield[State]):
62

63
            def evaluate(...
64

65
    Add any arguments to the :meth:`train` function as appropriate for updating the weights
66
    of a machine learning model, but do not add any for creating a new architecture.
67
    Cascade workflows are designed to use a fixed architecture rather than
68
    perform hyperparameter optimization.
69
    Create utility functions for defining the architecture in a separate module.
70
    """
71

72
    def __init__(self, scratch_dir: Path | None = None):
1✔
73
        """
74

75
        Args:
76
            scratch_dir: Path used to store temporary files
77
        """
78
        self.scratch_dir = scratch_dir
×
79

80
    def serialize_model(self, state: State | bytes) -> bytes:
1✔
81
        """Serialize the state of a model into a byte string
82

83
        Args:
84
            state: Model state
85
        Returns:
86
            Form ready for transmission to a compute node
87
        """
88
        if not isinstance(state, bytes):
×
89
            b = BytesIO()
×
90
            torch.save(state, b)
×
91
            return b.getvalue()
×
92
        return state
×
93

94
    def get_model(self, model_msg: bytes | State) -> State:
1✔
95
        """Load a model from the provided message and place on the CPU memory
96

97
        Args:
98
            model_msg: Model message
99
        Returns:
100
            The model ready for use in a function
101
        """
102
        if isinstance(model_msg, bytes):
×
103
            return torch.load(BytesIO(model_msg), map_location='cpu')
×
104
        return model_msg
×
105

106
    def evaluate(self,
1✔
107
                 model_msg: bytes | State,
108
                 atoms: list[ase.Atoms],
109
                 batch_size: int = 64,
110
                 device: str = 'cpu') -> (np.ndarray, list[np.ndarray], np.ndarray):
111
        """Run inference for a series of structures
112

113
        Args:
114
            model_msg: Model to evaluate
115
            atoms: List of structures to evaluate
116
            batch_size: Number of molecules to evaluate per batch
117
            device: Device on which to run the computation
118
        Returns:
119
            - Energy for each structure. (N,) array of floats, where N is the number of structures
120
            - Forces for each structure. List of N arrays of (n, 3), where n is the number of atoms in each structure
121
            - Stresses for each structure. (N, 3, 3) array, where each row a stress tensor.
122
        """
123
        raise NotImplementedError()
×
124

125
    def train(self,
1✔
126
              model_msg: bytes | State,
127
              train_data: list[ase.Atoms],
128
              valid_data: list[ase.Atoms],
129
              num_epochs: int,
130
              device: str = 'cpu',
131
              batch_size: int = 32,
132
              learning_rate: float = 1e-3,
133
              huber_deltas: tuple[float, float, float] = (0.5, 1, 1),
134
              force_weight: float = 10,
135
              stress_weight: float = 100,
136
              reset_weights: bool = False,
137
              **kwargs) -> tuple[bytes, pd.DataFrame]:
138
        """Train a model
139

140
        Args:
141
            model_msg: Model to be retrained
142
            train_data: Structures used for training
143
            valid_data: Structures used for validation
144
            num_epochs: Number of training epochs
145
            device: Device (e.g., 'cuda', 'cpu') used for training
146
            batch_size: Batch size during training
147
            learning_rate: Initial learning rate for optimizer
148
            huber_deltas: Delta parameters for the loss functions for energy and force
149
            force_weight: Amount of weight to use for the force part of the loss function
150
            stress_weight: Amount of weight to use for the stress part of the loss function
151
            reset_weights: Whether to reset the weights before training
152
        Returns:
153
            - model: Retrained model
154
            - history: Training history
155
        """
156
        raise NotImplementedError()
×
157

158
    def make_calculator(self, model_msg: bytes | State, device: str) -> Calculator:
1✔
159
        """Make an ASE calculator form of the provided model
160

161
        Args:
162
            model_msg: Serialized form of the model
163
            device: Device on which to run computations
164
        Returns:
165
            Model turned into a calculator
166
        """
167
        raise NotImplementedError()
×
168

169
    def make_ensemble_calculator(self, model_msgs: list[bytes | State], device: str) -> EnsembleCalculator:
1✔
170
        raise NotImplementedError()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc